feat: propagate expected type at elabFunCore

This commit is contained in:
Leonardo de Moura 2020-02-11 17:47:51 -08:00
parent abd0f54ce6
commit 41baf46083
2 changed files with 37 additions and 10 deletions

View file

@ -85,9 +85,22 @@ match stx with
| _ => throwUnsupportedSyntax
structure State :=
(fvars : Array Expr := #[])
(lctx : LocalContext)
(localInsts : LocalInstances)
(fvars : Array Expr := #[])
(lctx : LocalContext)
(localInsts : LocalInstances)
(expectedType? : Option Expr := none)
private def propagateExpectedType (ref : Syntax) (fvar : Expr) (fvarType : Expr) (s : State) : TermElabM State := do
match s.expectedType? with
| none => pure s
| some expectedType => do
expectedType ← whnfForall ref expectedType;
match expectedType with
| Expr.forallE _ d b _ => do
isDefEq ref fvarType d;
let b := b.instantiate1 fvar;
pure { expectedType? := some b, .. s }
| _ => pure { expectedType? := none, .. s }
private partial def elabBinderViews (binderViews : Array BinderView) : Nat → State → TermElabM State
| i, s =>
@ -100,13 +113,14 @@ private partial def elabBinderViews (binderViews : Array BinderView) : Nat → S
let fvars := s.fvars.push fvar;
-- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type);
let lctx := s.lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi;
s ← propagateExpectedType binderView.id fvar type s;
className? ← isClass binderView.type type;
match className? with
| none => elabBinderViews (i+1) { fvars := fvars, lctx := lctx, .. s }
| some className => do
resetSynthInstanceCache;
let localInsts := s.localInsts.push { className := className, fvar := mkFVar fvarId };
elabBinderViews (i+1) { fvars := fvars, lctx := lctx, localInsts := localInsts }
elabBinderViews (i+1) { fvars := fvars, lctx := lctx, localInsts := localInsts, .. s }
else
pure s
@ -125,14 +139,17 @@ end Binders
Elaborate the given binders (i.e., `Syntax` objects for `simpleBinder <|> bracktedBinder`),
update the local context, set of local instances, reset instance chache (if needed), and then
execute `x` with the updated context. -/
def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α :=
if binders.isEmpty then x #[]
def elabBindersWithExpectedType {α} (binders : Array Syntax) (expectedType? : Option Expr) (x : Array Expr → Option Expr → TermElabM α) : TermElabM α :=
if binders.isEmpty then x #[] expectedType?
else do
lctx ← getLCtx;
localInsts ← getLocalInsts;
s ← Binders.elabBindersAux binders 0 { lctx := lctx, localInsts := localInsts };
s ← Binders.elabBindersAux binders 0 { lctx := lctx, localInsts := localInsts, expectedType? := expectedType? };
resettingSynthInstanceCacheWhen (s.localInsts.size > localInsts.size) $ withLCtx s.lctx s.localInsts $
x s.fvars
x s.fvars s.expectedType?
def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α :=
elabBindersWithExpectedType binders none $ fun args _ => x args
@[inline] def elabBinder {α} (binder : Syntax) (x : Expr → TermElabM α) : TermElabM α :=
elabBinders #[binder] (fun fvars => x (fvars.get! 1))
@ -249,8 +266,8 @@ def elabFunCore (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr :=
let binders := (stx.getArg 1).getArgs;
let body := stx.getArg 3;
(binders, body) ← expandFunBinders binders body;
elabBinders binders $ fun xs => do {
e ← elabTerm body none;
elabBindersWithExpectedType binders expectedType? $ fun xs expectedType? => do {
e ← elabTerm body expectedType?;
mkLambda stx xs e
}

View file

@ -0,0 +1,10 @@
structure S :=
(g {α} : αα)
def f (h : Nat → (forall {α : Type}, αα) × Bool) : Nat :=
(h 0).1 1
new_frontend
def tst : Nat :=
f (fun n => (fun x => x, true))