From 41baf46083e4e6b0307387924431aee092b604e8 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 11 Feb 2020 17:47:51 -0800 Subject: [PATCH] feat: propagate expected type at `elabFunCore` --- src/Init/Lean/Elab/Binders.lean | 37 +++++++++++++++++++++++--------- tests/lean/run/newfrontend3.lean | 10 +++++++++ 2 files changed, 37 insertions(+), 10 deletions(-) create mode 100644 tests/lean/run/newfrontend3.lean diff --git a/src/Init/Lean/Elab/Binders.lean b/src/Init/Lean/Elab/Binders.lean index 34245dec3d..da57080cd5 100644 --- a/src/Init/Lean/Elab/Binders.lean +++ b/src/Init/Lean/Elab/Binders.lean @@ -85,9 +85,22 @@ match stx with | _ => throwUnsupportedSyntax structure State := -(fvars : Array Expr := #[]) -(lctx : LocalContext) -(localInsts : LocalInstances) +(fvars : Array Expr := #[]) +(lctx : LocalContext) +(localInsts : LocalInstances) +(expectedType? : Option Expr := none) + +private def propagateExpectedType (ref : Syntax) (fvar : Expr) (fvarType : Expr) (s : State) : TermElabM State := do +match s.expectedType? with +| none => pure s +| some expectedType => do + expectedType ← whnfForall ref expectedType; + match expectedType with + | Expr.forallE _ d b _ => do + isDefEq ref fvarType d; + let b := b.instantiate1 fvar; + pure { expectedType? := some b, .. s } + | _ => pure { expectedType? := none, .. s } private partial def elabBinderViews (binderViews : Array BinderView) : Nat → State → TermElabM State | i, s => @@ -100,13 +113,14 @@ private partial def elabBinderViews (binderViews : Array BinderView) : Nat → S let fvars := s.fvars.push fvar; -- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type); let lctx := s.lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi; + s ← propagateExpectedType binderView.id fvar type s; className? ← isClass binderView.type type; match className? with | none => elabBinderViews (i+1) { fvars := fvars, lctx := lctx, .. s } | some className => do resetSynthInstanceCache; let localInsts := s.localInsts.push { className := className, fvar := mkFVar fvarId }; - elabBinderViews (i+1) { fvars := fvars, lctx := lctx, localInsts := localInsts } + elabBinderViews (i+1) { fvars := fvars, lctx := lctx, localInsts := localInsts, .. s } else pure s @@ -125,14 +139,17 @@ end Binders Elaborate the given binders (i.e., `Syntax` objects for `simpleBinder <|> bracktedBinder`), update the local context, set of local instances, reset instance chache (if needed), and then execute `x` with the updated context. -/ -def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α := -if binders.isEmpty then x #[] +def elabBindersWithExpectedType {α} (binders : Array Syntax) (expectedType? : Option Expr) (x : Array Expr → Option Expr → TermElabM α) : TermElabM α := +if binders.isEmpty then x #[] expectedType? else do lctx ← getLCtx; localInsts ← getLocalInsts; - s ← Binders.elabBindersAux binders 0 { lctx := lctx, localInsts := localInsts }; + s ← Binders.elabBindersAux binders 0 { lctx := lctx, localInsts := localInsts, expectedType? := expectedType? }; resettingSynthInstanceCacheWhen (s.localInsts.size > localInsts.size) $ withLCtx s.lctx s.localInsts $ - x s.fvars + x s.fvars s.expectedType? + +def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α := +elabBindersWithExpectedType binders none $ fun args _ => x args @[inline] def elabBinder {α} (binder : Syntax) (x : Expr → TermElabM α) : TermElabM α := elabBinders #[binder] (fun fvars => x (fvars.get! 1)) @@ -249,8 +266,8 @@ def elabFunCore (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := let binders := (stx.getArg 1).getArgs; let body := stx.getArg 3; (binders, body) ← expandFunBinders binders body; -elabBinders binders $ fun xs => do { - e ← elabTerm body none; +elabBindersWithExpectedType binders expectedType? $ fun xs expectedType? => do { + e ← elabTerm body expectedType?; mkLambda stx xs e } diff --git a/tests/lean/run/newfrontend3.lean b/tests/lean/run/newfrontend3.lean new file mode 100644 index 0000000000..71a6402ac0 --- /dev/null +++ b/tests/lean/run/newfrontend3.lean @@ -0,0 +1,10 @@ +structure S := +(g {α} : α → α) + +def f (h : Nat → (forall {α : Type}, α → α) × Bool) : Nat := +(h 0).1 1 + +new_frontend + +def tst : Nat := +f (fun n => (fun x => x, true))