feat: allow overloaded notation in patterns
This commit is contained in:
parent
fddc8b06ac
commit
3214a20d33
4 changed files with 47 additions and 27 deletions
17
RELEASES.md
17
RELEASES.md
|
|
@ -171,3 +171,20 @@ example (a : A) : a.x = 1 := by
|
|||
-- `h` has now type `x = 1` instead of `autoParam (x = 1) auto✝`
|
||||
assumption
|
||||
```
|
||||
|
||||
* We now accept overloaded notation in patterns, but we require the set of pattern variables in each alternative to be the same. Example:
|
||||
```lean
|
||||
inductive Vector (α : Type u) : Nat → Type u
|
||||
| nil : Vector α 0
|
||||
| cons : α → Vector α n → Vector α (n+1)
|
||||
|
||||
infix:67 " :: " => Vector.cons -- Overloading the `::` notation
|
||||
|
||||
def head1 (x : List α) (h : x ≠ []) : α :=
|
||||
match x with
|
||||
| a :: as => a -- `::` is `List.cons` here
|
||||
|
||||
def head2 (x : Vector α (n+1)) : α :=
|
||||
match x with
|
||||
| a :: as => a -- `::` is `Vector.cons` here
|
||||
```
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ private partial def withPatternVars {α} (pVars : Array PatternVar) (k : Array P
|
|||
let rec loop (i : Nat) (decls : Array PatternVarDecl) (userNames : Array Name) := do
|
||||
if h : i < pVars.size then
|
||||
match pVars.get ⟨i, h⟩ with
|
||||
| PatternVar.localVar userName =>
|
||||
| { userName } =>
|
||||
let type ← mkFreshTypeMVar
|
||||
withLocalDecl userName BinderInfo.default type fun x =>
|
||||
loop (i+1) (decls.push { fvarId := x.fvarId! }) (userNames.push Name.anonymous)
|
||||
|
|
@ -729,7 +729,7 @@ private def generalize (discrs : Array Expr) (matchType : Expr) (altViews : Arra
|
|||
if ysUserNames.contains yUserName then
|
||||
yUserName ← mkFreshUserName yUserName
|
||||
-- Explicitly provided pattern variables shadow `y`
|
||||
else if patternVars.any fun | PatternVar.localVar x => x == yUserName then
|
||||
else if patternVars.any fun x => x.userName == yUserName then
|
||||
yUserName ← mkFreshUserName yUserName
|
||||
return ysUserNames.push yUserName
|
||||
let ysIds ← ysUserNames.reverse.mapM fun n => return mkIdentFrom (← getRef) n
|
||||
|
|
|
|||
|
|
@ -11,11 +11,12 @@ namespace Lean.Elab.Term
|
|||
|
||||
open Meta
|
||||
|
||||
inductive PatternVar where
|
||||
| localVar (userName : Name)
|
||||
structure PatternVar where
|
||||
userName : Name
|
||||
deriving BEq
|
||||
|
||||
instance : ToString PatternVar where
|
||||
toString := fun ⟨x⟩ => toString x
|
||||
toString x := toString x.userName
|
||||
|
||||
/-
|
||||
Patterns define new local variables.
|
||||
|
|
@ -47,22 +48,6 @@ structure State where
|
|||
|
||||
abbrev M := StateRefT State TermElabM
|
||||
|
||||
structure SavedState where
|
||||
term : Term.SavedState
|
||||
collect : State
|
||||
deriving Inhabited
|
||||
|
||||
protected def saveState : M SavedState :=
|
||||
return { term := (← Term.saveState), collect := (← get) }
|
||||
|
||||
def SavedState.restore (s : SavedState) (restoreInfo : Bool := false) : M Unit := do
|
||||
s.term.restore restoreInfo
|
||||
set s.collect
|
||||
|
||||
instance : MonadBacktrack SavedState M where
|
||||
saveState := CollectPatternVars.saveState
|
||||
restoreState b := b.restore
|
||||
|
||||
private def throwCtorExpected {α} : M α :=
|
||||
throwError "invalid pattern, constructor or constant marked with '[matchPattern]' expected"
|
||||
|
||||
|
|
@ -126,7 +111,7 @@ private def processVar (idStx : Syntax) : M Syntax := do
|
|||
throwError "invalid pattern variable, must be atomic"
|
||||
if (← get).found.contains id then
|
||||
throwError "invalid pattern, variable '{id}' occurred more than once"
|
||||
modify fun s => { s with vars := s.vars.push (PatternVar.localVar id), found := s.found.insert id }
|
||||
modify fun s => { s with vars := s.vars.push { userName := id }, found := s.found.insert id }
|
||||
return idStx
|
||||
|
||||
private def nameToPattern : Name → TermElabM Syntax
|
||||
|
|
@ -142,6 +127,12 @@ private def quotedNameToPattern (stx : Syntax) : TermElabM Syntax :=
|
|||
private def doubleQuotedNameToPattern (stx : Syntax) : TermElabM Syntax := do
|
||||
nameToPattern (← resolveGlobalConstNoOverloadWithInfo stx[2])
|
||||
|
||||
private def samePatternsVariables (startingAt : Nat) (s₁ s₂ : State) : Bool :=
|
||||
if h : s₁.vars.size = s₂.vars.size then
|
||||
Array.isEqvAux s₁.vars s₂.vars h (.==.) startingAt
|
||||
else
|
||||
false
|
||||
|
||||
partial def collect (stx : Syntax) : M Syntax := withRef stx <| withFreshMacroScope do
|
||||
let k := stx.getKind
|
||||
if k == identKind then
|
||||
|
|
@ -233,7 +224,18 @@ partial def collect (stx : Syntax) : M Syntax := withRef stx <| withFreshMacroSc
|
|||
/- Similar to previous case -/
|
||||
doubleQuotedNameToPattern stx
|
||||
else if k == choiceKind then
|
||||
throwError "invalid pattern, notation is ambiguous"
|
||||
let args := stx.getArgs
|
||||
let stateSaved ← get
|
||||
let arg0 ← collect args[0]
|
||||
let stateNew ← get
|
||||
let mut argsNew := #[arg0]
|
||||
for arg in args[1:] do
|
||||
set stateSaved
|
||||
argsNew := argsNew.push (← collect arg)
|
||||
unless samePatternsVariables stateSaved.vars.size stateNew (← get) do
|
||||
throwError "invalid pattern, overloaded notation is only allowed when all alternative have the same set of pattern variables"
|
||||
set stateNew
|
||||
return mkNode choiceKind argsNew
|
||||
else
|
||||
throwInvalidPattern
|
||||
|
||||
|
|
@ -354,7 +356,6 @@ def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) :=
|
|||
return s.vars
|
||||
|
||||
def getPatternVarNames (pvars : Array PatternVar) : Array Name :=
|
||||
pvars.filterMap fun
|
||||
| PatternVar.localVar x => some x
|
||||
pvars.map fun x => x.userName
|
||||
|
||||
end Lean.Elab.Term
|
||||
|
|
|
|||
|
|
@ -33,9 +33,11 @@ inductive Env : Vector Ty n → Type where
|
|||
| nil : Env Vector.nil
|
||||
| cons : Ty.interp a → Env ctx → Env (a :: ctx)
|
||||
|
||||
infix:67 " :: " => Env.cons
|
||||
|
||||
def Env.lookup : HasType i ctx ty → Env ctx → ty.interp
|
||||
| stop, Env.cons x xs => x
|
||||
| pop k, Env.cons x xs => lookup k xs
|
||||
| stop, x :: xs => x
|
||||
| pop k, x :: xs => lookup k xs
|
||||
|
||||
def Expr.interp (env : Env ctx) : Expr ctx ty → ty.interp
|
||||
| var i => env.lookup i
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue