diff --git a/src/Lean/Elab/BuiltinTerm.lean b/src/Lean/Elab/BuiltinTerm.lean index d951aaf146..78ac8ce6b6 100644 --- a/src/Lean/Elab/BuiltinTerm.lean +++ b/src/Lean/Elab/BuiltinTerm.lean @@ -65,10 +65,12 @@ private def elabOptLevel (stx : Syntax) : TermElabM Level := let arg := stx[1] let userName := if arg.isIdent then arg.getId else Name.anonymous let mkNewHole : Unit → TermElabM Expr := fun _ => do - let mvar ← mkFreshExprMVar expectedType? MetavarKind.syntheticOpaque userName + let kind := if (← read).inPattern then MetavarKind.natural else MetavarKind.syntheticOpaque + let mvar ← mkFreshExprMVar expectedType? kind userName + if (← read).inPattern then trace[Elab.match] "userName: {userName}, kind: {repr kind}, mvar: {mvar}" registerMVarErrorHoleInfo mvar.mvarId! stx pure mvar - if userName.isAnonymous then + if userName.isAnonymous || (← read).inPattern then mkNewHole () else let mctx ← getMCtx diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 376861d9c2..a3dbeacb2a 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -174,14 +174,6 @@ private def getMatchAlts : Syntax → Array MatchAltView | _ => none | _ => #[] -builtin_initialize Parser.registerBuiltinNodeKind `MVarWithIdKind - -/-- - The elaboration function for `Syntax` created using `mkMVarSyntax`. - It just converts the metavariable id wrapped by the Syntax into an `Expr`. -/ -@[builtinTermElab MVarWithIdKind] def elabMVarWithIdKind : TermElab := fun stx expectedType? => - return mkInaccessible <| mkMVar (getMVarSyntaxMVarId stx) - @[builtinTermElab inaccessible] def elabInaccessible : TermElab := fun stx expectedType? => do let e ← elabTerm stx[1] expectedType? return mkInaccessible e @@ -201,33 +193,17 @@ open Lean.Elab.Term.Quotation in /- We convert the collected `PatternVar`s intro `PatternVarDecl` -/ inductive PatternVarDecl where - /- For `anonymousVar`, we create both a metavariable and a free variable. The free variable is used as an assignment for the metavariable - when it is not assigned during pattern elaboration. -/ - | anonymousVar (mvarId : MVarId) (fvarId : FVarId) | localVar (fvarId : FVarId) private partial def withPatternVars {α} (pVars : Array PatternVar) (k : Array PatternVarDecl → TermElabM α) : TermElabM α := let rec loop (i : Nat) (decls : Array PatternVarDecl) (userNames : Array Name) := do if h : i < pVars.size then match pVars.get ⟨i, h⟩ with - | PatternVar.anonymousVar mvarId userName => - let type ← mkFreshTypeMVar - let userNameFVar ← if userName.isAnonymous then mkFreshBinderName else pure userName - withLocalDecl userNameFVar BinderInfo.default type fun x => - loop (i+1) (decls.push (PatternVarDecl.anonymousVar mvarId x.fvarId!)) (userNames.push userName) | PatternVar.localVar userName => let type ← mkFreshTypeMVar withLocalDecl userName BinderInfo.default type fun x => loop (i+1) (decls.push (PatternVarDecl.localVar x.fvarId!)) (userNames.push Name.anonymous) else - /- We must create the metavariables for `PatternVar.anonymousVar` AFTER we create the new local decls using `withLocalDecl`. - Reason: their scope must include the new local decls since some of them are assigned by typing constraints. -/ - for decl in decls, userName in userNames do - match decl with - | PatternVarDecl.anonymousVar mvarId fvarId => do - let type ← inferType (mkFVar fvarId) - discard <| mkFreshExprMVarWithId mvarId type (userName := userName) - | _ => pure () k decls loop 0 #[] #[] @@ -372,6 +348,9 @@ private partial def eraseIndices (type : Expr) : MetaM Expr := do let (newIndices, _, _) ← forallMetaTelescopeReducing resultType (some (args.size - info.numParams)) return mkAppN result newIndices +private def withPatternElabConfig (x : TermElabM α) : TermElabM α := + withoutErrToSorry <| withReader (fun ctx => { ctx with inPattern := true }) <| x + private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : ExceptT PatternElabException TermElabM (Array Expr × Expr) := withReader (fun ctx => { ctx with implicitLambda := false }) do let mut patterns := #[] @@ -384,10 +363,10 @@ private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : Excep let pattern ← do let s ← saveState try - liftM <| withSynthesize <| withoutErrToSorry <| elabTermEnsuringType patternStx d + liftM <| withSynthesize <| withPatternElabConfig <| elabTermEnsuringType patternStx d catch ex : Exception => restoreState s - match (← liftM <| commitIfNoErrors? <| withoutErrToSorry do elabTermAndSynthesize patternStx (← eraseIndices d)) with + match (← liftM <| commitIfNoErrors? <| withPatternElabConfig do elabTermAndSynthesize patternStx (← eraseIndices d)) with | some pattern => match (← findDiscrRefinementPath pattern d |>.run) with | some path => @@ -402,7 +381,7 @@ private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : Excep | _ => throwError "unexpected match type" return (patterns, matchType) -def finalizePatternDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (Array LocalDecl) := do +private def patternDeclsToLocalDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (Array LocalDecl) := do let mut decls := #[] for pdecl in patternVarDecls do match pdecl with @@ -410,19 +389,6 @@ def finalizePatternDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (A let decl ← getLocalDecl fvarId let decl ← instantiateLocalDeclMVars decl decls := decls.push decl - | PatternVarDecl.anonymousVar mvarId fvarId => - let e ← instantiateMVars (mkMVar mvarId) - trace[Elab.match] "finalizePatternDecls: mvarId: {mvarId.name} := {e}, fvar: {mkFVar fvarId}" - match e with - | Expr.mvar newMVarId _ => - /- Metavariable was not assigned, or assigned to another metavariable. So, - we assign to the auxiliary free variable we created at `withPatternVars` to `newMVarId`. -/ - assignExprMVar newMVarId (mkFVar fvarId) - trace[Elab.match] "finalizePatternDecls: {mkMVar newMVarId} := {mkFVar fvarId}" - let decl ← getLocalDecl fvarId - let decl ← instantiateLocalDeclMVars decl - decls := decls.push decl - | _ => pure () /- We perform a topological sort (dependecies) on `decls` because the pattern elaboration process may produce a sequence where a declaration d₁ may occur after d₂ when d₂ depends on d₁. -/ sortLocalDecls decls @@ -435,7 +401,11 @@ structure State where localDecls : Array LocalDecl newLocals : FVarIdSet := {} -abbrev M := StateRefT State TermElabM +structure Context where + /-- TODO: document -/ + userName : Name := Name.anonymous + +abbrev M := ReaderT Context $ StateRefT State TermElabM private def alreadyVisited (fvarId : FVarId) : M Bool := do let s ← get @@ -448,30 +418,45 @@ private def throwInvalidPattern {α} (e : Expr) : M α := throwError "invalid pattern {indentExpr e}" /- Create a new LocalDecl `x` for the metavariable `mvar`, and return `Pattern.var x` -/ -private def mkLocalDeclFor (mvar : Expr) : M Pattern := do +private partial def mkLocalDeclFor (mvar : Expr) : M Pattern := do let mvarId := mvar.mvarId! + assert! !(← isExprMVarAssigned mvarId) let s ← get - match (← getExprMVarAssignment? mvarId) with - | some val => return Pattern.inaccessible val - | none => - let fvarId ← mkFreshFVarId - let mvarDecl ← getMVarDecl mvarId - let type := mvarDecl.type - /- HACK: `fvarId` is not in the scope of `mvarId` - If this generates problems in the future, we should update the metavariable declarations. -/ - assignExprMVar mvarId (mkFVar fvarId) - let userName ← if mvarDecl.userName.isAnonymous then mkFreshBinderName else pure mvarDecl.userName - let newDecl := LocalDecl.cdecl default fvarId userName type BinderInfo.default - modify fun s => - { s with - newLocals := s.newLocals.insert fvarId, - localDecls := - match s.localDecls.findIdx? fun decl => mvar.occurs decl.type with - | none => s.localDecls.push newDecl -- None of the existing declarations depend on `mvar` - | some i => s.localDecls.insertAt i newDecl } - return Pattern.var fvarId + let fvarId ← mkFreshFVarId + let mvarDecl ← getMVarDecl mvarId + let type := mvarDecl.type + /- HACK: `fvarId` is not in the scope of `mvarId` + If this generates problems in the future, we should update the metavariable declarations. -/ + assignExprMVar mvarId (mkFVar fvarId) + let mut userName := mvarDecl.userName + if userName.isAnonymous then + userName := (← read).userName + if userName.isAnonymous then + userName ← mkFreshBinderName + let newDecl := LocalDecl.cdecl default fvarId userName type BinderInfo.default + modify fun s => + { s with + newLocals := s.newLocals.insert fvarId, + localDecls := + match s.localDecls.findIdx? fun decl => mvar.occurs decl.type with + | none => s.localDecls.push newDecl -- None of the existing declarations depend on `mvar` + | some i => s.localDecls.insertAt i newDecl } + trace[Elab.match] "mkLocalDeclFor {mvar} => {mkFVar fvarId}" + return Pattern.var fvarId + +private def withMVar (mvarId : MVarId) (x : M α) : M α := do + let localDecl ← getMVarDecl mvarId + if !localDecl.userName.isAnonymous && (← read).userName.isAnonymous then + withReader (fun ctx => { ctx with userName := localDecl.userName }) x + else + x + +private def isMatchValue' (e : Expr) : M Bool := do + -- TODO: optimize if it is a bottleneck. Simple trick: check head symbol before invoking `instantiateMVars` + return isMatchValue (← instantiateMVars e) partial def main (e : Expr) : M Pattern := do + trace[Elab.match] "ToDepElimPattern.main: {e}" let isLocalDecl (fvarId : FVarId) : M Bool := do return (← get).localDecls.any fun d => d.fvarId == fvarId let mkPatternVar (fvarId : FVarId) (e : Expr) : M Pattern := do @@ -480,7 +465,7 @@ partial def main (e : Expr) : M Pattern := do else markAsVisited fvarId return Pattern.var e.fvarId! - let mkInaccessible (e : Expr) : M Pattern := do + let rec mkInaccessible (e : Expr) : M Pattern := do match e with | Expr.fvar fvarId _ => if (← isLocalDecl fvarId) then @@ -488,7 +473,14 @@ partial def main (e : Expr) : M Pattern := do else return Pattern.inaccessible e | _ => - return Pattern.inaccessible e + if e.getAppFn.isMVar then + let eNew ← instantiateMVars e + if eNew != e then + withMVar e.getAppFn.mvarId! <| mkInaccessible eNew + else + mkLocalDeclFor e.getAppFn + else + return Pattern.inaccessible (← instantiateMVars e) match inaccessible? e with | some t => mkInaccessible t | none => @@ -501,22 +493,26 @@ partial def main (e : Expr) : M Pattern := do match e.getArg! 1, e.getArg! 3 with | Expr.fvar x _, Expr.fvar h _ => return Pattern.as x p h | _, _ => throwError "unexpected occurrence of auxiliary declaration 'namedPattern'" - else if isMatchValue e then - return Pattern.val e + else if (← isMatchValue' e) then + return Pattern.val (← instantiateMVars e) else if e.isFVar then let fvarId := e.fvarId! unless (← isLocalDecl fvarId) do throwInvalidPattern e mkPatternVar fvarId e - else if e.isMVar then - mkLocalDeclFor e - else - let newE ← whnf e - if newE != e then - main newE + else if e.getAppFn.isMVar then + let eNew ← instantiateMVars e + if eNew != e then + withMVar e.getAppFn.mvarId! <| main eNew else - matchConstCtor e.getAppFn - (fun _ => do + mkLocalDeclFor e.getAppFn + else + let eNew ← whnf e + if eNew != e then + main eNew + else + matchConstCtor e.getAppFn + (fun _ => do if (← isProof e) then /- We mark nested proofs as inaccessible. This is fine due to proof irrelevance. We need this feature to be able to elaborate definitions such as: @@ -529,11 +525,12 @@ partial def main (e : Expr) : M Pattern := do return Pattern.inaccessible e else throwInvalidPattern e) - (fun v us => do + (fun v us => do let args := e.getAppArgs unless args.size == v.numParams + v.numFields do throwInvalidPattern e let params := args.extract 0 v.numParams + let params ← params.mapM fun p => instantiateMVars p let fields := args.extract v.numParams args.size let fields ← fields.mapM main return Pattern.ctor v.name us params.toList fields.toList) @@ -541,13 +538,13 @@ partial def main (e : Expr) : M Pattern := do end ToDepElimPattern def withDepElimPatterns {α} (patternVarDecls : Array PatternVarDecl) (localDecls : Array LocalDecl) (ps : Array Expr) (k : Array LocalDecl → Array Pattern → TermElabM α) : TermElabM α := do - let (patterns, s) ← (ps.mapM ToDepElimPattern.main).run { localDecls := localDecls } + let (patterns, s) ← (ps.mapM ToDepElimPattern.main).run {} |>.run { localDecls := localDecls } let localDecls ← s.localDecls.mapM fun d => instantiateLocalDeclMVars d + trace[Elab.match] "localDecls: {localDecls.map (·.userName)}" /- toDepElimPatterns may have added new localDecls. Thus, we must update the local context before we execute `k` -/ let lctx ← getLCtx let lctx := patternVarDecls.foldl (init := lctx) fun (lctx : LocalContext) d => match d with - | PatternVarDecl.anonymousVar _ fvarId => lctx.erase fvarId | PatternVarDecl.localVar fvarId => lctx.erase fvarId let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.addDecl d) lctx withTheReader Meta.Context (fun ctx => { ctx with lctx := lctx }) do @@ -557,8 +554,9 @@ private def withElaboratedLHS {α} (ref : Syntax) (patternVarDecls : Array Patte (k : AltLHS → Expr → TermElabM α) : ExceptT PatternElabException TermElabM α := do let (patterns, matchType) ← withSynthesize <| elabPatterns patternStxs matchType id (α := TermElabM α) do - let localDecls ← finalizePatternDecls patternVarDecls - let patterns ← patterns.mapM (instantiateMVars ·) + -- let patterns ← patterns.mapM (instantiateMVars ·) + trace[Elab.match] "patterns: {patterns}" + let localDecls ← patternDeclsToLocalDecls patternVarDecls withDepElimPatterns patternVarDecls localDecls patterns fun localDecls patterns => do k { ref := ref, fvarDecls := localDecls.toList, patterns := patterns.toList } matchType @@ -638,7 +636,7 @@ private def generalize (discrs : Array Expr) (matchType : Expr) (altViews : Arra if ysUserNames.contains yUserName then yUserName ← mkFreshUserName yUserName -- Explicitly provided pattern variables shadow `y` - else if patternVars.any fun | PatternVar.localVar x => x == yUserName | _ => false then + else if patternVars.any fun | PatternVar.localVar x => x == yUserName then yUserName ← mkFreshUserName yUserName return ysUserNames.push yUserName let ysIds ← ysUserNames.reverse.mapM fun n => return mkIdentFrom (← getRef) n diff --git a/src/Lean/Elab/PatternVar.lean b/src/Lean/Elab/PatternVar.lean index d3c09160e0..280ff6dd57 100644 --- a/src/Lean/Elab/PatternVar.lean +++ b/src/Lean/Elab/PatternVar.lean @@ -12,25 +12,10 @@ namespace Lean.Elab.Term open Meta inductive PatternVar where - | localVar (userName : Name) - -- anonymous variables (`_`) are encoded using metavariables - | anonymousVar (mvarId : MVarId) (userName : Name) + | localVar (userName : Name) -instance : ToString PatternVar := ⟨fun - | PatternVar.localVar x => toString x - | PatternVar.anonymousVar mvarId _ => s!"?m{mvarId.name}"⟩ - -/-- - Create an auxiliary Syntax node wrapping a fresh metavariable id. - We use this kind of Syntax for representing `_` occurring in patterns. - The metavariables are created before we elaborate the patterns into `Expr`s. -/ -private def mkMVarSyntax : TermElabM Syntax := do - let mvarId ← mkFreshId - return mkNode `MVarWithIdKind #[mkNode mvarId #[]] - -/-- Given a syntax node constructed using `mkMVarSyntax`, return its MVarId -/ -def getMVarSyntaxMVarId (stx : Syntax) : MVarId := - { name := stx[0].getKind } +instance : ToString PatternVar where + toString := fun ⟨x⟩ => toString x /- Patterns define new local variables. @@ -171,14 +156,9 @@ partial def collect (stx : Syntax) : M Syntax := withRef stx <| withFreshMacroSc pure <| field.setArg 0 field return stx.setArg 2 <| mkNullNode fields else if k == ``Lean.Parser.Term.hole then - let r ← mkMVarSyntax - modify fun s => { s with vars := s.vars.push <| PatternVar.anonymousVar (getMVarSyntaxMVarId r) Name.anonymous } - return r + `(.( $stx )) else if k == ``Lean.Parser.Term.syntheticHole then - let r ← mkMVarSyntax - let userName := if stx[1].isIdent then stx[1].getId else Name.anonymous - modify fun s => { s with vars := s.vars.push <| PatternVar.anonymousVar (getMVarSyntaxMVarId r) userName } - return r + `(.( $stx )) else if k == ``Lean.Parser.Term.paren then let arg := stx[1] if arg.isNone then @@ -359,6 +339,5 @@ def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) := def getPatternVarNames (pvars : Array PatternVar) : Array Name := pvars.filterMap fun | PatternVar.localVar x => some x - | _ => none end Lean.Elab.Term diff --git a/src/Lean/Elab/Term.lean b/src/Lean/Elab/Term.lean index 2106c6de57..4960e14b6d 100644 --- a/src/Lean/Elab/Term.lean +++ b/src/Lean/Elab/Term.lean @@ -55,10 +55,12 @@ structure Context where sectionFVars : NameMap Expr := {} /-- Enable/disable implicit lambdas feature. -/ implicitLambda : Bool := true - /-- noncomputable sections automatically add the `noncomputable` modifier to any declaration we cannot generate code for -/ + /-- Noncomputable sections automatically add the `noncomputable` modifier to any declaration we cannot generate code for -/ isNoncomputableSection : Bool := false - /-- when `true` we skip TC failures. We use this option when processing patterns -/ + /-- When `true` we skip TC failures. We use this option when processing patterns -/ ignoreTCFailures : Bool := false + /-- True when elaborating patterns. It affects how we elaborate named holes. -/ + inPattern : Bool := false /-- Saved context for postponed terms and tactics to be executed. -/ structure SavedContext where