feat: propagate expected type at elabFunCore
This commit is contained in:
parent
abd0f54ce6
commit
41baf46083
2 changed files with 37 additions and 10 deletions
|
|
@ -85,9 +85,22 @@ match stx with
|
|||
| _ => throwUnsupportedSyntax
|
||||
|
||||
structure State :=
|
||||
(fvars : Array Expr := #[])
|
||||
(lctx : LocalContext)
|
||||
(localInsts : LocalInstances)
|
||||
(fvars : Array Expr := #[])
|
||||
(lctx : LocalContext)
|
||||
(localInsts : LocalInstances)
|
||||
(expectedType? : Option Expr := none)
|
||||
|
||||
private def propagateExpectedType (ref : Syntax) (fvar : Expr) (fvarType : Expr) (s : State) : TermElabM State := do
|
||||
match s.expectedType? with
|
||||
| none => pure s
|
||||
| some expectedType => do
|
||||
expectedType ← whnfForall ref expectedType;
|
||||
match expectedType with
|
||||
| Expr.forallE _ d b _ => do
|
||||
isDefEq ref fvarType d;
|
||||
let b := b.instantiate1 fvar;
|
||||
pure { expectedType? := some b, .. s }
|
||||
| _ => pure { expectedType? := none, .. s }
|
||||
|
||||
private partial def elabBinderViews (binderViews : Array BinderView) : Nat → State → TermElabM State
|
||||
| i, s =>
|
||||
|
|
@ -100,13 +113,14 @@ private partial def elabBinderViews (binderViews : Array BinderView) : Nat → S
|
|||
let fvars := s.fvars.push fvar;
|
||||
-- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type);
|
||||
let lctx := s.lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi;
|
||||
s ← propagateExpectedType binderView.id fvar type s;
|
||||
className? ← isClass binderView.type type;
|
||||
match className? with
|
||||
| none => elabBinderViews (i+1) { fvars := fvars, lctx := lctx, .. s }
|
||||
| some className => do
|
||||
resetSynthInstanceCache;
|
||||
let localInsts := s.localInsts.push { className := className, fvar := mkFVar fvarId };
|
||||
elabBinderViews (i+1) { fvars := fvars, lctx := lctx, localInsts := localInsts }
|
||||
elabBinderViews (i+1) { fvars := fvars, lctx := lctx, localInsts := localInsts, .. s }
|
||||
else
|
||||
pure s
|
||||
|
||||
|
|
@ -125,14 +139,17 @@ end Binders
|
|||
Elaborate the given binders (i.e., `Syntax` objects for `simpleBinder <|> bracktedBinder`),
|
||||
update the local context, set of local instances, reset instance chache (if needed), and then
|
||||
execute `x` with the updated context. -/
|
||||
def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α :=
|
||||
if binders.isEmpty then x #[]
|
||||
def elabBindersWithExpectedType {α} (binders : Array Syntax) (expectedType? : Option Expr) (x : Array Expr → Option Expr → TermElabM α) : TermElabM α :=
|
||||
if binders.isEmpty then x #[] expectedType?
|
||||
else do
|
||||
lctx ← getLCtx;
|
||||
localInsts ← getLocalInsts;
|
||||
s ← Binders.elabBindersAux binders 0 { lctx := lctx, localInsts := localInsts };
|
||||
s ← Binders.elabBindersAux binders 0 { lctx := lctx, localInsts := localInsts, expectedType? := expectedType? };
|
||||
resettingSynthInstanceCacheWhen (s.localInsts.size > localInsts.size) $ withLCtx s.lctx s.localInsts $
|
||||
x s.fvars
|
||||
x s.fvars s.expectedType?
|
||||
|
||||
def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α :=
|
||||
elabBindersWithExpectedType binders none $ fun args _ => x args
|
||||
|
||||
@[inline] def elabBinder {α} (binder : Syntax) (x : Expr → TermElabM α) : TermElabM α :=
|
||||
elabBinders #[binder] (fun fvars => x (fvars.get! 1))
|
||||
|
|
@ -249,8 +266,8 @@ def elabFunCore (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr :=
|
|||
let binders := (stx.getArg 1).getArgs;
|
||||
let body := stx.getArg 3;
|
||||
(binders, body) ← expandFunBinders binders body;
|
||||
elabBinders binders $ fun xs => do {
|
||||
e ← elabTerm body none;
|
||||
elabBindersWithExpectedType binders expectedType? $ fun xs expectedType? => do {
|
||||
e ← elabTerm body expectedType?;
|
||||
mkLambda stx xs e
|
||||
}
|
||||
|
||||
|
|
|
|||
10
tests/lean/run/newfrontend3.lean
Normal file
10
tests/lean/run/newfrontend3.lean
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
structure S :=
|
||||
(g {α} : α → α)
|
||||
|
||||
def f (h : Nat → (forall {α : Type}, α → α) × Bool) : Nat :=
|
||||
(h 0).1 1
|
||||
|
||||
new_frontend
|
||||
|
||||
def tst : Nat :=
|
||||
f (fun n => (fun x => x, true))
|
||||
Loading…
Add table
Reference in a new issue