feat: lift nested 'let-rec's
This commit is contained in:
parent
f2a6562eed
commit
06c6002d45
2 changed files with 134 additions and 1 deletions
|
|
@ -49,7 +49,7 @@ decls ← (letRec.getArg 1).getArgs.getSepElems.mapM fun attrDeclStx => do {
|
|||
type ← mkForallFVars xs type;
|
||||
pure (type, xs.size)
|
||||
};
|
||||
mvar ← mkFreshExprMVar type MetavarKind.synthetic;
|
||||
mvar ← mkFreshExprMVar type MetavarKind.syntheticOpaque;
|
||||
valStx ←
|
||||
if decl.isOfKind `Lean.Parser.Term.letIdDecl then
|
||||
pure $ decl.getArg 4
|
||||
|
|
|
|||
|
|
@ -199,6 +199,131 @@ letRecsToLift.forM fun toLift => do
|
|||
fnName ← getFunName fvarId letRecsToLift;
|
||||
throwErrorAt toLift.ref ("invalid type in 'let rec', it uses '" ++ fnName ++ "' which is being defined simultaneously")
|
||||
|
||||
private def replaceFunFVarsWithConsts (headers : Array DefViewElabHeader) (funFVars : Array Expr) (vars : Array Expr) (e : Expr) : Expr :=
|
||||
e.replace fun x => match x with
|
||||
| Expr.fvar fvarId _ =>
|
||||
match funFVars.indexOf x with
|
||||
| some idx => match headers.get? idx.val with
|
||||
| some header => some $ mkAppN (Lean.mkConst header.declName) vars -- Remark: we add the universe levels later
|
||||
| none => none
|
||||
| none => none
|
||||
| _ => none
|
||||
|
||||
/- A mapping from -/
|
||||
abbrev Closures := NameMap NameSet
|
||||
|
||||
namespace LetRecClosure
|
||||
|
||||
structure State :=
|
||||
(closures : Closures := {})
|
||||
(modified : Bool := false)
|
||||
|
||||
abbrev M := ReaderT (List FVarId) $ StateM State
|
||||
|
||||
private def isModified : M Bool := do
|
||||
s ← get; pure s.modified
|
||||
|
||||
private def resetModified : M Unit :=
|
||||
modify fun s => { s with modified := false }
|
||||
|
||||
private def markModified : M Unit :=
|
||||
modify fun s => { s with modified := true }
|
||||
|
||||
private def getClosures : M Closures :=
|
||||
do s ← get; pure s.closures
|
||||
|
||||
private def modifyClosures (f : Closures → Closures) : M Unit :=
|
||||
modify fun s => { s with closures := f s.closures }
|
||||
|
||||
-- merge s₂ into s₁
|
||||
private def merge (s₁ s₂ : NameSet) : M NameSet :=
|
||||
s₂.foldM
|
||||
(fun (s₁ : NameSet) k =>
|
||||
if s₁.contains k then
|
||||
pure s₁
|
||||
else do
|
||||
markModified;
|
||||
pure $ s₁.insert k)
|
||||
s₁
|
||||
|
||||
private def updateClosureOf (fvarId : FVarId) : M Unit := do
|
||||
closures ← getClosures;
|
||||
match closures.find? fvarId with
|
||||
| none => pure ()
|
||||
| some closure => do
|
||||
closureNew ← closure.foldM
|
||||
(fun (closureNew : NameSet) (fvarId' : FVarId) =>
|
||||
if fvarId == fvarId' then
|
||||
pure closureNew
|
||||
else
|
||||
match closures.find? fvarId' with
|
||||
| none => pure closureNew
|
||||
-- We are being sloppy here `otherClosure` may contain free variables that are not in the context of the let-rec associated with fvarId.
|
||||
-- We filter these out-of-context free variables later.
|
||||
| some otherClosure => merge closureNew otherClosure)
|
||||
closure;
|
||||
modifyClosures fun closures => closures.insert fvarId closureNew
|
||||
|
||||
private partial def closureFixpoint : Unit → M Unit
|
||||
| _ => do
|
||||
resetModified;
|
||||
letRecFVarIds ← read;
|
||||
letRecFVarIds.forM updateClosureOf;
|
||||
whenM isModified $ closureFixpoint ()
|
||||
|
||||
private def mkInitialClosures (letRecsToLift : List LetRecToLift) : Closures :=
|
||||
letRecsToLift.foldl
|
||||
(fun (closures : Closures) toLift =>
|
||||
let state := Lean.collectFVars {} toLift.val;
|
||||
let state := Lean.collectFVars state toLift.type;
|
||||
closures.insert toLift.fvarId state.fvarSet)
|
||||
{}
|
||||
|
||||
private def mkClosures (letRecsToLift : List LetRecToLift) : Closures :=
|
||||
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId;
|
||||
let closures := mkInitialClosures letRecsToLift;
|
||||
let (_, s) := ((closureFixpoint ()).run letRecFVarIds).run { closures := closures };
|
||||
s.closures
|
||||
|
||||
private def nameSetToFVars (s : NameSet) (lctx : LocalContext) (letRecFVarIds : List FVarId) : Array FVarId :=
|
||||
let fvarIds := s.fold
|
||||
(fun (fvarIds : Array FVarId) (fvarId : FVarId) =>
|
||||
if lctx.contains fvarId && !letRecFVarIds.contains fvarId then
|
||||
fvarIds.push fvarId
|
||||
else
|
||||
fvarIds)
|
||||
#[];
|
||||
fvarIds.qsort fun fvarId₁ fvarId₂ => (lctx.get! fvarId₁).index < (lctx.get! fvarId₂).index
|
||||
|
||||
def main (letRecsToLift : List LetRecToLift) : TermElabM (List LetRecToLift) := do
|
||||
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId;
|
||||
let closures := mkClosures letRecsToLift;
|
||||
-- Assign metavariables associated with each let-rec
|
||||
ps ← letRecsToLift.mapM fun toLift => do {
|
||||
let s := (closures.find? toLift.fvarId).get!;
|
||||
let lctx := toLift.lctx;
|
||||
let xs := (nameSetToFVars s lctx letRecFVarIds).map mkFVar;
|
||||
assignExprMVar toLift.mvarId (mkAppN (Lean.mkConst toLift.declName) xs);
|
||||
pure (xs, toLift)
|
||||
};
|
||||
ps.mapM fun (p : Array Expr × LetRecToLift) => do
|
||||
let (xs, toLift) := p;
|
||||
type ← instantiateMVars toLift.type;
|
||||
val ← instantiateMVars toLift.val;
|
||||
let lctx := toLift.lctx;
|
||||
-- Val may contain outer let-recs, we must replace them with constants
|
||||
let val := val.replace fun x => match x with
|
||||
| Expr.fvar fvarId _ => match ps.find? fun (p : Array Expr × LetRecToLift) => p.2.fvarId == fvarId with
|
||||
| some p => some (mkAppN (Lean.mkConst p.2.declName) p.1)
|
||||
| none => none
|
||||
| _ => none;
|
||||
-- Apply closure
|
||||
let type := lctx.mkForall xs type;
|
||||
let val := lctx.mkLambda xs val;
|
||||
pure { toLift with type := type, val := val }
|
||||
|
||||
end LetRecClosure
|
||||
|
||||
def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do
|
||||
scopeLevelNames ← getLevelNames;
|
||||
headers ← elabHeaders views;
|
||||
|
|
@ -213,7 +338,15 @@ withFunLocalDecls headers fun funFVars => do
|
|||
letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift;
|
||||
checkLetRecsToLiftTypes funFVars letRecsToLift;
|
||||
withUsedWhen vars headers values letRecsToLift (not $ isTheorem views) fun vars => do
|
||||
let values := values.map $ replaceFunFVarsWithConsts headers funFVars vars;
|
||||
let letRecsToLift := letRecsToLift.map fun toLift => { toLift with val := replaceFunFVarsWithConsts headers funFVars vars toLift.val };
|
||||
values ← values.mapM $ mkLambdaFVars vars;
|
||||
headers ← headers.mapM fun header => do { type ← mkForallFVars vars header.type; pure { header with type := type } };
|
||||
letRecsToLift ← LetRecClosure.main letRecsToLift;
|
||||
values ← values.mapM instantiateMVars;
|
||||
-- TODO
|
||||
values.forM fun val => IO.println (toString val);
|
||||
letRecsToLift.forM fun toLift => IO.println (toString toLift.declName ++ " := " ++ toString toLift.val);
|
||||
throwError "WIP mutual def"
|
||||
|
||||
end Term
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue