feat: elaborate lambda abstractions

This commit is contained in:
Leonardo de Moura 2019-12-17 11:44:40 -08:00
parent d9d1c67d86
commit 3d6146756f
2 changed files with 72 additions and 0 deletions

View file

@ -181,6 +181,7 @@ match type? with
| none => liftMetaM ref $ do u ← Meta.mkFreshLevelMVar; Meta.mkFreshExprMVar (mkSort u) userName? kind
def getLevel (ref : Syntax) (type : Expr) : TermElabM Level := liftMetaM ref $ Meta.getLevel type
def mkForall (ref : Syntax) (xs : Array Expr) (e : Expr) : TermElabM Expr := liftMetaM ref $ Meta.mkForall xs e
def mkLambda (ref : Syntax) (xs : Array Expr) (e : Expr) : TermElabM Expr := liftMetaM ref $ Meta.mkLambda xs e
def trySynthInstance (ref : Syntax) (type : Expr) : TermElabM (LOption Expr) := liftMetaM ref $ Meta.trySynthInstance type
def mkAppM (ref : Syntax) (constName : Name) (args : Array Expr) : TermElabM Expr := liftMetaM ref $ Meta.mkAppM constName args
def decLevel? (ref : Syntax) (u : Level) : TermElabM (Option Level) := liftMetaM ref $ Meta.decLevel? u
@ -436,6 +437,76 @@ fun stx _ =>
e ← elabType term;
mkForall stx.val xs e
/-
Auxiliary functions for converting `Term.app ... (Term.app id_1 id_2) ... id_n` into #[id_1, ..., id_m]`
It is used at `expandFunBinders`. -/
partial def getFunBinderIdsAux? : Bool → Syntax → Array Syntax → TermElabM (Option (Array Syntax))
| false, Syntax.node `Lean.Parser.Term.app args, acc => do
(some acc) ← getFunBinderIdsAux? false (args.getA 0) acc | pure none;
getFunBinderIdsAux? true (args.getA 1) acc
| _, Syntax.node `Lean.Parser.Term.id args, acc =>
if (args.getA 1).isNone then
pure (some (acc.push (args.getA 0)))
else
pure none
| _, _, _ => pure none
def getFunBinderIds? (stx : Syntax) : TermElabM (Option (Array Syntax)) :=
getFunBinderIdsAux? false stx #[]
partial def expandFunBindersAux (binders : Array Syntax) : Syntax → Nat → Array Syntax → TermElabM (Array Syntax × Syntax)
| body, i, newBinders =>
if h : i < binders.size then
let binder := binders.get ⟨i, h⟩;
let processAsPattern : Unit → TermElabM (Array Syntax × Syntax) := fun _ => do {
throwError binder "not implemented yet"
};
match binder with
| Syntax.node `Lean.Parser.Term.id args => do
unless (args.getA 1).isNone $ throwError binder "invalid binder, simple identifier expected";
let id := args.getA 0;
let type := mkHole;
expandFunBindersAux body (i+1) (newBinders.push $ mkExplicitBinder id type)
| Syntax.node `Lean.Parser.Term.hole _ => do
id ← mkFreshAnonymousName;
let id := mkIdentFrom binder id;
let type := binder;
expandFunBindersAux body (i+1) (newBinders.push $ mkExplicitBinder id type)
| Syntax.node `Lean.Parser.Term.paren args =>
-- `(` (termParser >> parenSpecial)? `)`
-- parenSpecial := (tupleTail <|> typeAscription)?
let binderBody := binder.getArg 1;
if binderBody.isNone then processAsPattern ()
else
let ids := binderBody.getArg 0;
let special := binderBody.getArg 1;
if special.isNone then processAsPattern ()
else if (special.getArg 0).getKind != `Lean.Parser.Term.typeAscription then processAsPattern ()
else do
-- typeAscription := `:` term
let type := ((special.getArg 0).getArg 1);
ids? ← getFunBinderIds? ids;
match ids? with
| some ids => expandFunBindersAux body (i+1) (newBinders ++ ids.map (fun id => mkExplicitBinder id type))
| none => processAsPattern ()
| _ => processAsPattern ()
else
pure (newBinders, body)
def expandFunBinders (binders : Array Syntax) (body : Syntax) : TermElabM (Array Syntax × Syntax) :=
expandFunBindersAux binders body 0 #[]
@[builtinTermElab «fun»] def elabFun : TermElab :=
fun stx expectedType? => do
-- `fun` term+ `=>` term
let binders := (stx.getArg 1).getArgs;
let body := stx.getArg 3;
(binders, body) ← expandFunBinders binders body;
elabBinders binders $ fun xs => do
-- TODO: expected type
e ← elabTerm body none;
mkLambda stx.val xs e
@[builtinTermElab paren] def elabParen : TermElab :=
fun stx expectedType? =>
-- `(` (termParser >> parenSpecial)? `)`

View file

@ -125,3 +125,4 @@ def m : Monoid Nat :=
#eval run "#check \"hello\""
#eval run "#check 1"
#eval run "#check Nat.succ 1"
#eval run "#check fun _ a (x y : Int) => x + y + a"