feat: allow overloaded notation in patterns

This commit is contained in:
Leonardo de Moura 2022-03-10 12:46:21 -08:00
parent fddc8b06ac
commit 3214a20d33
4 changed files with 47 additions and 27 deletions

View file

@ -171,3 +171,20 @@ example (a : A) : a.x = 1 := by
-- `h` has now type `x = 1` instead of `autoParam (x = 1) auto✝`
assumption
```
* We now accept overloaded notation in patterns, but we require the set of pattern variables in each alternative to be the same. Example:
```lean
inductive Vector (α : Type u) : Nat → Type u
| nil : Vector α 0
| cons : α → Vector α n → Vector α (n+1)
infix:67 " :: " => Vector.cons -- Overloading the `::` notation
def head1 (x : List α) (h : x ≠ []) : α :=
match x with
| a :: as => a -- `::` is `List.cons` here
def head2 (x : Vector α (n+1)) : α :=
match x with
| a :: as => a -- `::` is `Vector.cons` here
```

View file

@ -198,7 +198,7 @@ private partial def withPatternVars {α} (pVars : Array PatternVar) (k : Array P
let rec loop (i : Nat) (decls : Array PatternVarDecl) (userNames : Array Name) := do
if h : i < pVars.size then
match pVars.get ⟨i, h⟩ with
| PatternVar.localVar userName =>
| { userName } =>
let type ← mkFreshTypeMVar
withLocalDecl userName BinderInfo.default type fun x =>
loop (i+1) (decls.push { fvarId := x.fvarId! }) (userNames.push Name.anonymous)
@ -729,7 +729,7 @@ private def generalize (discrs : Array Expr) (matchType : Expr) (altViews : Arra
if ysUserNames.contains yUserName then
yUserName ← mkFreshUserName yUserName
-- Explicitly provided pattern variables shadow `y`
else if patternVars.any fun | PatternVar.localVar x => x == yUserName then
else if patternVars.any fun x => x.userName == yUserName then
yUserName ← mkFreshUserName yUserName
return ysUserNames.push yUserName
let ysIds ← ysUserNames.reverse.mapM fun n => return mkIdentFrom (← getRef) n

View file

@ -11,11 +11,12 @@ namespace Lean.Elab.Term
open Meta
inductive PatternVar where
| localVar (userName : Name)
structure PatternVar where
userName : Name
deriving BEq
instance : ToString PatternVar where
toString := fun ⟨x⟩ => toString x
toString x := toString x.userName
/-
Patterns define new local variables.
@ -47,22 +48,6 @@ structure State where
abbrev M := StateRefT State TermElabM
structure SavedState where
term : Term.SavedState
collect : State
deriving Inhabited
protected def saveState : M SavedState :=
return { term := (← Term.saveState), collect := (← get) }
def SavedState.restore (s : SavedState) (restoreInfo : Bool := false) : M Unit := do
s.term.restore restoreInfo
set s.collect
instance : MonadBacktrack SavedState M where
saveState := CollectPatternVars.saveState
restoreState b := b.restore
private def throwCtorExpected {α} : M α :=
throwError "invalid pattern, constructor or constant marked with '[matchPattern]' expected"
@ -126,7 +111,7 @@ private def processVar (idStx : Syntax) : M Syntax := do
throwError "invalid pattern variable, must be atomic"
if (← get).found.contains id then
throwError "invalid pattern, variable '{id}' occurred more than once"
modify fun s => { s with vars := s.vars.push (PatternVar.localVar id), found := s.found.insert id }
modify fun s => { s with vars := s.vars.push { userName := id }, found := s.found.insert id }
return idStx
private def nameToPattern : Name → TermElabM Syntax
@ -142,6 +127,12 @@ private def quotedNameToPattern (stx : Syntax) : TermElabM Syntax :=
private def doubleQuotedNameToPattern (stx : Syntax) : TermElabM Syntax := do
nameToPattern (← resolveGlobalConstNoOverloadWithInfo stx[2])
private def samePatternsVariables (startingAt : Nat) (s₁ s₂ : State) : Bool :=
if h : s₁.vars.size = s₂.vars.size then
Array.isEqvAux s₁.vars s₂.vars h (.==.) startingAt
else
false
partial def collect (stx : Syntax) : M Syntax := withRef stx <| withFreshMacroScope do
let k := stx.getKind
if k == identKind then
@ -233,7 +224,18 @@ partial def collect (stx : Syntax) : M Syntax := withRef stx <| withFreshMacroSc
/- Similar to previous case -/
doubleQuotedNameToPattern stx
else if k == choiceKind then
throwError "invalid pattern, notation is ambiguous"
let args := stx.getArgs
let stateSaved ← get
let arg0 ← collect args[0]
let stateNew ← get
let mut argsNew := #[arg0]
for arg in args[1:] do
set stateSaved
argsNew := argsNew.push (← collect arg)
unless samePatternsVariables stateSaved.vars.size stateNew (← get) do
throwError "invalid pattern, overloaded notation is only allowed when all alternative have the same set of pattern variables"
set stateNew
return mkNode choiceKind argsNew
else
throwInvalidPattern
@ -354,7 +356,6 @@ def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) :=
return s.vars
def getPatternVarNames (pvars : Array PatternVar) : Array Name :=
pvars.filterMap fun
| PatternVar.localVar x => some x
pvars.map fun x => x.userName
end Lean.Elab.Term

View file

@ -33,9 +33,11 @@ inductive Env : Vector Ty n → Type where
| nil : Env Vector.nil
| cons : Ty.interp a → Env ctx → Env (a :: ctx)
infix:67 " :: " => Env.cons
def Env.lookup : HasType i ctx ty → Env ctx → ty.interp
| stop, Env.cons x xs => x
| pop k, Env.cons x xs => lookup k xs
| stop, x :: xs => x
| pop k, x :: xs => lookup k xs
def Expr.interp (env : Env ctx) : Expr ctx ty → ty.interp
| var i => env.lookup i