/- Copyright (c) 2019 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ import Lean.Elab.Term import Lean.Elab.Quotation namespace Lean namespace Elab namespace Term /-- Given syntax of the forms a) (`:` term)? b) `:` term return `term` if it is present, or a hole if not. -/ private def expandBinderType (stx : Syntax) : Syntax := if stx.getNumArgs == 0 then mkHole stx else stx.getArg 1 /-- Given syntax of the form `ident <|> hole`, return `ident`. If `hole`, then we create a new anonymous name. -/ private def expandBinderIdent (stx : Syntax) : TermElabM Syntax := match_syntax stx with | `(_) => mkFreshAnonymousIdent stx | _ => pure stx /-- Given syntax of the form `(ident >> " : ")?`, return `ident`, or a new instance name. -/ private def expandOptIdent (stx : Syntax) : TermElabM Syntax := if stx.getNumArgs == 0 then do id ← mkFreshInstanceName; pure $ mkIdentFrom stx id else pure $ stx.getArg 0 structure BinderView := (id : Syntax) (type : Syntax) (bi : BinderInfo) partial def quoteAutoTactic : Syntax → TermElabM Syntax | stx@(Syntax.ident _ _ _ _) => throwError stx "invalic auto tactic, identifier is not allowed" | stx@(Syntax.node k args) => if Quotation.isAntiquot stx then throwError stx "invalic auto tactic, antiquotation is not allowed" else do empty ← `(Array.empty); args ← args.foldlM (fun args arg => if k == nullKind && Quotation.isAntiquotSplice arg then throwError arg "invalic auto tactic, antiquotation is not allowed" else do arg ← quoteAutoTactic arg; `(Array.push $args $arg)) empty; `(Syntax.node $(quote k) $args) | Syntax.atom info val => `(Syntax.atom {} $(quote val)) | Syntax.missing => unreachable! def declareTacticSyntax (tactic : Syntax) : TermElabM Name := withFreshMacroScope $ do name ← MonadQuotation.addMacroScope `_auto; let type := Lean.mkConst `Lean.Syntax; tactic ← quoteAutoTactic tactic; val ← elabTerm tactic type; val ← instantiateMVars tactic val; trace `Elab.autoParam tactic $ fun _ => val; let decl := Declaration.defnDecl { name := name, lparams := [], type := type, value := val, hints := ReducibilityHints.opaque, isUnsafe := false }; addDecl tactic decl; compileDecl tactic decl; pure name /- Expand `optional (binderTactic <|> binderDefault)` def binderTactic := parser! " := " >> " by " >> tacticParser def binderDefault := parser! " := " >> termParser -/ private def expandBinderModifier (type : Syntax) (optBinderModifier : Syntax) : TermElabM Syntax := if optBinderModifier.isNone then pure type else let modifier := optBinderModifier.getArg 0; let kind := modifier.getKind; if kind == `Lean.Parser.Term.binderDefault then do let defaultVal := modifier.getArg 1; `(optParam $type $defaultVal) else if kind == `Lean.Parser.Term.binderTactic then do let tac := modifier.getArg 2; name ← declareTacticSyntax tac; `(autoParam $type $(mkTermIdFrom tac name)) else throwUnsupportedSyntax private def getBinderIds (ids : Syntax) : TermElabM (Array Syntax) := ids.getArgs.mapM $ fun id => let k := id.getKind; if k == identKind || k == `Lean.Parser.Term.hole then pure id else if k == `Lean.Parser.Term.id && id.getArgs.size == 2 && (id.getArg 1).isNone then -- The parser never generates this case, but it is convenient when writting macros. pure (id.getArg 0) else throwError id "identifier or `_` expected" private def matchBinder (stx : Syntax) : TermElabM (Array BinderView) := match stx with | Syntax.node k args => if k == `Lean.Parser.Term.simpleBinder then do -- binderIdent+ ids ← getBinderIds (args.get! 0); let type := mkHole stx; ids.mapM $ fun id => do id ← expandBinderIdent id; pure { id := id, type := type, bi := BinderInfo.default } else if k == `Lean.Parser.Term.explicitBinder then do -- `(` binderIdent+ binderType (binderDefault <|> binderTactic)? `)` ids ← getBinderIds (args.get! 1); let type := expandBinderType (args.get! 2); let optModifier := args.get! 3; type ← expandBinderModifier type optModifier; ids.mapM $ fun id => do id ← expandBinderIdent id; pure { id := id, type := type, bi := BinderInfo.default } else if k == `Lean.Parser.Term.implicitBinder then do -- `{` binderIdent+ binderType `}` ids ← getBinderIds (args.get! 1); let type := expandBinderType (args.get! 2); ids.mapM $ fun id => do id ← expandBinderIdent id; pure { id := id, type := type, bi := BinderInfo.implicit } else if k == `Lean.Parser.Term.instBinder then do -- `[` optIdent type `]` id ← expandOptIdent (args.get! 1); let type := args.get! 2; pure #[ { id := id, type := type, bi := BinderInfo.instImplicit } ] else throwUnsupportedSyntax | _ => throwUnsupportedSyntax private partial def elabBinderViews (binderViews : Array BinderView) : Nat → Array Expr → LocalContext → LocalInstances → TermElabM (Array Expr × LocalContext × LocalInstances) | i, fvars, lctx, localInsts => if h : i < binderViews.size then let binderView := binderViews.get ⟨i, h⟩; withLCtx lctx localInsts $ do type ← elabType binderView.type; fvarId ← mkFreshFVarId; let fvar := mkFVar fvarId; let fvars := fvars.push fvar; -- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type); let lctx := lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi; className? ← isClass binderView.type type; match className? with | none => elabBinderViews (i+1) fvars lctx localInsts | some className => do resetSynthInstanceCache; let localInsts := localInsts.push { className := className, fvar := mkFVar fvarId }; elabBinderViews (i+1) fvars lctx localInsts else pure (fvars, lctx, localInsts) private partial def elabBindersAux (binders : Array Syntax) : Nat → Array Expr → LocalContext → LocalInstances → TermElabM (Array Expr × LocalContext × LocalInstances) | i, fvars, lctx, localInsts => if h : i < binders.size then do binderViews ← matchBinder (binders.get ⟨i, h⟩); (fvars, lctx, localInsts) ← elabBinderViews binderViews 0 fvars lctx localInsts; elabBindersAux (i+1) fvars lctx localInsts else pure (fvars, lctx, localInsts) /-- Elaborate the given binders (i.e., `Syntax` objects for `simpleBinder <|> bracketedBinder`), update the local context, set of local instances, reset instance chache (if needed), and then execute `x` with the updated context. -/ def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α := if binders.isEmpty then x #[] else do lctx ← getLCtx; localInsts ← getLocalInsts; (fvars, lctx, newLocalInsts) ← elabBindersAux binders 0 #[] lctx localInsts; resettingSynthInstanceCacheWhen (newLocalInsts.size > localInsts.size) $ withLCtx lctx newLocalInsts $ x fvars @[inline] def elabBinder {α} (binder : Syntax) (x : Expr → TermElabM α) : TermElabM α := elabBinders #[binder] (fun fvars => x (fvars.get! 0)) @[builtinTermElab «forall»] def elabForall : TermElab := fun stx _ => match_syntax stx with | `(forall $binders*, $term) => elabBinders binders $ fun xs => do e ← elabType term; mkForall stx xs e | _ => throwUnsupportedSyntax @[builtinTermElab arrow] def elabArrow : TermElab := adaptExpander $ fun stx => match_syntax stx with | `($dom:term -> $rng) => `(forall (a : $dom), $rng) | _ => throwUnsupportedSyntax @[builtinTermElab depArrow] def elabDepArrow : TermElab := fun stx _ => -- bracketedBinder `->` term let binder := stx.getArg 0; let term := stx.getArg 2; elabBinders #[binder] $ fun xs => do e ← elabType term; mkForall stx xs e /-- Main loop `getFunBinderIds?` -/ private partial def getFunBinderIdsAux? : Bool → Syntax → Array Syntax → TermElabM (Option (Array Syntax)) | idOnly, stx, acc => match_syntax stx with | `($f $a) => if idOnly then pure none else do (some acc) ← getFunBinderIdsAux? false f acc | pure none; getFunBinderIdsAux? true a acc | `(_) => do ident ← mkFreshAnonymousIdent stx; pure (some (acc.push ident)) | stx => match stx.isSimpleTermId? true with | some id => pure (some (acc.push id)) | _ => pure none /-- Auxiliary functions for converting `Term.app ... (Term.app id_1 id_2) ... id_n` into `#[id_1, ..., id_m]` It is used at `expandFunBinders`. -/ private def getFunBinderIds? (stx : Syntax) : TermElabM (Option (Array Syntax)) := getFunBinderIdsAux? false stx #[] /-- Main loop for `expandFunBinders`. -/ private partial def expandFunBindersAux (binders : Array Syntax) : Syntax → Nat → Array Syntax → TermElabM (Array Syntax × Syntax) | body, i, newBinders => if h : i < binders.size then let binder := binders.get ⟨i, h⟩; let processAsPattern : Unit → TermElabM (Array Syntax × Syntax) := fun _ => do { let pattern := binder; major ← mkFreshAnonymousIdent binder; (binders, newBody) ← expandFunBindersAux body (i+1) (newBinders.push $ mkExplicitBinder major (mkHole binder)); newBody ← `(match $major:ident with | $pattern => $newBody); pure (binders, newBody) }; match binder with | Syntax.node `Lean.Parser.Term.implicitBinder _ => expandFunBindersAux body (i+1) (newBinders.push binder) | Syntax.node `Lean.Parser.Term.instBinder _ => expandFunBindersAux body (i+1) (newBinders.push binder) | Syntax.node `Lean.Parser.Term.hole _ => do ident ← mkFreshAnonymousIdent binder; let type := binder; expandFunBindersAux body (i+1) (newBinders.push $ mkExplicitBinder ident type) | Syntax.node `Lean.Parser.Term.paren args => -- `(` (termParser >> parenSpecial)? `)` -- parenSpecial := (tupleTail <|> typeAscription)? let binderBody := binder.getArg 1; if binderBody.isNone then processAsPattern () else let idents := binderBody.getArg 0; let special := binderBody.getArg 1; if special.isNone then processAsPattern () else if (special.getArg 0).getKind != `Lean.Parser.Term.typeAscription then processAsPattern () else do -- typeAscription := `:` term let type := ((special.getArg 0).getArg 1); idents? ← getFunBinderIds? idents; match idents? with | some idents => expandFunBindersAux body (i+1) (newBinders ++ idents.map (fun ident => mkExplicitBinder ident type)) | none => processAsPattern () | binder => match binder.isTermId? true with | some (ident, extra) => do unless extra.isNone $ throwError binder "invalid binder, simple identifier expected"; let type := mkHole binder; expandFunBindersAux body (i+1) (newBinders.push $ mkExplicitBinder ident type) | none => processAsPattern () else pure (newBinders, body) /-- Auxiliary function for expanding `fun` notation binders. Recall that `fun` parser is defined as ``` def funBinder : Parser := implicitBinder <|> instBinder <|> termParser maxPrec parser! unicodeSymbol "λ" "fun" >> many1 funBinder >> "=>" >> termParser ``` to allow notation such as `fun (a, b) => a + b`, where `(a, b)` should be treated as a pattern. The result is a pair `(explicitBinders, newBody)`, where `explicitBinders` is syntax of the form ``` `(` ident `:` term `)` ``` which can be elaborated using `elabBinders`, and `newBody` is the updated `body` syntax. We update the `body` syntax when expanding the pattern notation. Example: `fun (a, b) => a + b` expands into `fun _a_1 => match _a_1 with | (a, b) => a + b`. See local function `processAsPattern` at `expandFunBindersAux`. -/ def expandFunBinders (binders : Array Syntax) (body : Syntax) : TermElabM (Array Syntax × Syntax) := expandFunBindersAux binders body 0 #[] namespace FunBinders structure State := (fvars : Array Expr := #[]) (lctx : LocalContext) (localInsts : LocalInstances) (expectedType? : Option Expr := none) private def checkNoOptAutoParam (ref : Syntax) (type : Expr) : TermElabM Unit := do type ← instantiateMVars ref type; when type.isOptParam $ throwError ref "optParam is not allowed at 'fun/λ' binders"; when type.isAutoParam $ throwError ref "autoParam is not allowed at 'fun/λ' binders" private def propagateExpectedType (ref : Syntax) (fvar : Expr) (fvarType : Expr) (s : State) : TermElabM State := do match s.expectedType? with | none => pure s | some expectedType => do expectedType ← whnfForall ref expectedType; match expectedType with | Expr.forallE _ d b _ => do _ ← isDefEq ref fvarType d; checkNoOptAutoParam ref fvarType; let b := b.instantiate1 fvar; pure { s with expectedType? := some b } | _ => pure { s with expectedType? := none } private partial def elabFunBinderViews (binderViews : Array BinderView) : Nat → State → TermElabM State | i, s => if h : i < binderViews.size then let binderView := binderViews.get ⟨i, h⟩; withLCtx s.lctx s.localInsts $ do type ← elabType binderView.type; checkNoOptAutoParam binderView.type type; fvarId ← mkFreshFVarId; let fvar := mkFVar fvarId; let s := { s with fvars := s.fvars.push fvar }; -- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type); /- We do **not** want to support default and auto arguments in lambda abstractions. Example: `fun (x : Nat := 10) => x+1`. We do not believe this is an useful feature, and it would complicate the logic here. -/ let lctx := s.lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi; s ← propagateExpectedType binderView.id fvar type s; let s := { s with lctx := lctx }; className? ← isClass binderView.type type; match className? with | none => elabFunBinderViews (i+1) s | some className => do resetSynthInstanceCache; let localInsts := s.localInsts.push { className := className, fvar := mkFVar fvarId }; elabFunBinderViews (i+1) { s with localInsts := localInsts } else pure s partial def elabFunBindersAux (binders : Array Syntax) : Nat → State → TermElabM State | i, s => if h : i < binders.size then do binderViews ← matchBinder (binders.get ⟨i, h⟩); s ← elabFunBinderViews binderViews 0 s; elabFunBindersAux (i+1) s else pure s end FunBinders def elabFunBinders {α} (binders : Array Syntax) (expectedType? : Option Expr) (x : Array Expr → Option Expr → TermElabM α) : TermElabM α := if binders.isEmpty then x #[] expectedType? else do lctx ← getLCtx; localInsts ← getLocalInsts; s ← FunBinders.elabFunBindersAux binders 0 { lctx := lctx, localInsts := localInsts, expectedType? := expectedType? }; resettingSynthInstanceCacheWhen (s.localInsts.size > localInsts.size) $ withLCtx s.lctx s.localInsts $ x s.fvars s.expectedType? @[builtinTermElab «fun»] def elabFun : TermElab := fun stx expectedType? => do -- `fun` term+ `=>` term let binders := (stx.getArg 1).getArgs; let body := stx.getArg 3; (binders, body) ← expandFunBinders binders body; elabFunBinders binders expectedType? $ fun xs expectedType? => do { e ← elabTerm body expectedType?; mkLambda stx xs e } /- Recall that ``` def typeSpec := parser! " : " >> termParser def optType : Parser := optional typeSpec ``` -/ def expandOptType (ref : Syntax) (optType : Syntax) : Syntax := if optType.isNone then mkHole ref else (optType.getArg 0).getArg 1 /- If `useLetExpr` is true, then a kernel let-expression `let x : type := val; body` is created. Otherwise, we create a term of the form `(fun (x : type) => body) val` -/ def elabLetDeclAux (ref : Syntax) (n : Name) (binders : Array Syntax) (typeStx : Syntax) (valStx : Syntax) (body : Syntax) (expectedType? : Option Expr) (useLetExpr : Bool) : TermElabM Expr := do (type, val) ← elabBinders binders $ fun xs => do { type ← elabType typeStx; val ← elabTerm valStx type; val ← ensureHasType valStx type val; type ← mkForall ref xs type; val ← mkLambda ref xs val; pure (type, val) }; trace `Elab.let.decl ref $ fun _ => n ++ " : " ++ type ++ " := " ++ val; if useLetExpr then withLetDecl ref n type val $ fun x => do body ← elabTerm body expectedType?; body ← instantiateMVars ref body; mkLet ref x body else do f ← withLocalDecl ref n BinderInfo.default type $ fun x => do { body ← elabTerm body expectedType?; body ← instantiateMVars ref body; mkLambda ref #[x] body }; pure $ mkApp f val @[builtinTermElab «let»] def elabLetDecl : TermElab := fun stx expectedType? => match_syntax stx with | `(let $id:ident $args* := $val; $body) => elabLetDeclAux stx id.getId args (mkHole stx) val body expectedType? true | `(let $id:ident $args* : $type := $val; $body) => elabLetDeclAux stx id.getId args type val body expectedType? true | `(let $pat:term := $val; $body) => do stxNew ← `(let x := $val; match x with $pat => $body); withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? | `(let $pat:term : $type := $val; $body) => do stxNew ← `(let x : $type := $val; match x with $pat => $body); withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? | _ => throwUnsupportedSyntax @[builtinTermElab «let!»] def elabLetBangDecl : TermElab := fun stx expectedType? => match_syntax stx with | `(let! $id:ident $args* := $val; $body) => elabLetDeclAux stx id.getId args (mkHole stx) val body expectedType? false | `(let! $id:ident $args* : $type := $val; $body) => elabLetDeclAux stx id.getId args type val body expectedType? false | `(let! $pat:term := $val; $body) => do stxNew ← `(let! x := $val; match x with $pat => $body); withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? | `(let! $pat:term : $type := $val; $body) => do stxNew ← `(let! x : $type := $val; match x with $pat => $body); withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? | _ => throwUnsupportedSyntax @[init] private def regTraceClasses : IO Unit := do registerTraceClass `Elab.let; pure () end Term end Elab end Lean