refactor: decouple elaboration for forall and lambda binders

Motivation: another refactoring to improve `elabFunCore`.
This commit is contained in:
Leonardo de Moura 2020-02-12 09:11:20 -08:00
parent fae52a7ba6
commit 6429486f88

View file

@ -9,7 +9,7 @@ import Init.Lean.Elab.Term
namespace Lean
namespace Elab
namespace Term
namespace Binders
/--
Given syntax of the forms
a) (`:` term)?
@ -37,6 +37,7 @@ else
structure BinderView :=
(id : Syntax) (type : Syntax) (bi : BinderInfo)
/-
Expand `optional (binderDefault <|> binderTactic)`
def binderDefault := parser! " := " >> termParser
@ -84,72 +85,50 @@ match stx with
throwUnsupportedSyntax
| _ => throwUnsupportedSyntax
structure State :=
(fvars : Array Expr := #[])
(lctx : LocalContext)
(localInsts : LocalInstances)
(expectedType? : Option Expr := none)
private def propagateExpectedType (ref : Syntax) (fvar : Expr) (fvarType : Expr) (s : State) : TermElabM State := do
match s.expectedType? with
| none => pure s
| some expectedType => do
expectedType ← whnfForall ref expectedType;
match expectedType with
| Expr.forallE _ d b _ => do
isDefEq ref fvarType d;
let b := b.instantiate1 fvar;
pure { expectedType? := some b, .. s }
| _ => pure { expectedType? := none, .. s }
private partial def elabBinderViews (binderViews : Array BinderView) : Nat → State → TermElabM State
| i, s =>
private partial def elabBinderViews (binderViews : Array BinderView)
: Nat → Array Expr → LocalContext → LocalInstances → TermElabM (Array Expr × LocalContext × LocalInstances)
| i, fvars, lctx, localInsts =>
if h : i < binderViews.size then
let binderView := binderViews.get ⟨i, h⟩;
withLCtx s.lctx s.localInsts $ do
withLCtx lctx localInsts $ do
type ← elabType binderView.type;
fvarId ← mkFreshFVarId;
let fvar := mkFVar fvarId;
let fvars := s.fvars.push fvar;
let fvars := fvars.push fvar;
-- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type);
let lctx := s.lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi;
s ← propagateExpectedType binderView.id fvar type s;
let lctx := lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi;
className? ← isClass binderView.type type;
match className? with
| none => elabBinderViews (i+1) { fvars := fvars, lctx := lctx, .. s }
| none => elabBinderViews (i+1) fvars lctx localInsts
| some className => do
resetSynthInstanceCache;
let localInsts := s.localInsts.push { className := className, fvar := mkFVar fvarId };
elabBinderViews (i+1) { fvars := fvars, lctx := lctx, localInsts := localInsts, .. s }
let localInsts := localInsts.push { className := className, fvar := mkFVar fvarId };
elabBinderViews (i+1) fvars lctx localInsts
else
pure s
pure (fvars, lctx, localInsts)
partial def elabBindersAux (binders : Array Syntax) : Nat → State → TermElabM State
| i, s =>
private partial def elabBindersAux (binders : Array Syntax)
: Nat → Array Expr → LocalContext → LocalInstances → TermElabM (Array Expr × LocalContext × LocalInstances)
| i, fvars, lctx, localInsts =>
if h : i < binders.size then do
binderViews ← matchBinder (binders.get ⟨i, h⟩);
s ← elabBinderViews binderViews 0 s;
elabBindersAux (i+1) s
(fvars, lctx, localInsts) ← elabBinderViews binderViews 0 fvars lctx localInsts;
elabBindersAux (i+1) fvars lctx localInsts
else
pure s
end Binders
pure (fvars, lctx, localInsts)
/--
Elaborate the given binders (i.e., `Syntax` objects for `simpleBinder <|> bracktedBinder`),
update the local context, set of local instances, reset instance chache (if needed), and then
execute `x` with the updated context. -/
def elabBindersWithExpectedType {α} (binders : Array Syntax) (expectedType? : Option Expr) (x : Array Expr → Option Expr → TermElabM α) : TermElabM α :=
if binders.isEmpty then x #[] expectedType?
def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α :=
if binders.isEmpty then x #[]
else do
lctx ← getLCtx;
localInsts ← getLocalInsts;
s ← Binders.elabBindersAux binders 0 { lctx := lctx, localInsts := localInsts, expectedType? := expectedType? };
resettingSynthInstanceCacheWhen (s.localInsts.size > localInsts.size) $ withLCtx s.lctx s.localInsts $
x s.fvars s.expectedType?
def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α :=
elabBindersWithExpectedType binders none $ fun args _ => x args
(fvars, lctx, newLocalInsts) ← elabBindersAux binders 0 #[] lctx localInsts;
resettingSynthInstanceCacheWhen (newLocalInsts.size > localInsts.size) $ withLCtx lctx newLocalInsts $
x fvars
@[inline] def elabBinder {α} (binder : Syntax) (x : Expr → TermElabM α) : TermElabM α :=
elabBinders #[binder] (fun fvars => x (fvars.get! 1))
@ -261,12 +240,76 @@ private partial def expandFunBindersAux (binders : Array Syntax) : Syntax → Na
def expandFunBinders (binders : Array Syntax) (body : Syntax) : TermElabM (Array Syntax × Syntax) :=
expandFunBindersAux binders body 0 #[]
namespace FunBinders
structure State :=
(implicitArgs : Array Expr := #[])
(fvars : Array Expr := #[])
(lctx : LocalContext)
(localInsts : LocalInstances)
(expectedType? : Option Expr := none)
(explicit : Bool := false)
private def propagateExpectedType (ref : Syntax) (fvar : Expr) (fvarType : Expr) (s : State) : TermElabM State := do
match s.expectedType? with
| none => pure s
| some expectedType => do
expectedType ← whnfForall ref expectedType;
match expectedType with
| Expr.forallE _ d b _ => do
isDefEq ref fvarType d;
let b := b.instantiate1 fvar;
pure { expectedType? := some b, .. s }
| _ => pure { expectedType? := none, .. s }
private partial def elabFunBinderViews (binderViews : Array BinderView) : Nat → State → TermElabM State
| i, s =>
if h : i < binderViews.size then
let binderView := binderViews.get ⟨i, h⟩;
withLCtx s.lctx s.localInsts $ do
type ← elabType binderView.type;
fvarId ← mkFreshFVarId;
let fvar := mkFVar fvarId;
let fvars := s.fvars.push fvar;
-- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type);
let lctx := s.lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi;
s ← propagateExpectedType binderView.id fvar type s;
className? ← isClass binderView.type type;
match className? with
| none => elabFunBinderViews (i+1) { fvars := fvars, lctx := lctx, .. s }
| some className => do
resetSynthInstanceCache;
let localInsts := s.localInsts.push { className := className, fvar := mkFVar fvarId };
elabFunBinderViews (i+1) { fvars := fvars, lctx := lctx, localInsts := localInsts, .. s }
else
pure s
partial def elabFunBindersAux (binders : Array Syntax) : Nat → State → TermElabM State
| i, s =>
if h : i < binders.size then do
binderViews ← matchBinder (binders.get ⟨i, h⟩);
s ← elabFunBinderViews binderViews 0 s;
elabFunBindersAux (i+1) s
else
pure s
end FunBinders
def elabFunBinders {α} (binders : Array Syntax) (expectedType? : Option Expr) (explicit : Bool) (x : Array Expr → Option Expr → TermElabM α) : TermElabM α :=
if binders.isEmpty then x #[] expectedType?
else do
lctx ← getLCtx;
localInsts ← getLocalInsts;
s ← FunBinders.elabFunBindersAux binders 0 { lctx := lctx, localInsts := localInsts, expectedType? := expectedType?, explicit := explicit };
resettingSynthInstanceCacheWhen (s.localInsts.size > localInsts.size) $ withLCtx s.lctx s.localInsts $
x s.fvars s.expectedType?
def elabFunCore (stx : Syntax) (expectedType? : Option Expr) (explicit : Bool) : TermElabM Expr := do
-- `fun` term+ `=>` term
let binders := (stx.getArg 1).getArgs;
let body := stx.getArg 3;
(binders, body) ← expandFunBinders binders body;
elabBindersWithExpectedType binders expectedType? $ fun xs expectedType? => do {
elabFunBinders binders expectedType? explicit $ fun xs expectedType? => do {
e ← elabTerm body expectedType?;
mkLambda stx xs e
}