lean4-htt/src/Lean/Elab/DoNotation.lean
2020-06-25 11:21:17 -07:00

236 lines
8.7 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Elab.Term
import Lean.Elab.Binders
import Lean.Elab.Quotation
namespace Lean
namespace Elab
namespace Term
structure ExtractMonadResult :=
(m : Expr)
(α : Expr)
(hasBindInst : Expr)
private def mkIdBindFor (ref : Syntax) (type : Expr) : TermElabM ExtractMonadResult := do
u ← getLevel ref type;
let id := Lean.mkConst `Id [u];
let idBindVal := Lean.mkConst `Id.hasBind [u];
pure { m := id, hasBindInst := idBindVal, α := type }
private def extractBind (ref : Syntax) (expectedType? : Option Expr) : TermElabM ExtractMonadResult := do
match expectedType? with
| none => throwError ref "invalid do notation, expected type is not available"
| some expectedType => do
type ← withReducible $ whnf ref expectedType;
when type.getAppFn.isMVar $ throwError ref "invalid do notation, expected type is not available";
match type with
| Expr.app m α _ =>
catch
(do
bindInstType ← mkAppM ref `HasBind #[m];
bindInstVal ← synthesizeInst ref bindInstType;
pure { m := m, hasBindInst := bindInstVal, α := α })
(fun ex => mkIdBindFor ref type)
| _ => mkIdBindFor ref type
private def getDoElems (stx : Syntax) : Array Syntax :=
--parser! "do " >> (bracketedDoSeq <|> doSeq)
let arg := stx.getArg 1;
if arg.getKind == `Lean.Parser.Term.bracketedDoSeq then
-- parser! "{" >> doSeq >> "}"
(arg.getArg 1).getArgs
else
arg.getArgs
private partial def hasLiftMethod : Syntax → Bool
| Syntax.node k args =>
if k == `Lean.Parser.Term.do then false
else if k == `Lean.Parser.Term.liftMethod then true
else args.any hasLiftMethod
| _ => false
private partial def expandLiftMethodAux : Syntax → StateT (Array Syntax) MacroM Syntax
| stx@(Syntax.node k args) =>
if k == `Lean.Parser.Term.do then pure stx
else if k == `Lean.Parser.Term.liftMethod then withFreshMacroScope $ do
let term := args.get! 1;
term ← expandLiftMethodAux term;
auxDo ← `(do a ← $term; $(Syntax.missing));
let auxDoElems := (getDoElems auxDo).pop;
modify $ fun s => s ++ auxDoElems;
`(a)
else do
args ← args.mapM expandLiftMethodAux;
pure $ Syntax.node k args
| stx => pure stx
private def expandLiftMethod (stx : Syntax) : MacroM (Option (Array Syntax)) :=
if hasLiftMethod stx then do
(stx, doElems) ← (expandLiftMethodAux stx).run #[];
let doElems := doElems.push stx;
pure doElems
else
pure none
/- Expand `doLet`, `doPat`, nonterminal `doExpr`s, and `liftMethod` -/
private partial def expandDoElemsAux : Bool → Array Syntax → Nat → MacroM (Option Syntax)
| modified, doElems, i =>
let mkRest : Unit → MacroM Syntax := fun _ => do {
let restElems := doElems.extract (i+2) doElems.size;
if restElems.size == 1 then
pure $ (restElems.get! 0).getArg 0
else
`(do { $restElems* })
};
let addPrefix (rest : Syntax) : MacroM (Option Syntax) := do {
if i == 0 then
pure rest
else
let newElems := doElems.extract 0 i;
let newElems := newElems.push $ Syntax.node `Lean.Parser.Term.doExpr #[rest];
`(do { $newElems* })
};
if h : i < doElems.size then do
let doElem := doElems.get ⟨i, h⟩;
doElemsNew? ← expandLiftMethod doElem;
match doElemsNew? with
| some doElemsNew => do
let post := doElems.extract (i+1) doElems.size;
let pre := doElems.extract 0 i;
let doElems := pre ++ doElemsNew ++ post;
tmp ← `(do { $doElems* });
expandDoElemsAux true doElems i
| none =>
if doElem.getKind == `Lean.Parser.Term.doLet then do
let letDecl := doElem.getArg 1;
rest ← mkRest ();
newBody ← `(let $letDecl:letDecl; $rest);
addPrefix newBody
else if doElem.getKind == `Lean.Parser.Term.doPat then withFreshMacroScope $ do
-- (termParser >> leftArrow) >> termParser >> optional (" | " >> termParser)
let pat := doElem.getArg 0;
let discr := doElem.getArg 2;
let optElse := doElem.getArg 3;
rest ← mkRest ();
newBody ←
if optElse.isNone then do
`(do x ← $discr; match x with | $pat => $rest)
else
let elseBody := optElse.getArg 1;
`(do x ← $discr; match x with | $pat => $rest | _ => $elseBody);
addPrefix newBody
else if i < doElems.size - 1 && doElem.getKind == `Lean.Parser.Term.doExpr then do
-- def doExpr := parser! termParser
let term := doElem.getArg 0;
auxDo ← `(do x ← $term; $(Syntax.missing));
let doElemNew := (getDoElems auxDo).get! 0;
let doElems := doElems.set! i doElemNew;
expandDoElemsAux true doElems (i+2)
else
expandDoElemsAux modified doElems (i+2)
else if modified then
`(do { $doElems* })
else
pure none
private partial def expandDoElems (doElems : Array Syntax) : MacroM (Option Syntax) :=
expandDoElemsAux false doElems 0
structure ProcessedDoElem :=
(action : Expr)
(var : Expr)
instance ProcessedDoElem.inhabited : Inhabited ProcessedDoElem := ⟨⟨arbitrary _, arbitrary _⟩⟩
private def extractTypeFormerAppArg (ref : Syntax) (type : Expr) : TermElabM Expr := do
type ← withReducible $ whnf ref type;
match type with
| Expr.app _ a _ => pure a
| _ => throwError ref ("type former application expected" ++ indentExpr type)
/-
HasBind.bind : ∀ {m : Type u_1 → Type u_2} [self : HasBind m] {α β : Type u_1}, m α → (α → m β) → m β
-/
private def mkBind (ref : Syntax) (m bindInstVal : Expr) (elems : Array ProcessedDoElem) (body : Expr) : TermElabM Expr :=
if elems.isEmpty then
pure body
else do
let x := elems.back.var; -- any variable would work since they must be in the same universe
xType ← inferType ref x;
u_1 ← getDecLevel ref xType;
bodyType ← inferType ref body;
u_2 ← getDecLevel ref bodyType;
let bindAndInst := mkApp2 (Lean.mkConst `HasBind.bind [u_1, u_2]) m bindInstVal;
elems.foldrM
(fun elem body => do
-- dbgTrace (">>> " ++ toString body);
let var := elem.var;
let action := elem.action;
α ← inferType ref var;
mβ ← inferType ref body;
β ← extractTypeFormerAppArg ref mβ;
f ← mkLambda ref #[var] body;
-- dbgTrace (">>> f: " ++ toString f);
let body := mkAppN bindAndInst #[α, β, action, f];
pure body)
body
private partial def processDoElemsAux (doElems : Array Syntax) (m bindInstVal : Expr) (expectedType : Expr) : Nat → Array ProcessedDoElem → TermElabM Expr
| i, elems =>
let doElem := doElems.get! i;
let k := doElem.getKind;
let ref := doElem;
if k == `Lean.Parser.Term.doId then do
when (i == doElems.size - 1) $
throwError ref "the last statement in a 'do' block must be an expression";
-- try (ident >> optType >> leftArrow) >> termParser
let id := doElem.getIdAt 0;
let typeStx := expandOptType ref (doElem.getArg 1);
let actionStx := doElem.getArg 3;
type ← elabType typeStx;
let actionExpectedType := mkApp m type;
action ← elabTerm actionStx actionExpectedType;
action ← ensureHasType actionStx actionExpectedType action;
withLocalDecl ref id BinderInfo.default type $ fun x =>
processDoElemsAux (i+1) (elems.push { action := action, var := x })
else if doElem.getKind == `Lean.Parser.Term.doExpr then do
when (i != doElems.size - 1) $
throwError ref ("unexpected 'do' expression element" ++ Format.line ++ doElem);
let bodyStx := doElem.getArg 0;
body ← elabTerm bodyStx expectedType;
body ← ensureHasType ref expectedType body;
mkBind ref m bindInstVal elems body
else
throwError ref ("unexpected 'do' expression element" ++ Format.line ++ doElem)
private def processDoElems (doElems : Array Syntax) (m bindInstVal : Expr) (expectedType : Expr) : TermElabM Expr :=
processDoElemsAux doElems m bindInstVal expectedType 0 #[]
@[builtinTermElab «do»] def elabDo : TermElab :=
fun stx expectedType? => do
let ref := stx;
tryPostponeIfNoneOrMVar expectedType?;
let doElems := getDoElems stx;
stxNew? ← liftMacroM $ expandDoElems doElems;
match stxNew? with
| some stxNew => withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
| none => do
trace `Elab.do ref $ fun _ => stx;
let doElems := doElems.getSepElems;
{ m := m, hasBindInst := bindInstVal, .. } ← extractBind ref expectedType?;
result ← processDoElems doElems m bindInstVal expectedType?.get!;
-- dbgTrace ("result: " ++ toString result);
pure result
@[init] private def regTraceClasses : IO Unit := do
registerTraceClass `Elab.do;
pure ()
end Term
end Elab
end Lean