refactor: decouple elaboration for forall and lambda binders
Motivation: another refactoring to improve `elabFunCore`.
This commit is contained in:
parent
fae52a7ba6
commit
6429486f88
1 changed files with 88 additions and 45 deletions
|
|
@ -9,7 +9,7 @@ import Init.Lean.Elab.Term
|
|||
namespace Lean
|
||||
namespace Elab
|
||||
namespace Term
|
||||
namespace Binders
|
||||
|
||||
/--
|
||||
Given syntax of the forms
|
||||
a) (`:` term)?
|
||||
|
|
@ -37,6 +37,7 @@ else
|
|||
structure BinderView :=
|
||||
(id : Syntax) (type : Syntax) (bi : BinderInfo)
|
||||
|
||||
|
||||
/-
|
||||
Expand `optional (binderDefault <|> binderTactic)`
|
||||
def binderDefault := parser! " := " >> termParser
|
||||
|
|
@ -84,72 +85,50 @@ match stx with
|
|||
throwUnsupportedSyntax
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
structure State :=
|
||||
(fvars : Array Expr := #[])
|
||||
(lctx : LocalContext)
|
||||
(localInsts : LocalInstances)
|
||||
(expectedType? : Option Expr := none)
|
||||
|
||||
private def propagateExpectedType (ref : Syntax) (fvar : Expr) (fvarType : Expr) (s : State) : TermElabM State := do
|
||||
match s.expectedType? with
|
||||
| none => pure s
|
||||
| some expectedType => do
|
||||
expectedType ← whnfForall ref expectedType;
|
||||
match expectedType with
|
||||
| Expr.forallE _ d b _ => do
|
||||
isDefEq ref fvarType d;
|
||||
let b := b.instantiate1 fvar;
|
||||
pure { expectedType? := some b, .. s }
|
||||
| _ => pure { expectedType? := none, .. s }
|
||||
|
||||
private partial def elabBinderViews (binderViews : Array BinderView) : Nat → State → TermElabM State
|
||||
| i, s =>
|
||||
private partial def elabBinderViews (binderViews : Array BinderView)
|
||||
: Nat → Array Expr → LocalContext → LocalInstances → TermElabM (Array Expr × LocalContext × LocalInstances)
|
||||
| i, fvars, lctx, localInsts =>
|
||||
if h : i < binderViews.size then
|
||||
let binderView := binderViews.get ⟨i, h⟩;
|
||||
withLCtx s.lctx s.localInsts $ do
|
||||
withLCtx lctx localInsts $ do
|
||||
type ← elabType binderView.type;
|
||||
fvarId ← mkFreshFVarId;
|
||||
let fvar := mkFVar fvarId;
|
||||
let fvars := s.fvars.push fvar;
|
||||
let fvars := fvars.push fvar;
|
||||
-- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type);
|
||||
let lctx := s.lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi;
|
||||
s ← propagateExpectedType binderView.id fvar type s;
|
||||
let lctx := lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi;
|
||||
className? ← isClass binderView.type type;
|
||||
match className? with
|
||||
| none => elabBinderViews (i+1) { fvars := fvars, lctx := lctx, .. s }
|
||||
| none => elabBinderViews (i+1) fvars lctx localInsts
|
||||
| some className => do
|
||||
resetSynthInstanceCache;
|
||||
let localInsts := s.localInsts.push { className := className, fvar := mkFVar fvarId };
|
||||
elabBinderViews (i+1) { fvars := fvars, lctx := lctx, localInsts := localInsts, .. s }
|
||||
let localInsts := localInsts.push { className := className, fvar := mkFVar fvarId };
|
||||
elabBinderViews (i+1) fvars lctx localInsts
|
||||
else
|
||||
pure s
|
||||
pure (fvars, lctx, localInsts)
|
||||
|
||||
partial def elabBindersAux (binders : Array Syntax) : Nat → State → TermElabM State
|
||||
| i, s =>
|
||||
private partial def elabBindersAux (binders : Array Syntax)
|
||||
: Nat → Array Expr → LocalContext → LocalInstances → TermElabM (Array Expr × LocalContext × LocalInstances)
|
||||
| i, fvars, lctx, localInsts =>
|
||||
if h : i < binders.size then do
|
||||
binderViews ← matchBinder (binders.get ⟨i, h⟩);
|
||||
s ← elabBinderViews binderViews 0 s;
|
||||
elabBindersAux (i+1) s
|
||||
(fvars, lctx, localInsts) ← elabBinderViews binderViews 0 fvars lctx localInsts;
|
||||
elabBindersAux (i+1) fvars lctx localInsts
|
||||
else
|
||||
pure s
|
||||
|
||||
end Binders
|
||||
pure (fvars, lctx, localInsts)
|
||||
|
||||
/--
|
||||
Elaborate the given binders (i.e., `Syntax` objects for `simpleBinder <|> bracktedBinder`),
|
||||
update the local context, set of local instances, reset instance chache (if needed), and then
|
||||
execute `x` with the updated context. -/
|
||||
def elabBindersWithExpectedType {α} (binders : Array Syntax) (expectedType? : Option Expr) (x : Array Expr → Option Expr → TermElabM α) : TermElabM α :=
|
||||
if binders.isEmpty then x #[] expectedType?
|
||||
def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α :=
|
||||
if binders.isEmpty then x #[]
|
||||
else do
|
||||
lctx ← getLCtx;
|
||||
localInsts ← getLocalInsts;
|
||||
s ← Binders.elabBindersAux binders 0 { lctx := lctx, localInsts := localInsts, expectedType? := expectedType? };
|
||||
resettingSynthInstanceCacheWhen (s.localInsts.size > localInsts.size) $ withLCtx s.lctx s.localInsts $
|
||||
x s.fvars s.expectedType?
|
||||
|
||||
def elabBinders {α} (binders : Array Syntax) (x : Array Expr → TermElabM α) : TermElabM α :=
|
||||
elabBindersWithExpectedType binders none $ fun args _ => x args
|
||||
(fvars, lctx, newLocalInsts) ← elabBindersAux binders 0 #[] lctx localInsts;
|
||||
resettingSynthInstanceCacheWhen (newLocalInsts.size > localInsts.size) $ withLCtx lctx newLocalInsts $
|
||||
x fvars
|
||||
|
||||
@[inline] def elabBinder {α} (binder : Syntax) (x : Expr → TermElabM α) : TermElabM α :=
|
||||
elabBinders #[binder] (fun fvars => x (fvars.get! 1))
|
||||
|
|
@ -261,12 +240,76 @@ private partial def expandFunBindersAux (binders : Array Syntax) : Syntax → Na
|
|||
def expandFunBinders (binders : Array Syntax) (body : Syntax) : TermElabM (Array Syntax × Syntax) :=
|
||||
expandFunBindersAux binders body 0 #[]
|
||||
|
||||
namespace FunBinders
|
||||
|
||||
structure State :=
|
||||
(implicitArgs : Array Expr := #[])
|
||||
(fvars : Array Expr := #[])
|
||||
(lctx : LocalContext)
|
||||
(localInsts : LocalInstances)
|
||||
(expectedType? : Option Expr := none)
|
||||
(explicit : Bool := false)
|
||||
|
||||
private def propagateExpectedType (ref : Syntax) (fvar : Expr) (fvarType : Expr) (s : State) : TermElabM State := do
|
||||
match s.expectedType? with
|
||||
| none => pure s
|
||||
| some expectedType => do
|
||||
expectedType ← whnfForall ref expectedType;
|
||||
match expectedType with
|
||||
| Expr.forallE _ d b _ => do
|
||||
isDefEq ref fvarType d;
|
||||
let b := b.instantiate1 fvar;
|
||||
pure { expectedType? := some b, .. s }
|
||||
| _ => pure { expectedType? := none, .. s }
|
||||
|
||||
private partial def elabFunBinderViews (binderViews : Array BinderView) : Nat → State → TermElabM State
|
||||
| i, s =>
|
||||
if h : i < binderViews.size then
|
||||
let binderView := binderViews.get ⟨i, h⟩;
|
||||
withLCtx s.lctx s.localInsts $ do
|
||||
type ← elabType binderView.type;
|
||||
fvarId ← mkFreshFVarId;
|
||||
let fvar := mkFVar fvarId;
|
||||
let fvars := s.fvars.push fvar;
|
||||
-- dbgTrace (toString binderView.id.getId ++ " : " ++ toString type);
|
||||
let lctx := s.lctx.mkLocalDecl fvarId binderView.id.getId type binderView.bi;
|
||||
s ← propagateExpectedType binderView.id fvar type s;
|
||||
className? ← isClass binderView.type type;
|
||||
match className? with
|
||||
| none => elabFunBinderViews (i+1) { fvars := fvars, lctx := lctx, .. s }
|
||||
| some className => do
|
||||
resetSynthInstanceCache;
|
||||
let localInsts := s.localInsts.push { className := className, fvar := mkFVar fvarId };
|
||||
elabFunBinderViews (i+1) { fvars := fvars, lctx := lctx, localInsts := localInsts, .. s }
|
||||
else
|
||||
pure s
|
||||
|
||||
partial def elabFunBindersAux (binders : Array Syntax) : Nat → State → TermElabM State
|
||||
| i, s =>
|
||||
if h : i < binders.size then do
|
||||
binderViews ← matchBinder (binders.get ⟨i, h⟩);
|
||||
s ← elabFunBinderViews binderViews 0 s;
|
||||
elabFunBindersAux (i+1) s
|
||||
else
|
||||
pure s
|
||||
|
||||
end FunBinders
|
||||
|
||||
def elabFunBinders {α} (binders : Array Syntax) (expectedType? : Option Expr) (explicit : Bool) (x : Array Expr → Option Expr → TermElabM α) : TermElabM α :=
|
||||
if binders.isEmpty then x #[] expectedType?
|
||||
else do
|
||||
lctx ← getLCtx;
|
||||
localInsts ← getLocalInsts;
|
||||
s ← FunBinders.elabFunBindersAux binders 0 { lctx := lctx, localInsts := localInsts, expectedType? := expectedType?, explicit := explicit };
|
||||
resettingSynthInstanceCacheWhen (s.localInsts.size > localInsts.size) $ withLCtx s.lctx s.localInsts $
|
||||
x s.fvars s.expectedType?
|
||||
|
||||
def elabFunCore (stx : Syntax) (expectedType? : Option Expr) (explicit : Bool) : TermElabM Expr := do
|
||||
-- `fun` term+ `=>` term
|
||||
let binders := (stx.getArg 1).getArgs;
|
||||
let body := stx.getArg 3;
|
||||
(binders, body) ← expandFunBinders binders body;
|
||||
elabBindersWithExpectedType binders expectedType? $ fun xs expectedType? => do {
|
||||
elabFunBinders binders expectedType? explicit $ fun xs expectedType? => do {
|
||||
e ← elabTerm body expectedType?;
|
||||
mkLambda stx xs e
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue