refactor: simplify PatternVar.lean and Match.lean

This commit is contained in:
Leonardo de Moura 2022-03-07 16:41:05 -08:00
parent 74411aa472
commit 1609f96128
4 changed files with 90 additions and 109 deletions

View file

@ -65,10 +65,12 @@ private def elabOptLevel (stx : Syntax) : TermElabM Level :=
let arg := stx[1]
let userName := if arg.isIdent then arg.getId else Name.anonymous
let mkNewHole : Unit → TermElabM Expr := fun _ => do
let mvar ← mkFreshExprMVar expectedType? MetavarKind.syntheticOpaque userName
let kind := if (← read).inPattern then MetavarKind.natural else MetavarKind.syntheticOpaque
let mvar ← mkFreshExprMVar expectedType? kind userName
if (← read).inPattern then trace[Elab.match] "userName: {userName}, kind: {repr kind}, mvar: {mvar}"
registerMVarErrorHoleInfo mvar.mvarId! stx
pure mvar
if userName.isAnonymous then
if userName.isAnonymous || (← read).inPattern then
mkNewHole ()
else
let mctx ← getMCtx

View file

@ -174,14 +174,6 @@ private def getMatchAlts : Syntax → Array MatchAltView
| _ => none
| _ => #[]
builtin_initialize Parser.registerBuiltinNodeKind `MVarWithIdKind
/--
The elaboration function for `Syntax` created using `mkMVarSyntax`.
It just converts the metavariable id wrapped by the Syntax into an `Expr`. -/
@[builtinTermElab MVarWithIdKind] def elabMVarWithIdKind : TermElab := fun stx expectedType? =>
return mkInaccessible <| mkMVar (getMVarSyntaxMVarId stx)
@[builtinTermElab inaccessible] def elabInaccessible : TermElab := fun stx expectedType? => do
let e ← elabTerm stx[1] expectedType?
return mkInaccessible e
@ -201,33 +193,17 @@ open Lean.Elab.Term.Quotation in
/- We convert the collected `PatternVar`s intro `PatternVarDecl` -/
inductive PatternVarDecl where
/- For `anonymousVar`, we create both a metavariable and a free variable. The free variable is used as an assignment for the metavariable
when it is not assigned during pattern elaboration. -/
| anonymousVar (mvarId : MVarId) (fvarId : FVarId)
| localVar (fvarId : FVarId)
private partial def withPatternVars {α} (pVars : Array PatternVar) (k : Array PatternVarDecl → TermElabM α) : TermElabM α :=
let rec loop (i : Nat) (decls : Array PatternVarDecl) (userNames : Array Name) := do
if h : i < pVars.size then
match pVars.get ⟨i, h⟩ with
| PatternVar.anonymousVar mvarId userName =>
let type ← mkFreshTypeMVar
let userNameFVar ← if userName.isAnonymous then mkFreshBinderName else pure userName
withLocalDecl userNameFVar BinderInfo.default type fun x =>
loop (i+1) (decls.push (PatternVarDecl.anonymousVar mvarId x.fvarId!)) (userNames.push userName)
| PatternVar.localVar userName =>
let type ← mkFreshTypeMVar
withLocalDecl userName BinderInfo.default type fun x =>
loop (i+1) (decls.push (PatternVarDecl.localVar x.fvarId!)) (userNames.push Name.anonymous)
else
/- We must create the metavariables for `PatternVar.anonymousVar` AFTER we create the new local decls using `withLocalDecl`.
Reason: their scope must include the new local decls since some of them are assigned by typing constraints. -/
for decl in decls, userName in userNames do
match decl with
| PatternVarDecl.anonymousVar mvarId fvarId => do
let type ← inferType (mkFVar fvarId)
discard <| mkFreshExprMVarWithId mvarId type (userName := userName)
| _ => pure ()
k decls
loop 0 #[] #[]
@ -372,6 +348,9 @@ private partial def eraseIndices (type : Expr) : MetaM Expr := do
let (newIndices, _, _) ← forallMetaTelescopeReducing resultType (some (args.size - info.numParams))
return mkAppN result newIndices
private def withPatternElabConfig (x : TermElabM α) : TermElabM α :=
withoutErrToSorry <| withReader (fun ctx => { ctx with inPattern := true }) <| x
private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : ExceptT PatternElabException TermElabM (Array Expr × Expr) :=
withReader (fun ctx => { ctx with implicitLambda := false }) do
let mut patterns := #[]
@ -384,10 +363,10 @@ private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : Excep
let pattern ← do
let s ← saveState
try
liftM <| withSynthesize <| withoutErrToSorry <| elabTermEnsuringType patternStx d
liftM <| withSynthesize <| withPatternElabConfig <| elabTermEnsuringType patternStx d
catch ex : Exception =>
restoreState s
match (← liftM <| commitIfNoErrors? <| withoutErrToSorry do elabTermAndSynthesize patternStx (← eraseIndices d)) with
match (← liftM <| commitIfNoErrors? <| withPatternElabConfig do elabTermAndSynthesize patternStx (← eraseIndices d)) with
| some pattern =>
match (← findDiscrRefinementPath pattern d |>.run) with
| some path =>
@ -402,7 +381,7 @@ private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : Excep
| _ => throwError "unexpected match type"
return (patterns, matchType)
def finalizePatternDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (Array LocalDecl) := do
private def patternDeclsToLocalDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (Array LocalDecl) := do
let mut decls := #[]
for pdecl in patternVarDecls do
match pdecl with
@ -410,19 +389,6 @@ def finalizePatternDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (A
let decl ← getLocalDecl fvarId
let decl ← instantiateLocalDeclMVars decl
decls := decls.push decl
| PatternVarDecl.anonymousVar mvarId fvarId =>
let e ← instantiateMVars (mkMVar mvarId)
trace[Elab.match] "finalizePatternDecls: mvarId: {mvarId.name} := {e}, fvar: {mkFVar fvarId}"
match e with
| Expr.mvar newMVarId _ =>
/- Metavariable was not assigned, or assigned to another metavariable. So,
we assign to the auxiliary free variable we created at `withPatternVars` to `newMVarId`. -/
assignExprMVar newMVarId (mkFVar fvarId)
trace[Elab.match] "finalizePatternDecls: {mkMVar newMVarId} := {mkFVar fvarId}"
let decl ← getLocalDecl fvarId
let decl ← instantiateLocalDeclMVars decl
decls := decls.push decl
| _ => pure ()
/- We perform a topological sort (dependecies) on `decls` because the pattern elaboration process may produce a sequence where a declaration d₁ may occur after d₂ when d₂ depends on d₁. -/
sortLocalDecls decls
@ -435,7 +401,11 @@ structure State where
localDecls : Array LocalDecl
newLocals : FVarIdSet := {}
abbrev M := StateRefT State TermElabM
structure Context where
/-- TODO: document -/
userName : Name := Name.anonymous
abbrev M := ReaderT Context $ StateRefT State TermElabM
private def alreadyVisited (fvarId : FVarId) : M Bool := do
let s ← get
@ -448,30 +418,45 @@ private def throwInvalidPattern {α} (e : Expr) : M α :=
throwError "invalid pattern {indentExpr e}"
/- Create a new LocalDecl `x` for the metavariable `mvar`, and return `Pattern.var x` -/
private def mkLocalDeclFor (mvar : Expr) : M Pattern := do
private partial def mkLocalDeclFor (mvar : Expr) : M Pattern := do
let mvarId := mvar.mvarId!
assert! !(← isExprMVarAssigned mvarId)
let s ← get
match (← getExprMVarAssignment? mvarId) with
| some val => return Pattern.inaccessible val
| none =>
let fvarId ← mkFreshFVarId
let mvarDecl ← getMVarDecl mvarId
let type := mvarDecl.type
/- HACK: `fvarId` is not in the scope of `mvarId`
If this generates problems in the future, we should update the metavariable declarations. -/
assignExprMVar mvarId (mkFVar fvarId)
let userName ← if mvarDecl.userName.isAnonymous then mkFreshBinderName else pure mvarDecl.userName
let newDecl := LocalDecl.cdecl default fvarId userName type BinderInfo.default
modify fun s =>
{ s with
newLocals := s.newLocals.insert fvarId,
localDecls :=
match s.localDecls.findIdx? fun decl => mvar.occurs decl.type with
| none => s.localDecls.push newDecl -- None of the existing declarations depend on `mvar`
| some i => s.localDecls.insertAt i newDecl }
return Pattern.var fvarId
let fvarId ← mkFreshFVarId
let mvarDecl ← getMVarDecl mvarId
let type := mvarDecl.type
/- HACK: `fvarId` is not in the scope of `mvarId`
If this generates problems in the future, we should update the metavariable declarations. -/
assignExprMVar mvarId (mkFVar fvarId)
let mut userName := mvarDecl.userName
if userName.isAnonymous then
userName := (← read).userName
if userName.isAnonymous then
userName ← mkFreshBinderName
let newDecl := LocalDecl.cdecl default fvarId userName type BinderInfo.default
modify fun s =>
{ s with
newLocals := s.newLocals.insert fvarId,
localDecls :=
match s.localDecls.findIdx? fun decl => mvar.occurs decl.type with
| none => s.localDecls.push newDecl -- None of the existing declarations depend on `mvar`
| some i => s.localDecls.insertAt i newDecl }
trace[Elab.match] "mkLocalDeclFor {mvar} => {mkFVar fvarId}"
return Pattern.var fvarId
private def withMVar (mvarId : MVarId) (x : M α) : M α := do
let localDecl ← getMVarDecl mvarId
if !localDecl.userName.isAnonymous && (← read).userName.isAnonymous then
withReader (fun ctx => { ctx with userName := localDecl.userName }) x
else
x
private def isMatchValue' (e : Expr) : M Bool := do
-- TODO: optimize if it is a bottleneck. Simple trick: check head symbol before invoking `instantiateMVars`
return isMatchValue (← instantiateMVars e)
partial def main (e : Expr) : M Pattern := do
trace[Elab.match] "ToDepElimPattern.main: {e}"
let isLocalDecl (fvarId : FVarId) : M Bool := do
return (← get).localDecls.any fun d => d.fvarId == fvarId
let mkPatternVar (fvarId : FVarId) (e : Expr) : M Pattern := do
@ -480,7 +465,7 @@ partial def main (e : Expr) : M Pattern := do
else
markAsVisited fvarId
return Pattern.var e.fvarId!
let mkInaccessible (e : Expr) : M Pattern := do
let rec mkInaccessible (e : Expr) : M Pattern := do
match e with
| Expr.fvar fvarId _ =>
if (← isLocalDecl fvarId) then
@ -488,7 +473,14 @@ partial def main (e : Expr) : M Pattern := do
else
return Pattern.inaccessible e
| _ =>
return Pattern.inaccessible e
if e.getAppFn.isMVar then
let eNew ← instantiateMVars e
if eNew != e then
withMVar e.getAppFn.mvarId! <| mkInaccessible eNew
else
mkLocalDeclFor e.getAppFn
else
return Pattern.inaccessible (← instantiateMVars e)
match inaccessible? e with
| some t => mkInaccessible t
| none =>
@ -501,22 +493,26 @@ partial def main (e : Expr) : M Pattern := do
match e.getArg! 1, e.getArg! 3 with
| Expr.fvar x _, Expr.fvar h _ => return Pattern.as x p h
| _, _ => throwError "unexpected occurrence of auxiliary declaration 'namedPattern'"
else if isMatchValue e then
return Pattern.val e
else if (← isMatchValue' e) then
return Pattern.val (← instantiateMVars e)
else if e.isFVar then
let fvarId := e.fvarId!
unless (← isLocalDecl fvarId) do
throwInvalidPattern e
mkPatternVar fvarId e
else if e.isMVar then
mkLocalDeclFor e
else
let newE ← whnf e
if newE != e then
main newE
else if e.getAppFn.isMVar then
let eNew ← instantiateMVars e
if eNew != e then
withMVar e.getAppFn.mvarId! <| main eNew
else
matchConstCtor e.getAppFn
(fun _ => do
mkLocalDeclFor e.getAppFn
else
let eNew ← whnf e
if eNew != e then
main eNew
else
matchConstCtor e.getAppFn
(fun _ => do
if (← isProof e) then
/- We mark nested proofs as inaccessible. This is fine due to proof irrelevance.
We need this feature to be able to elaborate definitions such as:
@ -529,11 +525,12 @@ partial def main (e : Expr) : M Pattern := do
return Pattern.inaccessible e
else
throwInvalidPattern e)
(fun v us => do
(fun v us => do
let args := e.getAppArgs
unless args.size == v.numParams + v.numFields do
throwInvalidPattern e
let params := args.extract 0 v.numParams
let params ← params.mapM fun p => instantiateMVars p
let fields := args.extract v.numParams args.size
let fields ← fields.mapM main
return Pattern.ctor v.name us params.toList fields.toList)
@ -541,13 +538,13 @@ partial def main (e : Expr) : M Pattern := do
end ToDepElimPattern
def withDepElimPatterns {α} (patternVarDecls : Array PatternVarDecl) (localDecls : Array LocalDecl) (ps : Array Expr) (k : Array LocalDecl → Array Pattern → TermElabM α) : TermElabM α := do
let (patterns, s) ← (ps.mapM ToDepElimPattern.main).run { localDecls := localDecls }
let (patterns, s) ← (ps.mapM ToDepElimPattern.main).run {} |>.run { localDecls := localDecls }
let localDecls ← s.localDecls.mapM fun d => instantiateLocalDeclMVars d
trace[Elab.match] "localDecls: {localDecls.map (·.userName)}"
/- toDepElimPatterns may have added new localDecls. Thus, we must update the local context before we execute `k` -/
let lctx ← getLCtx
let lctx := patternVarDecls.foldl (init := lctx) fun (lctx : LocalContext) d =>
match d with
| PatternVarDecl.anonymousVar _ fvarId => lctx.erase fvarId
| PatternVarDecl.localVar fvarId => lctx.erase fvarId
let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.addDecl d) lctx
withTheReader Meta.Context (fun ctx => { ctx with lctx := lctx }) do
@ -557,8 +554,9 @@ private def withElaboratedLHS {α} (ref : Syntax) (patternVarDecls : Array Patte
(k : AltLHS → Expr → TermElabM α) : ExceptT PatternElabException TermElabM α := do
let (patterns, matchType) ← withSynthesize <| elabPatterns patternStxs matchType
id (α := TermElabM α) do
let localDecls ← finalizePatternDecls patternVarDecls
let patterns ← patterns.mapM (instantiateMVars ·)
-- let patterns ← patterns.mapM (instantiateMVars ·)
trace[Elab.match] "patterns: {patterns}"
let localDecls ← patternDeclsToLocalDecls patternVarDecls
withDepElimPatterns patternVarDecls localDecls patterns fun localDecls patterns => do
k { ref := ref, fvarDecls := localDecls.toList, patterns := patterns.toList } matchType
@ -638,7 +636,7 @@ private def generalize (discrs : Array Expr) (matchType : Expr) (altViews : Arra
if ysUserNames.contains yUserName then
yUserName ← mkFreshUserName yUserName
-- Explicitly provided pattern variables shadow `y`
else if patternVars.any fun | PatternVar.localVar x => x == yUserName | _ => false then
else if patternVars.any fun | PatternVar.localVar x => x == yUserName then
yUserName ← mkFreshUserName yUserName
return ysUserNames.push yUserName
let ysIds ← ysUserNames.reverse.mapM fun n => return mkIdentFrom (← getRef) n

View file

@ -12,25 +12,10 @@ namespace Lean.Elab.Term
open Meta
inductive PatternVar where
| localVar (userName : Name)
-- anonymous variables (`_`) are encoded using metavariables
| anonymousVar (mvarId : MVarId) (userName : Name)
| localVar (userName : Name)
instance : ToString PatternVar := ⟨fun
| PatternVar.localVar x => toString x
| PatternVar.anonymousVar mvarId _ => s!"?m{mvarId.name}"⟩
/--
Create an auxiliary Syntax node wrapping a fresh metavariable id.
We use this kind of Syntax for representing `_` occurring in patterns.
The metavariables are created before we elaborate the patterns into `Expr`s. -/
private def mkMVarSyntax : TermElabM Syntax := do
let mvarId ← mkFreshId
return mkNode `MVarWithIdKind #[mkNode mvarId #[]]
/-- Given a syntax node constructed using `mkMVarSyntax`, return its MVarId -/
def getMVarSyntaxMVarId (stx : Syntax) : MVarId :=
{ name := stx[0].getKind }
instance : ToString PatternVar where
toString := fun ⟨x⟩ => toString x
/-
Patterns define new local variables.
@ -171,14 +156,9 @@ partial def collect (stx : Syntax) : M Syntax := withRef stx <| withFreshMacroSc
pure <| field.setArg 0 field
return stx.setArg 2 <| mkNullNode fields
else if k == ``Lean.Parser.Term.hole then
let r ← mkMVarSyntax
modify fun s => { s with vars := s.vars.push <| PatternVar.anonymousVar (getMVarSyntaxMVarId r) Name.anonymous }
return r
`(.( $stx ))
else if k == ``Lean.Parser.Term.syntheticHole then
let r ← mkMVarSyntax
let userName := if stx[1].isIdent then stx[1].getId else Name.anonymous
modify fun s => { s with vars := s.vars.push <| PatternVar.anonymousVar (getMVarSyntaxMVarId r) userName }
return r
`(.( $stx ))
else if k == ``Lean.Parser.Term.paren then
let arg := stx[1]
if arg.isNone then
@ -359,6 +339,5 @@ def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) :=
def getPatternVarNames (pvars : Array PatternVar) : Array Name :=
pvars.filterMap fun
| PatternVar.localVar x => some x
| _ => none
end Lean.Elab.Term

View file

@ -55,10 +55,12 @@ structure Context where
sectionFVars : NameMap Expr := {}
/-- Enable/disable implicit lambdas feature. -/
implicitLambda : Bool := true
/-- noncomputable sections automatically add the `noncomputable` modifier to any declaration we cannot generate code for -/
/-- Noncomputable sections automatically add the `noncomputable` modifier to any declaration we cannot generate code for -/
isNoncomputableSection : Bool := false
/-- when `true` we skip TC failures. We use this option when processing patterns -/
/-- When `true` we skip TC failures. We use this option when processing patterns -/
ignoreTCFailures : Bool := false
/-- True when elaborating patterns. It affects how we elaborate named holes. -/
inPattern : Bool := false
/-- Saved context for postponed terms and tactics to be executed. -/
structure SavedContext where