From 8347dd5826377859a47d59fe681ae698fceba641 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 14 Oct 2020 14:24:09 -0700 Subject: [PATCH] chore: move to new frontend --- src/Lean/Elab/Binders.lean | 444 ++++++++++++++++++------------------- 1 file changed, 216 insertions(+), 228 deletions(-) diff --git a/src/Lean/Elab/Binders.lean b/src/Lean/Elab/Binders.lean index bef9e45081..d99474e8ac 100644 --- a/src/Lean/Elab/Binders.lean +++ b/src/Lean/Elab/Binders.lean @@ -1,3 +1,4 @@ +#lang lean4 /- Copyright (c) 2019 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. @@ -6,10 +7,7 @@ Authors: Leonardo de Moura import Lean.Elab.Term import Lean.Elab.Quotation -namespace Lean -namespace Elab -namespace Term - +namespace Lean.Elab.Term open Meta /-- @@ -21,7 +19,7 @@ private def expandBinderType (ref : Syntax) (stx : Syntax) : Syntax := if stx.getNumArgs == 0 then mkHole ref else - stx.getArg 1 + stx[1] /-- Given syntax of the form `ident <|> hole`, return `ident`. If `hole`, then we create a new anonymous name. -/ private def expandBinderIdent (stx : Syntax) : TermElabM Syntax := @@ -30,43 +28,43 @@ match_syntax stx with | _ => pure stx /-- Given syntax of the form `(ident >> " : ")?`, return `ident`, or a new instance name. -/ -private def expandOptIdent (stx : Syntax) : TermElabM Syntax := -if stx.getNumArgs == 0 then do - id ← mkFreshInstanceName; pure $ mkIdentFrom stx id +private def expandOptIdent (stx : Syntax) : TermElabM Syntax := do +if stx.getNumArgs == 0 then + pure $ mkIdentFrom stx (← mkFreshInstanceName) else - pure $ stx.getArg 0 + pure stx[0] structure BinderView := (id : Syntax) (type : Syntax) (bi : BinderInfo) partial def quoteAutoTactic : Syntax → TermElabM Syntax | stx@(Syntax.ident _ _ _ _) => throwErrorAt stx "invalic auto tactic, identifier is not allowed" -| stx@(Syntax.node k args) => +| stx@(Syntax.node k args) => do if stx.isAntiquot then throwErrorAt stx "invalic auto tactic, antiquotation is not allowed" - else do - empty ← `(Array.empty); - args ← args.foldlM (fun args arg => + else + let quotedArgs ← `(Array.empty) + for arg in args do if k == nullKind && Quotation.isAntiquotSplice arg then throwErrorAt arg "invalic auto tactic, antiquotation is not allowed" - else do - arg ← quoteAutoTactic arg; - `(Array.push $args $arg)) empty; - `(Syntax.node $(quote k) $args) + else + let quotedArg ← quoteAutoTactic arg + quotedArgs ← `(Array.push $quotedArgs $quotedArg) + `(Syntax.node $(quote k) $quotedArgs) | Syntax.atom info val => `(Syntax.atom {} $(quote val)) | Syntax.missing => unreachable! def declareTacticSyntax (tactic : Syntax) : TermElabM Name := -withFreshMacroScope $ do - name ← MonadQuotation.addMacroScope `_auto; - let type := Lean.mkConst `Lean.Syntax; - tactic ← quoteAutoTactic tactic; - val ← elabTerm tactic type; - val ← instantiateMVars val; - trace `Elab.autoParam $ fun _ => val; - let decl := Declaration.defnDecl { name := name, lparams := [], type := type, value := val, hints := ReducibilityHints.opaque, isUnsafe := false }; - addDecl decl; - compileDecl decl; +withFreshMacroScope do + let name ← MonadQuotation.addMacroScope `_auto + let type := Lean.mkConst `Lean.Syntax + let tactic ← quoteAutoTactic tactic + let val ← elabTerm tactic type + let val ← instantiateMVars val + trace[Elab.autoParam]! val + let decl := Declaration.defnDecl { name := name, lparams := [], type := type, value := val, hints := ReducibilityHints.opaque, isUnsafe := false } + addDecl decl + compileDecl decl pure name /- @@ -77,21 +75,21 @@ def binderDefault := parser! " := " >> termParser private def expandBinderModifier (type : Syntax) (optBinderModifier : Syntax) : TermElabM Syntax := if optBinderModifier.isNone then pure type else - let modifier := optBinderModifier.getArg 0; - let kind := modifier.getKind; + let modifier := optBinderModifier[0] + let kind := modifier.getKind if kind == `Lean.Parser.Term.binderDefault then do - let defaultVal := modifier.getArg 1; + let defaultVal := modifier[1] `(optParam $type $defaultVal) else if kind == `Lean.Parser.Term.binderTactic then do - let tac := modifier.getArg 2; - name ← declareTacticSyntax tac; + let tac := modifier[2] + let name ← declareTacticSyntax tac `(autoParam $type $(mkIdentFrom tac name)) else throwUnsupportedSyntax private def getBinderIds (ids : Syntax) : TermElabM (Array Syntax) := ids.getArgs.mapM $ fun id => - let k := id.getKind; + let k := id.getKind if k == identKind || k == `Lean.Parser.Term.hole then pure id else @@ -99,28 +97,28 @@ ids.getArgs.mapM $ fun id => private def matchBinder (stx : Syntax) : TermElabM (Array BinderView) := match stx with -| Syntax.node k args => - if k == `Lean.Parser.Term.simpleBinder then do +| Syntax.node k args => do + if k == `Lean.Parser.Term.simpleBinder then -- binderIdent+ - ids ← getBinderIds (args.get! 0); - let type := mkHole stx; - ids.mapM $ fun id => do id ← expandBinderIdent id; pure { id := id, type := type, bi := BinderInfo.default } - else if k == `Lean.Parser.Term.explicitBinder then do + let ids ← getBinderIds args[0] + let type := mkHole stx + ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := type, bi := BinderInfo.default } + else if k == `Lean.Parser.Term.explicitBinder then -- `(` binderIdent+ binderType (binderDefault <|> binderTactic)? `)` - ids ← getBinderIds (args.get! 1); - let type := expandBinderType stx (args.get! 2); - let optModifier := args.get! 3; - type ← expandBinderModifier type optModifier; - ids.mapM $ fun id => do id ← expandBinderIdent id; pure { id := id, type := type, bi := BinderInfo.default } - else if k == `Lean.Parser.Term.implicitBinder then do + let ids ← getBinderIds args[1] + let type := expandBinderType stx args[2] + let optModifier := args[3] + let type ← expandBinderModifier type optModifier + ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := type, bi := BinderInfo.default } + else if k == `Lean.Parser.Term.implicitBinder then -- `{` binderIdent+ binderType `}` - ids ← getBinderIds (args.get! 1); - let type := expandBinderType stx (args.get! 2); - ids.mapM $ fun id => do id ← expandBinderIdent id; pure { id := id, type := type, bi := BinderInfo.implicit } - else if k == `Lean.Parser.Term.instBinder then do + let ids ← getBinderIds args[1] + let type := expandBinderType stx args[2] + ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := type, bi := BinderInfo.implicit } + else if k == `Lean.Parser.Term.instBinder then -- `[` optIdent type `]` - id ← expandOptIdent (args.get! 1); - let type := args.get! 2; + let id ← expandOptIdent args[1] + let type := args[2] pure #[ { id := id, type := type, bi := BinderInfo.instImplicit } ] else throwUnsupportedSyntax @@ -133,32 +131,30 @@ private partial def elabBinderViews (binderViews : Array BinderView) : Nat → Array Expr → LocalContext → LocalInstances → TermElabM (Array Expr × LocalContext × LocalInstances) | i, fvars, lctx, localInsts => if h : i < binderViews.size then - let binderView := binderViews.get ⟨i, h⟩; - withRef binderView.type $ withLCtx lctx localInsts $ do - type ← elabType binderView.type; - registerFailedToInferBinderTypeInfo type binderView.type; - fvarId ← mkFreshFVarId; - let fvar := mkFVar fvarId; - let fvars := fvars.push fvar; - -- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type); - let lctx := lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi; - className? ← isClass? type; - match className? with - | none => elabBinderViews (i+1) fvars lctx localInsts + let binderView := binderViews.get ⟨i, h⟩ + withRef binderView.type $ withLCtx lctx localInsts do + let type ← elabType binderView.type + registerFailedToInferBinderTypeInfo type binderView.type + let fvarId ← mkFreshFVarId + let fvar := mkFVar fvarId + let fvars := fvars.push fvar + let lctx := lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi + match (← isClass? type) with + | none => elabBinderViews binderViews (i+1) fvars lctx localInsts | some className => resettingSynthInstanceCache do - let localInsts := localInsts.push { className := className, fvar := mkFVar fvarId }; - elabBinderViews (i+1) fvars lctx localInsts + let localInsts := localInsts.push { className := className, fvar := mkFVar fvarId } + elabBinderViews binderViews (i+1) fvars lctx localInsts else pure (fvars, lctx, localInsts) private partial def elabBindersAux (binders : Array Syntax) : Nat → Array Expr → LocalContext → LocalInstances → TermElabM (Array Expr × LocalContext × LocalInstances) -| i, fvars, lctx, localInsts => - if h : i < binders.size then do - binderViews ← matchBinder (binders.get ⟨i, h⟩); - (fvars, lctx, localInsts) ← elabBinderViews binderViews 0 fvars lctx localInsts; - elabBindersAux (i+1) fvars lctx localInsts +| i, fvars, lctx, localInsts => do + if h : i < binders.size then + let binderViews ← matchBinder (binders.get ⟨i, h⟩) + let (fvars, lctx, localInsts) ← elabBinderViews binderViews 0 fvars lctx localInsts + elabBindersAux binders (i+1) fvars lctx localInsts else pure (fvars, lctx, localInsts) @@ -166,12 +162,13 @@ private partial def elabBindersAux (binders : Array Syntax) Elaborate the given binders (i.e., `Syntax` objects for `simpleBinder <|> bracketedBinder`), update the local context, set of local instances, reset instance chache (if needed), and then execute `x` with the updated context. -/ -def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α := -if binders.isEmpty then x #[] -else do - lctx ← getLCtx; - localInsts ← getLocalInstances; - (fvars, lctx, newLocalInsts) ← elabBindersAux binders 0 #[] lctx localInsts; +def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α := do +if binders.isEmpty then + x #[] +else + let lctx ← getLCtx + let localInsts ← getLocalInstances + let (fvars, lctx, newLocalInsts) ← elabBindersAux binders 0 #[] lctx localInsts resettingSynthInstanceCacheWhen (newLocalInsts.size > localInsts.size) $ withLCtx lctx newLocalInsts $ x fvars @@ -181,8 +178,8 @@ elabBinders #[binder] (fun fvars => x (fvars.get! 0)) @[builtinTermElab «forall»] def elabForall : TermElab := fun stx _ => match_syntax stx with | `(forall $binders*, $term) => - elabBinders binders $ fun xs => do - e ← elabType term; + elabBinders binders fun xs => do + let e ← elabType term mkForallFVars xs e | _ => throwUnsupportedSyntax @@ -194,22 +191,22 @@ adaptExpander $ fun stx => match_syntax stx with @[builtinTermElab depArrow] def elabDepArrow : TermElab := fun stx _ => -- bracketedBinder `->` term - let binder := stx.getArg 0; - let term := stx.getArg 2; - elabBinders #[binder] $ fun xs => do - e ← elabType term; - mkForallFVars xs e + let binder := stx[0] + let term := stx[2] + elabBinders #[binder] fun xs => do + mkForallFVars xs (← elabType term) /-- Main loop `getFunBinderIds?` -/ private partial def getFunBinderIdsAux? : Bool → Syntax → Array Syntax → TermElabM (Option (Array Syntax)) | idOnly, stx, acc => match_syntax stx with - | `($f $a) => - if idOnly then pure none - else do - (some acc) ← getFunBinderIdsAux? false f acc | pure none; + | `($f $a) => do + if idOnly then + pure none + else + let (some acc) ← getFunBinderIdsAux? false f acc | pure none getFunBinderIdsAux? true a acc - | `(_) => do { ident ← mkFreshIdent stx; pure (some (acc.push ident)) } + | `(_) => do let ident ← mkFreshIdent stx; pure (some (acc.push ident)) | `($id:ident) => pure (some (acc.push id)) | _ => pure none @@ -224,44 +221,43 @@ getFunBinderIdsAux? false stx #[] The resulting `Bool` is true if a pattern was found. We use it to "mark" a macro expansion. -/ private partial def expandFunBindersAux (binders : Array Syntax) : Syntax → Nat → Array Syntax → TermElabM (Array Syntax × Syntax × Bool) -| body, i, newBinders => +| body, i, newBinders => do if h : i < binders.size then - let binder := binders.get ⟨i, h⟩; - let processAsPattern : Unit → TermElabM (Array Syntax × Syntax × Bool) := fun _ => do { - let pattern := binder; - major ← mkFreshIdent binder; - (binders, newBody, _) ← expandFunBindersAux body (i+1) (newBinders.push $ mkExplicitBinder major (mkHole binder)); - newBody ← `(match $major:ident with | $pattern => $newBody); + let binder := binders.get ⟨i, h⟩ + let processAsPattern : Unit → TermElabM (Array Syntax × Syntax × Bool) := fun _ => do + let pattern := binder + let major ← mkFreshIdent binder + let (binders, newBody, _) ← expandFunBindersAux binders body (i+1) (newBinders.push $ mkExplicitBinder major (mkHole binder)) + let newBody ← `(match $major:ident with | $pattern => $newBody) pure (binders, newBody, true) - }; match binder with - | Syntax.node `Lean.Parser.Term.implicitBinder _ => expandFunBindersAux body (i+1) (newBinders.push binder) - | Syntax.node `Lean.Parser.Term.instBinder _ => expandFunBindersAux body (i+1) (newBinders.push binder) - | Syntax.node `Lean.Parser.Term.explicitBinder _ => expandFunBindersAux body (i+1) (newBinders.push binder) - | Syntax.node `Lean.Parser.Term.hole _ => do - ident ← mkFreshIdent binder; - let type := binder; - expandFunBindersAux body (i+1) (newBinders.push $ mkExplicitBinder ident type) + | Syntax.node `Lean.Parser.Term.implicitBinder _ => expandFunBindersAux binders body (i+1) (newBinders.push binder) + | Syntax.node `Lean.Parser.Term.instBinder _ => expandFunBindersAux binders body (i+1) (newBinders.push binder) + | Syntax.node `Lean.Parser.Term.explicitBinder _ => expandFunBindersAux binders body (i+1) (newBinders.push binder) + | Syntax.node `Lean.Parser.Term.hole _ => + let ident ← mkFreshIdent binder + let type := binder + expandFunBindersAux binders body (i+1) (newBinders.push $ mkExplicitBinder ident type) | Syntax.node `Lean.Parser.Term.paren args => -- `(` (termParser >> parenSpecial)? `)` -- parenSpecial := (tupleTail <|> typeAscription)? - let binderBody := binder.getArg 1; + let binderBody := binder[1] if binderBody.isNone then processAsPattern () else - let idents := binderBody.getArg 0; - let special := binderBody.getArg 1; + let idents := binderBody[0] + let special := binderBody[1] if special.isNone then processAsPattern () - else if (special.getArg 0).getKind != `Lean.Parser.Term.typeAscription then processAsPattern () - else do + else if special[0].getKind != `Lean.Parser.Term.typeAscription then + processAsPattern () + else -- typeAscription := `:` term - let type := ((special.getArg 0).getArg 1); - idents? ← getFunBinderIds? idents; - match idents? with - | some idents => expandFunBindersAux body (i+1) (newBinders ++ idents.map (fun ident => mkExplicitBinder ident type)) + let type := special[0][1] + match (← getFunBinderIds? idents) with + | some idents => expandFunBindersAux binders body (i+1) (newBinders ++ idents.map (fun ident => mkExplicitBinder ident type)) | none => processAsPattern () | Syntax.ident _ _ _ _ => - let type := mkHole binder; - expandFunBindersAux body (i+1) (newBinders.push $ mkExplicitBinder binder type) + let type := mkHole binder + expandFunBindersAux binders body (i+1) (newBinders.push $ mkExplicitBinder binder type) | _ => processAsPattern () else pure (newBinders, body, false) @@ -295,72 +291,72 @@ structure State := (expectedType? : Option Expr := none) private def checkNoOptAutoParam (type : Expr) : TermElabM Unit := do -type ← instantiateMVars type; -when type.isOptParam $ - throwError "optParam is not allowed at 'fun/λ' binders"; -when type.isAutoParam $ +type ← instantiateMVars type +if type.isOptParam then + throwError "optParam is not allowed at 'fun/λ' binders" +if type.isAutoParam then throwError "autoParam is not allowed at 'fun/λ' binders" private def propagateExpectedType (fvar : Expr) (fvarType : Expr) (s : State) : TermElabM State := do match s.expectedType? with | none => pure s -| some expectedType => do - expectedType ← whnfForall expectedType; +| some expectedType => + expectedType ← whnfForall expectedType match expectedType with - | Expr.forallE _ d b _ => do - _ ← isDefEq fvarType d; - checkNoOptAutoParam fvarType; - let b := b.instantiate1 fvar; + | Expr.forallE _ d b _ => + isDefEq fvarType d + checkNoOptAutoParam fvarType + let b := b.instantiate1 fvar pure { s with expectedType? := some b } | _ => pure { s with expectedType? := none } private partial def elabFunBinderViews (binderViews : Array BinderView) : Nat → State → TermElabM State | i, s => if h : i < binderViews.size then - let binderView := binderViews.get ⟨i, h⟩; + let binderView := binderViews.get ⟨i, h⟩ withRef binderView.type $ withLCtx s.lctx s.localInsts $ do - type ← elabType binderView.type; - registerFailedToInferBinderTypeInfo type binderView.type; - checkNoOptAutoParam type; - fvarId ← mkFreshFVarId; - let fvar := mkFVar fvarId; - let s := { s with fvars := s.fvars.push fvar }; - -- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type); + let type ← elabType binderView.type + registerFailedToInferBinderTypeInfo type binderView.type + checkNoOptAutoParam type + let fvarId ← mkFreshFVarId + let fvar := mkFVar fvarId + let s := { s with fvars := s.fvars.push fvar } + -- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type) /- We do **not** want to support default and auto arguments in lambda abstractions. Example: `fun (x : Nat := 10) => x+1`. We do not believe this is an useful feature, and it would complicate the logic here. -/ - let lctx := s.lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi; - s ← withRef binderView.id $ propagateExpectedType fvar type s; - let s := { s with lctx := lctx }; - className? ← isClass? type; - match className? with - | none => elabFunBinderViews (i+1) s - | some className => do + let lctx := s.lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi + let s ← withRef binderView.id $ propagateExpectedType fvar type s + let s := { s with lctx := lctx } + match (← isClass? type) with + | none => elabFunBinderViews binderViews (i+1) s + | some className => resettingSynthInstanceCache do - let localInsts := s.localInsts.push { className := className, fvar := mkFVar fvarId }; - elabFunBinderViews (i+1) { s with localInsts := localInsts } + let localInsts := s.localInsts.push { className := className, fvar := mkFVar fvarId } + elabFunBinderViews binderViews (i+1) { s with localInsts := localInsts } else pure s partial def elabFunBindersAux (binders : Array Syntax) : Nat → State → TermElabM State -| i, s => - if h : i < binders.size then do - binderViews ← matchBinder (binders.get ⟨i, h⟩); - s ← elabFunBinderViews binderViews 0 s; - elabFunBindersAux (i+1) s +| i, s => do + if h : i < binders.size then + let binderViews ← matchBinder (binders.get ⟨i, h⟩) + let s ← elabFunBinderViews binderViews 0 s + elabFunBindersAux binders (i+1) s else pure s end FunBinders def elabFunBinders {α} (binders : Array Syntax) (expectedType? : Option Expr) (x : Array Expr → Option Expr → TermElabM α) : TermElabM α := -if binders.isEmpty then x #[] expectedType? +if binders.isEmpty then + x #[] expectedType? else do - lctx ← getLCtx; - localInsts ← getLocalInstances; - s ← FunBinders.elabFunBindersAux binders 0 { lctx := lctx, localInsts := localInsts, expectedType? := expectedType? }; + let lctx ← getLCtx + let localInsts ← getLocalInstances + let s ← FunBinders.elabFunBindersAux binders 0 { lctx := lctx, localInsts := localInsts, expectedType? := expectedType? } resettingSynthInstanceCacheWhen (s.localInsts.size > localInsts.size) $ withLCtx s.lctx s.localInsts $ x s.fvars s.expectedType? @@ -374,12 +370,12 @@ def expandOptType (ref : Syntax) (optType : Syntax) : Syntax := if optType.isNone then mkHole ref else - (optType.getArg 0).getArg 1 + optType[0][1] /- Helper function for `expandEqnsIntoMatch` -/ private def getMatchAltNumPatterns (matchAlts : Syntax) : Nat := -let alt0 := (matchAlts.getArg 1).getArg 0; -let pats := (alt0.getArg 0).getSepArgs; +let alt0 := matchAlts[1][0] +let pats := alt0[0].getSepArgs pats.size /- Helper function for `expandMatchAltsIntoMatch` -/ @@ -388,10 +384,10 @@ private def expandMatchAltsIntoMatchAux (ref : Syntax) (matchAlts : Syntax) (mat pure $ Syntax.node (if matchTactic then `Lean.Parser.Tactic.match else `Lean.Parser.Term.match) #[mkAtomFrom ref "match ", mkNullNode discrs, mkNullNode, mkAtomFrom ref " with ", matchAlts] | n+1, discrs => withFreshMacroScope do - x ← `(x); - let discrs := if discrs.isEmpty then discrs else discrs.push $ mkAtomFrom ref ", "; - let discrs := discrs.push $ Syntax.node `Lean.Parser.Term.matchDiscr #[mkNullNode, x]; - body ← expandMatchAltsIntoMatchAux n discrs; + let x ← `(x) + let discrs := if discrs.isEmpty then discrs else discrs.push $ mkAtomFrom ref ", " + let discrs := discrs.push $ Syntax.node `Lean.Parser.Term.matchDiscr #[mkNullNode, x] + let body ← expandMatchAltsIntoMatchAux ref matchAlts matchTactic n discrs if matchTactic then `(tactic| intro $x:term; $body:tactic) else @@ -426,24 +422,24 @@ def expandMatchAltsIntoMatchTactic (ref : Syntax) (matchAlts : Syntax) : MacroM expandMatchAltsIntoMatchAux ref matchAlts true (getMatchAltNumPatterns matchAlts) #[] @[builtinTermElab «fun»] def elabFun : TermElab := -fun stx expectedType? => +fun stx expectedType? => do -- "fun " >> ((many1 funBinder >> darrow >> termParser) <|> matchAlts) -if (stx.getArg 1).isOfKind `Lean.Parser.Term.matchAlts then do - stxNew ← liftMacroM $ expandMatchAltsIntoMatch stx (stx.getArg 1); +if stx[1].isOfKind `Lean.Parser.Term.matchAlts then + let stxNew ← liftMacroM $ expandMatchAltsIntoMatch stx stx[1] withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? -else do - let binders := (stx.getArg 1).getArgs; - let body := stx.getArg 3; - (binders, body, expandedPattern) ← expandFunBinders binders body; - if expandedPattern then do - newStx ← `(fun $binders* => $body); +else + let binders := stx[1].getArgs + let body := stx[3] + let (binders, body, expandedPattern) ← expandFunBinders binders body + if expandedPattern then + let newStx ← `(fun $binders* => $body) withMacroExpansion stx newStx $ elabTerm newStx expectedType? else elabFunBinders binders expectedType? $ fun xs expectedType? => do /- We ensure the expectedType here since it will force coercions to be applied if needed. If we just use `elabTerm`, then we will need to a coercion `Coe (α → β) (α → δ)` whenever there is a coercion `Coe β δ`, and another instance for the dependent version. -/ - e ← elabTermEnsuringType body expectedType?; + let e ← elabTermEnsuringType body expectedType? mkLambdaFVars xs e /- If `useLetExpr` is true, then a kernel let-expression `let x : type := val; body` is created. @@ -453,41 +449,37 @@ else do If `elabBodyFirst == true`, then we use the order `binders`, `typeStx`, `body`, and `valStx`. -/ def elabLetDeclAux (n : Name) (binders : Array Syntax) (typeStx : Syntax) (valStx : Syntax) (body : Syntax) (expectedType? : Option Expr) (useLetExpr : Bool) (elabBodyFirst : Bool) : TermElabM Expr := do -(type, val, arity) ← elabBinders binders $ fun xs => do { - type ← elabType typeStx; - registerCustomErrorIfMVar type typeStx "failed to infer 'let' declaration type"; - if elabBodyFirst then do - type ← mkForallFVars xs type; - val ← mkFreshExprMVar type; +let (type, val, arity) ← elabBinders binders fun xs => do + let type ← elabType typeStx + registerCustomErrorIfMVar type typeStx "failed to infer 'let' declaration type" + if elabBodyFirst then + let type ← mkForallFVars xs type + let val ← mkFreshExprMVar type pure (type, val, xs.size) - else do - val ← elabTermEnsuringType valStx type; - type ← mkForallFVars xs type; - val ← mkLambdaFVars xs val; + else + let val ← elabTermEnsuringType valStx type + let type ← mkForallFVars xs type + let val ← mkLambdaFVars xs val pure (type, val, xs.size) -}; -trace `Elab.let.decl $ fun _ => n ++ " : " ++ type ++ " := " ++ val; -result ← +trace[Elab.let.decl]! "{n} : {type} := {val}" +let result ← if useLetExpr then - withLetDecl n type val $ fun x => do - body ← elabTerm body expectedType?; - body ← instantiateMVars body; + withLetDecl n type val fun x => do + let body ← elabTerm body expectedType? + let body ← instantiateMVars body mkLetFVars #[x] body - else do { - f ← withLocalDecl n BinderInfo.default type $ fun x => do { - body ← elabTerm body expectedType?; - body ← instantiateMVars body; + else + let f ← withLocalDecl n BinderInfo.default type fun x => do + let body ← elabTerm body expectedType? + let body ← instantiateMVars body mkLambdaFVars #[x] body - }; pure $ mkApp f val - }; -when elabBodyFirst do { +if elabBodyFirst then forallBoundedTelescope type arity fun xs type => do - valResult ← elabTermEnsuringType valStx type; - valResult ← mkLambdaFVars xs valResult; - unlessM (isDefEq val valResult) do + let valResult ← elabTermEnsuringType valStx type + let valResult ← mkLambdaFVars xs valResult + unless (← isDefEq val valResult) do throwError "unexpected error when elaborating 'let'" -}; pure result structure LetIdDeclView := @@ -498,11 +490,11 @@ structure LetIdDeclView := def mkLetIdDeclView (letIdDecl : Syntax) : LetIdDeclView := -- `letIdDecl` is of the form `ident >> many bracketedBinder >> optType >> " := " >> termParser -let id := (letIdDecl.getArg 0).getId; -let binders := (letIdDecl.getArg 1).getArgs; -let optType := letIdDecl.getArg 2; -let type := expandOptType letIdDecl optType; -let value := letIdDecl.getArg 4; +let id := letIdDecl[0].getId +let binders := letIdDecl[1].getArgs +let optType := letIdDecl[2] +let type := expandOptType letIdDecl optType +let value := letIdDecl[4] { id := id, binders := binders, type := type, value := value } private def expandLetEqnsDeclVal (ref : Syntax) (alts : Syntax) : Nat → Array Syntax → MacroM Syntax @@ -510,42 +502,42 @@ private def expandLetEqnsDeclVal (ref : Syntax) (alts : Syntax) : Nat → Array pure $ Syntax.node `Lean.Parser.Term.match #[mkAtomFrom ref "match ", mkNullNode discrs, mkNullNode, mkAtomFrom ref " with ", alts] | n+1, discrs => withFreshMacroScope do - x ← `(x); - let discrs := if discrs.isEmpty then discrs else discrs.push $ mkAtomFrom ref ", "; - let discrs := discrs.push $ Syntax.node `Lean.Parser.Term.matchDiscr #[mkNullNode, x]; - body ← expandLetEqnsDeclVal n discrs; + let x ← `(x) + let discrs := if discrs.isEmpty then discrs else discrs.push $ mkAtomFrom ref ", " + let discrs := discrs.push $ Syntax.node `Lean.Parser.Term.matchDiscr #[mkNullNode, x] + let body ← expandLetEqnsDeclVal ref alts n discrs `(fun $x => $body) def expandLetEqnsDecl (letDecl : Syntax) : MacroM Syntax := do -let ref := letDecl; -let matchAlts := letDecl.getArg 3; -val ← expandMatchAltsIntoMatch ref matchAlts; -pure $ Syntax.node `Lean.Parser.Term.letIdDecl #[letDecl.getArg 0, letDecl.getArg 1, letDecl.getArg 2, mkAtomFrom ref " := ", val] +let ref := letDecl +let matchAlts := letDecl[3] +let val ← expandMatchAltsIntoMatch ref matchAlts +pure $ Syntax.node `Lean.Parser.Term.letIdDecl #[letDecl[0], letDecl[1], letDecl[2], mkAtomFrom ref " := ", val] def elabLetDeclCore (stx : Syntax) (expectedType? : Option Expr) (useLetExpr : Bool) (elabBodyFirst : Bool) : TermElabM Expr := do -let ref := stx; -let letDecl := (stx.getArg 1).getArg 0; -let body := stx.getArg 3; +let ref := stx +let letDecl := stx[1][0] +let body := stx[3] if letDecl.getKind == `Lean.Parser.Term.letIdDecl then - let { id := id, binders := binders, type := type, value := val } := mkLetIdDeclView letDecl; + let { id := id, binders := binders, type := type, value := val } := mkLetIdDeclView letDecl elabLetDeclAux id binders type val body expectedType? useLetExpr elabBodyFirst -else if letDecl.getKind == `Lean.Parser.Term.letPatDecl then do +else if letDecl.getKind == `Lean.Parser.Term.letPatDecl then -- node `Lean.Parser.Term.letPatDecl $ try (termParser >> pushNone >> optType >> " := ") >> termParser - let pat := letDecl.getArg 0; - let optType := letDecl.getArg 2; - let type := expandOptType stx optType; - let val := letDecl.getArg 4; - stxNew ← `(let x : $type := $val; match x with | $pat => $body); + let pat := letDecl[0] + let optType := letDecl[2] + let type := expandOptType stx optType + let val := letDecl[4] + let stxNew ← `(let x : $type := $val; match x with | $pat => $body) let stxNew := match useLetExpr, elabBodyFirst with | true, false => stxNew | true, true => stxNew.updateKind `Lean.Parser.Term.«let*» | false, true => stxNew.updateKind `Lean.Parser.Term.«let!» - | false, false => unreachable!; + | false, false => unreachable! withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? -else if letDecl.getKind == `Lean.Parser.Term.letEqnsDecl then do - letDeclIdNew ← liftMacroM $ expandLetEqnsDecl letDecl; - let declNew := (stx.getArg 1).setArg 0 letDeclIdNew; - let stxNew := stx.setArg 1 declNew; +else if letDecl.getKind == `Lean.Parser.Term.letEqnsDecl then + let letDeclIdNew ← liftMacroM $ expandLetEqnsDecl letDecl + let declNew := stx[1].setArg 0 letDeclIdNew + let stxNew := stx.setArg 1 declNew withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? else throwUnsupportedSyntax @@ -559,10 +551,6 @@ fun stx expectedType? => elabLetDeclCore stx expectedType? false false @[builtinTermElab «let*»] def elabLetStarDecl : TermElab := fun stx expectedType? => elabLetDeclCore stx expectedType? true true -@[init] private def regTraceClasses : IO Unit := do -registerTraceClass `Elab.let; -pure () +initialize registerTraceClass `Elab.let -end Term -end Elab -end Lean +end Lean.Elab.Term