lean4-htt/src/Lean/Elab/Binders.lean
2020-06-25 11:21:17 -07:00

445 lines
18 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Elab.Term
import Lean.Elab.Quotation
namespace Lean
namespace Elab
namespace Term
/--
Given syntax of the forms
a) (`:` term)?
b) `:` term
return `term` if it is present, or a hole if not. -/
private def expandBinderType (stx : Syntax) : Syntax :=
if stx.getNumArgs == 0 then
mkHole stx
else
stx.getArg 1
/-- Given syntax of the form `ident <|> hole`, return `ident`. If `hole`, then we create a new anonymous name. -/
private def expandBinderIdent (stx : Syntax) : TermElabM Syntax :=
match_syntax stx with
| `(_) => mkFreshAnonymousIdent stx
| _ => pure stx
/-- Given syntax of the form `(ident >> " : ")?`, return `ident`, or a new instance name. -/
private def expandOptIdent (stx : Syntax) : TermElabM Syntax :=
if stx.getNumArgs == 0 then do
id ← mkFreshInstanceName; pure $ mkIdentFrom stx id
else
pure $ stx.getArg 0
structure BinderView :=
(id : Syntax) (type : Syntax) (bi : BinderInfo)
partial def quoteAutoTactic : Syntax → TermElabM Syntax
| stx@(Syntax.ident _ _ _ _) => throwError stx "invalic auto tactic, identifier is not allowed"
| stx@(Syntax.node k args) =>
if Quotation.isAntiquot stx then
throwError stx "invalic auto tactic, antiquotation is not allowed"
else do
empty ← `(Array.empty);
args ← args.foldlM (fun args arg =>
if k == nullKind && Quotation.isAntiquotSplice arg then
throwError arg "invalic auto tactic, antiquotation is not allowed"
else do
arg ← quoteAutoTactic arg;
`(Array.push $args $arg)) empty;
`(Syntax.node $(quote k) $args)
| Syntax.atom info val => `(Syntax.atom {} $(quote val))
| Syntax.missing => unreachable!
def declareTacticSyntax (tactic : Syntax) : TermElabM Name :=
withFreshMacroScope $ do
name ← MonadQuotation.addMacroScope `_auto;
let type := Lean.mkConst `Lean.Syntax;
tactic ← quoteAutoTactic tactic;
val ← elabTerm tactic type;
val ← instantiateMVars tactic val;
trace `Elab.autoParam tactic $ fun _ => val;
let decl := Declaration.defnDecl { name := name, lparams := [], type := type, value := val, hints := ReducibilityHints.opaque, isUnsafe := false };
addDecl tactic decl;
compileDecl tactic decl;
pure name
/-
Expand `optional (binderTactic <|> binderDefault)`
def binderTactic := parser! " := " >> " by " >> tacticParser
def binderDefault := parser! " := " >> termParser
-/
private def expandBinderModifier (type : Syntax) (optBinderModifier : Syntax) : TermElabM Syntax :=
if optBinderModifier.isNone then pure type
else
let modifier := optBinderModifier.getArg 0;
let kind := modifier.getKind;
if kind == `Lean.Parser.Term.binderDefault then do
let defaultVal := modifier.getArg 1;
`(optParam $type $defaultVal)
else if kind == `Lean.Parser.Term.binderTactic then do
let tac := modifier.getArg 2;
name ← declareTacticSyntax tac;
`(autoParam $type $(mkTermIdFrom tac name))
else
throwUnsupportedSyntax
private def getBinderIds (ids : Syntax) : TermElabM (Array Syntax) :=
ids.getArgs.mapM $ fun id =>
let k := id.getKind;
if k == identKind || k == `Lean.Parser.Term.hole then
pure id
else if k == `Lean.Parser.Term.id && id.getArgs.size == 2 && (id.getArg 1).isNone then
-- The parser never generates this case, but it is convenient when writting macros.
pure (id.getArg 0)
else
throwError id "identifier or `_` expected"
private def matchBinder (stx : Syntax) : TermElabM (Array BinderView) :=
match stx with
| Syntax.node k args =>
if k == `Lean.Parser.Term.simpleBinder then do
-- binderIdent+
ids ← getBinderIds (args.get! 0);
let type := mkHole stx;
ids.mapM $ fun id => do id ← expandBinderIdent id; pure { id := id, type := type, bi := BinderInfo.default }
else if k == `Lean.Parser.Term.explicitBinder then do
-- `(` binderIdent+ binderType (binderDefault <|> binderTactic)? `)`
ids ← getBinderIds (args.get! 1);
let type := expandBinderType (args.get! 2);
let optModifier := args.get! 3;
type ← expandBinderModifier type optModifier;
ids.mapM $ fun id => do id ← expandBinderIdent id; pure { id := id, type := type, bi := BinderInfo.default }
else if k == `Lean.Parser.Term.implicitBinder then do
-- `{` binderIdent+ binderType `}`
ids ← getBinderIds (args.get! 1);
let type := expandBinderType (args.get! 2);
ids.mapM $ fun id => do id ← expandBinderIdent id; pure { id := id, type := type, bi := BinderInfo.implicit }
else if k == `Lean.Parser.Term.instBinder then do
-- `[` optIdent type `]`
id ← expandOptIdent (args.get! 1);
let type := args.get! 2;
pure #[ { id := id, type := type, bi := BinderInfo.instImplicit } ]
else
throwUnsupportedSyntax
| _ => throwUnsupportedSyntax
private partial def elabBinderViews (binderViews : Array BinderView)
: Nat → Array Expr → LocalContext → LocalInstances → TermElabM (Array Expr × LocalContext × LocalInstances)
| i, fvars, lctx, localInsts =>
if h : i < binderViews.size then
let binderView := binderViews.get ⟨i, h⟩;
withLCtx lctx localInsts $ do
type ← elabType binderView.type;
fvarId ← mkFreshFVarId;
let fvar := mkFVar fvarId;
let fvars := fvars.push fvar;
-- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type);
let lctx := lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi;
className? ← isClass binderView.type type;
match className? with
| none => elabBinderViews (i+1) fvars lctx localInsts
| some className => do
resetSynthInstanceCache;
let localInsts := localInsts.push { className := className, fvar := mkFVar fvarId };
elabBinderViews (i+1) fvars lctx localInsts
else
pure (fvars, lctx, localInsts)
private partial def elabBindersAux (binders : Array Syntax)
: Nat → Array Expr → LocalContext → LocalInstances → TermElabM (Array Expr × LocalContext × LocalInstances)
| i, fvars, lctx, localInsts =>
if h : i < binders.size then do
binderViews ← matchBinder (binders.get ⟨i, h⟩);
(fvars, lctx, localInsts) ← elabBinderViews binderViews 0 fvars lctx localInsts;
elabBindersAux (i+1) fvars lctx localInsts
else
pure (fvars, lctx, localInsts)
/--
Elaborate the given binders (i.e., `Syntax` objects for `simpleBinder <|> bracketedBinder`),
update the local context, set of local instances, reset instance chache (if needed), and then
execute `x` with the updated context. -/
def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α :=
if binders.isEmpty then x #[]
else do
lctx ← getLCtx;
localInsts ← getLocalInsts;
(fvars, lctx, newLocalInsts) ← elabBindersAux binders 0 #[] lctx localInsts;
resettingSynthInstanceCacheWhen (newLocalInsts.size > localInsts.size) $ withLCtx lctx newLocalInsts $
x fvars
@[inline] def elabBinder {α} (binder : Syntax) (x : Expr → TermElabM α) : TermElabM α :=
elabBinders #[binder] (fun fvars => x (fvars.get! 0))
@[builtinTermElab «forall»] def elabForall : TermElab :=
fun stx _ => match_syntax stx with
| `(forall $binders*, $term) =>
elabBinders binders $ fun xs => do
e ← elabType term;
mkForall stx xs e
| _ => throwUnsupportedSyntax
@[builtinTermElab arrow] def elabArrow : TermElab :=
adaptExpander $ fun stx => match_syntax stx with
| `($dom:term -> $rng) => `(forall (a : $dom), $rng)
| _ => throwUnsupportedSyntax
@[builtinTermElab depArrow] def elabDepArrow : TermElab :=
fun stx _ =>
-- bracketedBinder `->` term
let binder := stx.getArg 0;
let term := stx.getArg 2;
elabBinders #[binder] $ fun xs => do
e ← elabType term;
mkForall stx xs e
/-- Main loop `getFunBinderIds?` -/
private partial def getFunBinderIdsAux? : Bool → Syntax → Array Syntax → TermElabM (Option (Array Syntax))
| idOnly, stx, acc =>
match_syntax stx with
| `($f $a) =>
if idOnly then pure none
else do
(some acc) ← getFunBinderIdsAux? false f acc | pure none;
getFunBinderIdsAux? true a acc
| `(_) => do ident ← mkFreshAnonymousIdent stx; pure (some (acc.push ident))
| stx =>
match stx.isSimpleTermId? true with
| some id => pure (some (acc.push id))
| _ => pure none
/--
Auxiliary functions for converting `Term.app ... (Term.app id_1 id_2) ... id_n` into `#[id_1, ..., id_m]`
It is used at `expandFunBinders`. -/
private def getFunBinderIds? (stx : Syntax) : TermElabM (Option (Array Syntax)) :=
getFunBinderIdsAux? false stx #[]
/-- Main loop for `expandFunBinders`. -/
private partial def expandFunBindersAux (binders : Array Syntax) : Syntax → Nat → Array Syntax → TermElabM (Array Syntax × Syntax)
| body, i, newBinders =>
if h : i < binders.size then
let binder := binders.get ⟨i, h⟩;
let processAsPattern : Unit → TermElabM (Array Syntax × Syntax) := fun _ => do {
let pattern := binder;
major ← mkFreshAnonymousIdent binder;
(binders, newBody) ← expandFunBindersAux body (i+1) (newBinders.push $ mkExplicitBinder major (mkHole binder));
newBody ← `(match $major:ident with | $pattern => $newBody);
pure (binders, newBody)
};
match binder with
| Syntax.node `Lean.Parser.Term.implicitBinder _ => expandFunBindersAux body (i+1) (newBinders.push binder)
| Syntax.node `Lean.Parser.Term.instBinder _ => expandFunBindersAux body (i+1) (newBinders.push binder)
| Syntax.node `Lean.Parser.Term.hole _ => do
ident ← mkFreshAnonymousIdent binder;
let type := binder;
expandFunBindersAux body (i+1) (newBinders.push $ mkExplicitBinder ident type)
| Syntax.node `Lean.Parser.Term.paren args =>
-- `(` (termParser >> parenSpecial)? `)`
-- parenSpecial := (tupleTail <|> typeAscription)?
let binderBody := binder.getArg 1;
if binderBody.isNone then processAsPattern ()
else
let idents := binderBody.getArg 0;
let special := binderBody.getArg 1;
if special.isNone then processAsPattern ()
else if (special.getArg 0).getKind != `Lean.Parser.Term.typeAscription then processAsPattern ()
else do
-- typeAscription := `:` term
let type := ((special.getArg 0).getArg 1);
idents? ← getFunBinderIds? idents;
match idents? with
| some idents => expandFunBindersAux body (i+1) (newBinders ++ idents.map (fun ident => mkExplicitBinder ident type))
| none => processAsPattern ()
| binder =>
match binder.isTermId? true with
| some (ident, extra) => do
unless extra.isNone $ throwError binder "invalid binder, simple identifier expected";
let type := mkHole binder;
expandFunBindersAux body (i+1) (newBinders.push $ mkExplicitBinder ident type)
| none => processAsPattern ()
else
pure (newBinders, body)
/--
Auxiliary function for expanding `fun` notation binders. Recall that `fun` parser is defined as
```
def funBinder : Parser := implicitBinder <|> instBinder <|> termParser maxPrec
parser! unicodeSymbol "λ" "fun" >> many1 funBinder >> "=>" >> termParser
```
to allow notation such as `fun (a, b) => a + b`, where `(a, b)` should be treated as a pattern.
The result is a pair `(explicitBinders, newBody)`, where `explicitBinders` is syntax of the form
```
`(` ident `:` term `)`
```
which can be elaborated using `elabBinders`, and `newBody` is the updated `body` syntax.
We update the `body` syntax when expanding the pattern notation.
Example: `fun (a, b) => a + b` expands into `fun _a_1 => match _a_1 with | (a, b) => a + b`.
See local function `processAsPattern` at `expandFunBindersAux`. -/
def expandFunBinders (binders : Array Syntax) (body : Syntax) : TermElabM (Array Syntax × Syntax) :=
expandFunBindersAux binders body 0 #[]
namespace FunBinders
structure State :=
(fvars : Array Expr := #[])
(lctx : LocalContext)
(localInsts : LocalInstances)
(expectedType? : Option Expr := none)
private def checkNoOptAutoParam (ref : Syntax) (type : Expr) : TermElabM Unit := do
type ← instantiateMVars ref type;
when type.isOptParam $
throwError ref "optParam is not allowed at 'fun/λ' binders";
when type.isAutoParam $
throwError ref "autoParam is not allowed at 'fun/λ' binders"
private def propagateExpectedType (ref : Syntax) (fvar : Expr) (fvarType : Expr) (s : State) : TermElabM State := do
match s.expectedType? with
| none => pure s
| some expectedType => do
expectedType ← whnfForall ref expectedType;
match expectedType with
| Expr.forallE _ d b _ => do
_ ← isDefEq ref fvarType d;
checkNoOptAutoParam ref fvarType;
let b := b.instantiate1 fvar;
pure { s with expectedType? := some b }
| _ => pure { s with expectedType? := none }
private partial def elabFunBinderViews (binderViews : Array BinderView) : Nat → State → TermElabM State
| i, s =>
if h : i < binderViews.size then
let binderView := binderViews.get ⟨i, h⟩;
withLCtx s.lctx s.localInsts $ do
type ← elabType binderView.type;
checkNoOptAutoParam binderView.type type;
fvarId ← mkFreshFVarId;
let fvar := mkFVar fvarId;
let s := { s with fvars := s.fvars.push fvar };
-- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type);
/-
We do **not** want to support default and auto arguments in lambda abstractions.
Example: `fun (x : Nat := 10) => x+1`.
We do not believe this is an useful feature, and it would complicate the logic here.
-/
let lctx := s.lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi;
s ← propagateExpectedType binderView.id fvar type s;
let s := { s with lctx := lctx };
className? ← isClass binderView.type type;
match className? with
| none => elabFunBinderViews (i+1) s
| some className => do
resetSynthInstanceCache;
let localInsts := s.localInsts.push { className := className, fvar := mkFVar fvarId };
elabFunBinderViews (i+1) { s with localInsts := localInsts }
else
pure s
partial def elabFunBindersAux (binders : Array Syntax) : Nat → State → TermElabM State
| i, s =>
if h : i < binders.size then do
binderViews ← matchBinder (binders.get ⟨i, h⟩);
s ← elabFunBinderViews binderViews 0 s;
elabFunBindersAux (i+1) s
else
pure s
end FunBinders
def elabFunBinders {α} (binders : Array Syntax) (expectedType? : Option Expr) (x : Array Expr → Option Expr → TermElabM α) : TermElabM α :=
if binders.isEmpty then x #[] expectedType?
else do
lctx ← getLCtx;
localInsts ← getLocalInsts;
s ← FunBinders.elabFunBindersAux binders 0 { lctx := lctx, localInsts := localInsts, expectedType? := expectedType? };
resettingSynthInstanceCacheWhen (s.localInsts.size > localInsts.size) $ withLCtx s.lctx s.localInsts $
x s.fvars s.expectedType?
@[builtinTermElab «fun»] def elabFun : TermElab :=
fun stx expectedType? => do
-- `fun` term+ `=>` term
let binders := (stx.getArg 1).getArgs;
let body := stx.getArg 3;
(binders, body) ← expandFunBinders binders body;
elabFunBinders binders expectedType? $ fun xs expectedType? => do {
e ← elabTerm body expectedType?;
mkLambda stx xs e
}
/-
Recall that
```
def typeSpec := parser! " : " >> termParser
def optType : Parser := optional typeSpec
``` -/
def expandOptType (ref : Syntax) (optType : Syntax) : Syntax :=
if optType.isNone then
mkHole ref
else
(optType.getArg 0).getArg 1
/- If `useLetExpr` is true, then a kernel let-expression `let x : type := val; body` is created.
Otherwise, we create a term of the form `(fun (x : type) => body) val` -/
def elabLetDeclAux (ref : Syntax) (n : Name) (binders : Array Syntax) (typeStx : Syntax) (valStx : Syntax) (body : Syntax)
(expectedType? : Option Expr) (useLetExpr : Bool) : TermElabM Expr := do
(type, val) ← elabBinders binders $ fun xs => do {
type ← elabType typeStx;
val ← elabTerm valStx type;
val ← ensureHasType valStx type val;
type ← mkForall ref xs type;
val ← mkLambda ref xs val;
pure (type, val)
};
trace `Elab.let.decl ref $ fun _ => n ++ " : " ++ type ++ " := " ++ val;
if useLetExpr then
withLetDecl ref n type val $ fun x => do
body ← elabTerm body expectedType?;
body ← instantiateMVars ref body;
mkLet ref x body
else do
f ← withLocalDecl ref n BinderInfo.default type $ fun x => do {
body ← elabTerm body expectedType?;
body ← instantiateMVars ref body;
mkLambda ref #[x] body
};
pure $ mkApp f val
@[builtinTermElab «let»] def elabLetDecl : TermElab :=
fun stx expectedType? => match_syntax stx with
| `(let $id:ident $args* := $val; $body) =>
elabLetDeclAux stx id.getId args (mkHole stx) val body expectedType? true
| `(let $id:ident $args* : $type := $val; $body) =>
elabLetDeclAux stx id.getId args type val body expectedType? true
| `(let $pat:term := $val; $body) => do
stxNew ← `(let x := $val; match x with $pat => $body);
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
| `(let $pat:term : $type := $val; $body) => do
stxNew ← `(let x : $type := $val; match x with $pat => $body);
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
| _ => throwUnsupportedSyntax
@[builtinTermElab «let!»] def elabLetBangDecl : TermElab :=
fun stx expectedType? => match_syntax stx with
| `(let! $id:ident $args* := $val; $body) =>
elabLetDeclAux stx id.getId args (mkHole stx) val body expectedType? false
| `(let! $id:ident $args* : $type := $val; $body) =>
elabLetDeclAux stx id.getId args type val body expectedType? false
| `(let! $pat:term := $val; $body) => do
stxNew ← `(let! x := $val; match x with $pat => $body);
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
| `(let! $pat:term : $type := $val; $body) => do
stxNew ← `(let! x : $type := $val; match x with $pat => $body);
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
| _ => throwUnsupportedSyntax
@[init] private def regTraceClasses : IO Unit := do
registerTraceClass `Elab.let;
pure ()
end Term
end Elab
end Lean