From 3214a20d33b4f88ff4783cb2fa3148ade47b833a Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 10 Mar 2022 12:46:21 -0800 Subject: [PATCH] feat: allow overloaded notation in patterns --- RELEASES.md | 17 +++++++++++++ src/Lean/Elab/Match.lean | 4 +-- src/Lean/Elab/PatternVar.lean | 47 ++++++++++++++++++----------------- tests/lean/run/interp.lean | 6 +++-- 4 files changed, 47 insertions(+), 27 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index ef63f29ed6..2b9bc95821 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -171,3 +171,20 @@ example (a : A) : a.x = 1 := by -- `h` has now type `x = 1` instead of `autoParam (x = 1) auto✝` assumption ``` + +* We now accept overloaded notation in patterns, but we require the set of pattern variables in each alternative to be the same. Example: +```lean +inductive Vector (α : Type u) : Nat → Type u + | nil : Vector α 0 + | cons : α → Vector α n → Vector α (n+1) + +infix:67 " :: " => Vector.cons -- Overloading the `::` notation + +def head1 (x : List α) (h : x ≠ []) : α := + match x with + | a :: as => a -- `::` is `List.cons` here + +def head2 (x : Vector α (n+1)) : α := + match x with + | a :: as => a -- `::` is `Vector.cons` here +``` diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index b89f5dfc92..5683589356 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -198,7 +198,7 @@ private partial def withPatternVars {α} (pVars : Array PatternVar) (k : Array P let rec loop (i : Nat) (decls : Array PatternVarDecl) (userNames : Array Name) := do if h : i < pVars.size then match pVars.get ⟨i, h⟩ with - | PatternVar.localVar userName => + | { userName } => let type ← mkFreshTypeMVar withLocalDecl userName BinderInfo.default type fun x => loop (i+1) (decls.push { fvarId := x.fvarId! }) (userNames.push Name.anonymous) @@ -729,7 +729,7 @@ private def generalize (discrs : Array Expr) (matchType : Expr) (altViews : Arra if ysUserNames.contains yUserName then yUserName ← mkFreshUserName yUserName -- Explicitly provided pattern variables shadow `y` - else if patternVars.any fun | PatternVar.localVar x => x == yUserName then + else if patternVars.any fun x => x.userName == yUserName then yUserName ← mkFreshUserName yUserName return ysUserNames.push yUserName let ysIds ← ysUserNames.reverse.mapM fun n => return mkIdentFrom (← getRef) n diff --git a/src/Lean/Elab/PatternVar.lean b/src/Lean/Elab/PatternVar.lean index cc64ae5b1b..2c6296a242 100644 --- a/src/Lean/Elab/PatternVar.lean +++ b/src/Lean/Elab/PatternVar.lean @@ -11,11 +11,12 @@ namespace Lean.Elab.Term open Meta -inductive PatternVar where - | localVar (userName : Name) +structure PatternVar where + userName : Name + deriving BEq instance : ToString PatternVar where - toString := fun ⟨x⟩ => toString x + toString x := toString x.userName /- Patterns define new local variables. @@ -47,22 +48,6 @@ structure State where abbrev M := StateRefT State TermElabM -structure SavedState where - term : Term.SavedState - collect : State - deriving Inhabited - -protected def saveState : M SavedState := - return { term := (← Term.saveState), collect := (← get) } - -def SavedState.restore (s : SavedState) (restoreInfo : Bool := false) : M Unit := do - s.term.restore restoreInfo - set s.collect - -instance : MonadBacktrack SavedState M where - saveState := CollectPatternVars.saveState - restoreState b := b.restore - private def throwCtorExpected {α} : M α := throwError "invalid pattern, constructor or constant marked with '[matchPattern]' expected" @@ -126,7 +111,7 @@ private def processVar (idStx : Syntax) : M Syntax := do throwError "invalid pattern variable, must be atomic" if (← get).found.contains id then throwError "invalid pattern, variable '{id}' occurred more than once" - modify fun s => { s with vars := s.vars.push (PatternVar.localVar id), found := s.found.insert id } + modify fun s => { s with vars := s.vars.push { userName := id }, found := s.found.insert id } return idStx private def nameToPattern : Name → TermElabM Syntax @@ -142,6 +127,12 @@ private def quotedNameToPattern (stx : Syntax) : TermElabM Syntax := private def doubleQuotedNameToPattern (stx : Syntax) : TermElabM Syntax := do nameToPattern (← resolveGlobalConstNoOverloadWithInfo stx[2]) +private def samePatternsVariables (startingAt : Nat) (s₁ s₂ : State) : Bool := + if h : s₁.vars.size = s₂.vars.size then + Array.isEqvAux s₁.vars s₂.vars h (.==.) startingAt + else + false + partial def collect (stx : Syntax) : M Syntax := withRef stx <| withFreshMacroScope do let k := stx.getKind if k == identKind then @@ -233,7 +224,18 @@ partial def collect (stx : Syntax) : M Syntax := withRef stx <| withFreshMacroSc /- Similar to previous case -/ doubleQuotedNameToPattern stx else if k == choiceKind then - throwError "invalid pattern, notation is ambiguous" + let args := stx.getArgs + let stateSaved ← get + let arg0 ← collect args[0] + let stateNew ← get + let mut argsNew := #[arg0] + for arg in args[1:] do + set stateSaved + argsNew := argsNew.push (← collect arg) + unless samePatternsVariables stateSaved.vars.size stateNew (← get) do + throwError "invalid pattern, overloaded notation is only allowed when all alternative have the same set of pattern variables" + set stateNew + return mkNode choiceKind argsNew else throwInvalidPattern @@ -354,7 +356,6 @@ def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) := return s.vars def getPatternVarNames (pvars : Array PatternVar) : Array Name := - pvars.filterMap fun - | PatternVar.localVar x => some x + pvars.map fun x => x.userName end Lean.Elab.Term diff --git a/tests/lean/run/interp.lean b/tests/lean/run/interp.lean index 41e4a4e115..9ef1761547 100644 --- a/tests/lean/run/interp.lean +++ b/tests/lean/run/interp.lean @@ -33,9 +33,11 @@ inductive Env : Vector Ty n → Type where | nil : Env Vector.nil | cons : Ty.interp a → Env ctx → Env (a :: ctx) +infix:67 " :: " => Env.cons + def Env.lookup : HasType i ctx ty → Env ctx → ty.interp - | stop, Env.cons x xs => x - | pop k, Env.cons x xs => lookup k xs + | stop, x :: xs => x + | pop k, x :: xs => lookup k xs def Expr.interp (env : Env ctx) : Expr ctx ty → ty.interp | var i => env.lookup i