feat: lift nested 'let-rec's

This commit is contained in:
Leonardo de Moura 2020-09-02 18:53:18 -07:00
parent f2a6562eed
commit 06c6002d45
2 changed files with 134 additions and 1 deletions

View file

@ -49,7 +49,7 @@ decls ← (letRec.getArg 1).getArgs.getSepElems.mapM fun attrDeclStx => do {
type ← mkForallFVars xs type;
pure (type, xs.size)
};
mvar ← mkFreshExprMVar type MetavarKind.synthetic;
mvar ← mkFreshExprMVar type MetavarKind.syntheticOpaque;
valStx ←
if decl.isOfKind `Lean.Parser.Term.letIdDecl then
pure $ decl.getArg 4

View file

@ -199,6 +199,131 @@ letRecsToLift.forM fun toLift => do
fnName ← getFunName fvarId letRecsToLift;
throwErrorAt toLift.ref ("invalid type in 'let rec', it uses '" ++ fnName ++ "' which is being defined simultaneously")
private def replaceFunFVarsWithConsts (headers : Array DefViewElabHeader) (funFVars : Array Expr) (vars : Array Expr) (e : Expr) : Expr :=
e.replace fun x => match x with
| Expr.fvar fvarId _ =>
match funFVars.indexOf x with
| some idx => match headers.get? idx.val with
| some header => some $ mkAppN (Lean.mkConst header.declName) vars -- Remark: we add the universe levels later
| none => none
| none => none
| _ => none
/- A mapping from -/
abbrev Closures := NameMap NameSet
namespace LetRecClosure
structure State :=
(closures : Closures := {})
(modified : Bool := false)
abbrev M := ReaderT (List FVarId) $ StateM State
private def isModified : M Bool := do
s ← get; pure s.modified
private def resetModified : M Unit :=
modify fun s => { s with modified := false }
private def markModified : M Unit :=
modify fun s => { s with modified := true }
private def getClosures : M Closures :=
do s ← get; pure s.closures
private def modifyClosures (f : Closures → Closures) : M Unit :=
modify fun s => { s with closures := f s.closures }
-- merge s₂ into s₁
private def merge (s₁ s₂ : NameSet) : M NameSet :=
s₂.foldM
(fun (s₁ : NameSet) k =>
if s₁.contains k then
pure s₁
else do
markModified;
pure $ s₁.insert k)
s₁
private def updateClosureOf (fvarId : FVarId) : M Unit := do
closures ← getClosures;
match closures.find? fvarId with
| none => pure ()
| some closure => do
closureNew ← closure.foldM
(fun (closureNew : NameSet) (fvarId' : FVarId) =>
if fvarId == fvarId' then
pure closureNew
else
match closures.find? fvarId' with
| none => pure closureNew
-- We are being sloppy here `otherClosure` may contain free variables that are not in the context of the let-rec associated with fvarId.
-- We filter these out-of-context free variables later.
| some otherClosure => merge closureNew otherClosure)
closure;
modifyClosures fun closures => closures.insert fvarId closureNew
private partial def closureFixpoint : Unit → M Unit
| _ => do
resetModified;
letRecFVarIds ← read;
letRecFVarIds.forM updateClosureOf;
whenM isModified $ closureFixpoint ()
private def mkInitialClosures (letRecsToLift : List LetRecToLift) : Closures :=
letRecsToLift.foldl
(fun (closures : Closures) toLift =>
let state := Lean.collectFVars {} toLift.val;
let state := Lean.collectFVars state toLift.type;
closures.insert toLift.fvarId state.fvarSet)
{}
private def mkClosures (letRecsToLift : List LetRecToLift) : Closures :=
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId;
let closures := mkInitialClosures letRecsToLift;
let (_, s) := ((closureFixpoint ()).run letRecFVarIds).run { closures := closures };
s.closures
private def nameSetToFVars (s : NameSet) (lctx : LocalContext) (letRecFVarIds : List FVarId) : Array FVarId :=
let fvarIds := s.fold
(fun (fvarIds : Array FVarId) (fvarId : FVarId) =>
if lctx.contains fvarId && !letRecFVarIds.contains fvarId then
fvarIds.push fvarId
else
fvarIds)
#[];
fvarIds.qsort fun fvarId₁ fvarId₂ => (lctx.get! fvarId₁).index < (lctx.get! fvarId₂).index
def main (letRecsToLift : List LetRecToLift) : TermElabM (List LetRecToLift) := do
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId;
let closures := mkClosures letRecsToLift;
-- Assign metavariables associated with each let-rec
ps ← letRecsToLift.mapM fun toLift => do {
let s := (closures.find? toLift.fvarId).get!;
let lctx := toLift.lctx;
let xs := (nameSetToFVars s lctx letRecFVarIds).map mkFVar;
assignExprMVar toLift.mvarId (mkAppN (Lean.mkConst toLift.declName) xs);
pure (xs, toLift)
};
ps.mapM fun (p : Array Expr × LetRecToLift) => do
let (xs, toLift) := p;
type ← instantiateMVars toLift.type;
val ← instantiateMVars toLift.val;
let lctx := toLift.lctx;
-- Val may contain outer let-recs, we must replace them with constants
let val := val.replace fun x => match x with
| Expr.fvar fvarId _ => match ps.find? fun (p : Array Expr × LetRecToLift) => p.2.fvarId == fvarId with
| some p => some (mkAppN (Lean.mkConst p.2.declName) p.1)
| none => none
| _ => none;
-- Apply closure
let type := lctx.mkForall xs type;
let val := lctx.mkLambda xs val;
pure { toLift with type := type, val := val }
end LetRecClosure
def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do
scopeLevelNames ← getLevelNames;
headers ← elabHeaders views;
@ -213,7 +338,15 @@ withFunLocalDecls headers fun funFVars => do
letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift;
checkLetRecsToLiftTypes funFVars letRecsToLift;
withUsedWhen vars headers values letRecsToLift (not $ isTheorem views) fun vars => do
let values := values.map $ replaceFunFVarsWithConsts headers funFVars vars;
let letRecsToLift := letRecsToLift.map fun toLift => { toLift with val := replaceFunFVarsWithConsts headers funFVars vars toLift.val };
values ← values.mapM $ mkLambdaFVars vars;
headers ← headers.mapM fun header => do { type ← mkForallFVars vars header.type; pure { header with type := type } };
letRecsToLift ← LetRecClosure.main letRecsToLift;
values ← values.mapM instantiateMVars;
-- TODO
values.forM fun val => IO.println (toString val);
letRecsToLift.forM fun toLift => IO.println (toString toLift.declName ++ " := " ++ toString toLift.val);
throwError "WIP mutual def"
end Term