feat: lift let rec expressions nested in a mutual block

This commit is contained in:
Leonardo de Moura 2020-09-04 15:40:58 -07:00
parent 12f69a78b7
commit b2f932c4dc
2 changed files with 337 additions and 101 deletions

View file

@ -3,6 +3,8 @@ Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Meta.Closure
import Lean.Meta.Check
import Lean.Elab.Command
import Lean.Elab.Definition
@ -21,6 +23,9 @@ structure DefViewElabHeader :=
(type : Expr) -- including the parameters
(declVal : Syntax)
instance DefViewElabHeader.inhabited : Inhabited DefViewElabHeader :=
⟨⟨arbitrary _, {}, DefKind.«def», arbitrary _, arbitrary _, [], 0, arbitrary _, arbitrary _⟩⟩
namespace Term
open Meta
@ -187,7 +192,7 @@ match decl? with
| some n => pure n
/-
Ensures that the of let-rec definitions do not contain functions being defined.
Ensures that the of let-rec definition types do not contain functions being defined.
In principle, this test can be improved. We could perform it after we separate the set of functions is strongly connected components.
However, this extra complication doesn't seem worth it.
-/
@ -199,41 +204,119 @@ letRecsToLift.forM fun toLift => do
fnName ← getFunName fvarId letRecsToLift;
throwErrorAt toLift.ref ("invalid type in 'let rec', it uses '" ++ fnName ++ "' which is being defined simultaneously")
private def replaceFunFVarsWithConsts (headers : Array DefViewElabHeader) (funFVars : Array Expr) (vars : Array Expr) (e : Expr) : Expr :=
e.replace fun x => match x with
| Expr.fvar fvarId _ =>
match funFVars.indexOf x with
| some idx => match headers.get? idx.val with
| some header => some $ mkAppN (Lean.mkConst header.declName) vars -- Remark: we add the universe levels later
| none => none
| none => none
| _ => none
structure PreDeclaration :=
(kind : DefKind)
(modifiers : Modifiers)
(declName : Name)
(type : Expr)
(val : Expr)
/- A mapping from -/
abbrev Closures := NameMap NameSet
namespace MutualClosure
namespace LetRecClosure
/- A mapping from FVarId to Set of FVarIds. -/
abbrev UsedFVarsMap := NameMap NameSet
/-
Create the `UsedFVarsMap` mapping that takes the variable id for the mutually recursive functions being defined to the set of
free variables in its definition.
For `mainFVars`, this is just the set of section variables `sectionVars` used.
For nested let-rec functions, we collect their free variables.
Recall that a `let rec` expressions are encoded as follows in the elaborator.
```lean
let rec
f : A := t,
g : B := s;
body
```
is encoded as
```lean
let f : A := ?m₁;
let g : B := ?m₂;
body
```
where `?m₁` and `?m₂` are synthetic opaque metavariables. That are assigned by this module.
We may have nested `let rec`s.
```lean
let rec f : A :=
let rec g : B := t;
s;
body
```
is encoded as
```lean
let f : A := ?m₁;
body
```
and the body of `f` is stored the field `val` of a `LetRecToLift`. For the example above,
we would have a `LetRecToLift` containing:
```
{
mvarId := m₁,
val := `(let g : B := ?m₂; body)
...
}
```
Note that `g` is not a free variable at `(let g : B := ?m₂; body)`. We recover the fact that
`f` depends on `g` because it contains `m₂`
-/
private def mkInitialUsedFVarsMap (mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift)
: UsedFVarsMap :=
let sectionVarSet := sectionVars.foldl (fun (s : NameSet) (var : Expr) => s.insert var.fvarId!) {};
let usedFVarMap := mainFVarIds.foldl
(fun (usedFVarMap : UsedFVarsMap) mainFVarId =>
usedFVarMap.insert mainFVarId sectionVarSet)
{};
letRecsToLift.foldl
(fun (usedFVarMap : UsedFVarsMap) toLift =>
let state := Lean.collectFVars {} toLift.val;
let state := Lean.collectFVars state toLift.type;
let set := state.fvarSet;
/- toLift.val may contain metavariables that are placeholders for nested let-recs. We should collect the fvarId
for the associated let-rec because we need this information to compute the fixpoint later. -/
let mvarIds := (toLift.val.collectMVars {}).result;
let set := mvarIds.foldl
(fun (set : NameSet) (mvarId : MVarId) =>
match letRecsToLift.findSome? fun (toLift : LetRecToLift) => if toLift.mvarId == mctx.getDelayedRoot mvarId then some toLift.fvarId else none with
| some fvarId => set.insert fvarId
| none => set)
set;
usedFVarMap.insert toLift.fvarId set)
usedFVarMap
/-
The let-recs may invoke each other. Example:
```
let rec
f (x : Nat) := g x + y
g : Nat → Nat
| 0 => 1
| x+1 => f x + z
```
`y` is free variable in `f`, and `z` is a free variable in `g`.
To close `f` and `g`, we `y` and `z` must be in the closure of both.
That is, we need to generate the top-level definitions.
```
def f (y z x : Nat) := g y z x + y
def g (y z : Nat) : Nat → Nat
| 0 => 1
| x+1 => f y z x + z
```
-/
namespace FixPoint
structure State :=
(closures : Closures := {})
(modified : Bool := false)
(usedFVarsMap : UsedFVarsMap := {})
(modified : Bool := false)
abbrev M := ReaderT (List FVarId) $ StateM State
private def isModified : M Bool := do
s ← get; pure s.modified
private def resetModified : M Unit :=
modify fun s => { s with modified := false }
private def markModified : M Unit :=
modify fun s => { s with modified := true }
private def getClosures : M Closures :=
do s ← get; pure s.closures
private def modifyClosures (f : Closures → Closures) : M Unit :=
modify fun s => { s with closures := f s.closures }
private def isModified : M Bool := do s ← get; pure s.modified
private def resetModified : M Unit := modify fun s => { s with modified := false }
private def markModified : M Unit := modify fun s => { s with modified := true }
private def getUsedFVarsMap : M UsedFVarsMap := do s ← get; pure s.usedFVarsMap
private def modifyUsedFVars (f : UsedFVarsMap → UsedFVarsMap) : M Unit := modify fun s => { s with usedFVarsMap := f s.usedFVarsMap }
-- merge s₂ into s₁
private def merge (s₁ s₂ : NameSet) : M NameSet :=
@ -246,84 +329,229 @@ s₂.foldM
pure $ s₁.insert k)
s₁
private def updateClosureOf (fvarId : FVarId) : M Unit := do
closures ← getClosures;
match closures.find? fvarId with
private def updateUsedVarsOf (fvarId : FVarId) : M Unit := do
usedFVarsMap ← getUsedFVarsMap;
match usedFVarsMap.find? fvarId with
| none => pure ()
| some closure => do
closureNew ← closure.foldM
(fun (closureNew : NameSet) (fvarId' : FVarId) =>
| some fvarIds => do
fvarIdsNew ← fvarIds.foldM
(fun (fvarIdsNew : NameSet) (fvarId' : FVarId) =>
if fvarId == fvarId' then
pure closureNew
pure fvarIdsNew
else
match closures.find? fvarId' with
| none => pure closureNew
-- We are being sloppy here `otherClosure` may contain free variables that are not in the context of the let-rec associated with fvarId.
-- We filter these out-of-context free variables later.
| some otherClosure => merge closureNew otherClosure)
closure;
modifyClosures fun closures => closures.insert fvarId closureNew
match usedFVarsMap.find? fvarId' with
| none => pure fvarIdsNew
/- We are being sloppy here `otherFVarIds` may contain free variables that are
not in the context of the let-rec associated with fvarId.
We filter these out-of-context free variables later. -/
| some otherFVarIds => merge fvarIdsNew otherFVarIds)
fvarIds;
modifyUsedFVars fun usedFVars => usedFVars.insert fvarId fvarIdsNew
private partial def closureFixpoint : Unit → M Unit
private partial def fixpoint : Unit → M Unit
| _ => do
resetModified;
letRecFVarIds ← read;
letRecFVarIds.forM updateClosureOf;
whenM isModified $ closureFixpoint ()
letRecFVarIds.forM updateUsedVarsOf;
whenM isModified $ fixpoint ()
private def mkInitialClosures (letRecsToLift : List LetRecToLift) : Closures :=
def run (letRecFVarIds : List FVarId) (usedFVarsMap : UsedFVarsMap) : UsedFVarsMap :=
let (_, s) := ((fixpoint ()).run letRecFVarIds).run { usedFVarsMap := usedFVarsMap };
s.usedFVarsMap
end FixPoint
abbrev FreeVarMap := NameMap (Array FVarId)
private def mkFreeVarMap
(mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId)
(recFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift) : FreeVarMap :=
let usedFVarsMap := mkInitialUsedFVarsMap mctx sectionVars mainFVarIds letRecsToLift;
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId;
let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap;
letRecsToLift.foldl
(fun (closures : Closures) toLift =>
let state := Lean.collectFVars {} toLift.val;
let state := Lean.collectFVars state toLift.type;
closures.insert toLift.fvarId state.fvarSet)
(fun (freeVarMap : FreeVarMap) toLift =>
let lctx := toLift.lctx;
let fvarIdsSet := (usedFVarsMap.find? toLift.fvarId).get!;
let fvarIds := fvarIdsSet.fold
(fun (fvarIds : Array FVarId) (fvarId : FVarId) =>
if lctx.contains fvarId && !recFVarIds.contains fvarId then
fvarIds.push fvarId
else
fvarIds)
#[];
freeVarMap.insert toLift.fvarId fvarIds)
{}
private def mkClosures (letRecsToLift : List LetRecToLift) : Closures :=
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId;
let closures := mkInitialClosures letRecsToLift;
let (_, s) := ((closureFixpoint ()).run letRecFVarIds).run { closures := closures };
s.closures
structure ClosureState :=
(newLocalDecls : Array LocalDecl := #[])
(localDecls : Array LocalDecl := #[])
(newLetDecls : Array LocalDecl := #[])
(exprArgs : Array Expr := #[])
private def nameSetToFVars (s : NameSet) (lctx : LocalContext) (letRecFVarIds : List FVarId) : Array FVarId :=
let fvarIds := s.fold
(fun (fvarIds : Array FVarId) (fvarId : FVarId) =>
if lctx.contains fvarId && !letRecFVarIds.contains fvarId then
fvarIds.push fvarId
else
fvarIds)
#[];
fvarIds.qsort fun fvarId₁ fvarId₂ => (lctx.get! fvarId₁).index < (lctx.get! fvarId₂).index
private def pickMaxFVar? (lctx : LocalContext) (fvarIds : Array FVarId) : Option FVarId :=
fvarIds.getMax? fun fvarId₁ fvarId₂ => (lctx.get! fvarId₁).index < (lctx.get! fvarId₂).index
def main (letRecsToLift : List LetRecToLift) : TermElabM (List LetRecToLift) := do
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId;
let closures := mkClosures letRecsToLift;
-- Assign metavariables associated with each let-rec
ps ← letRecsToLift.mapM fun toLift => do {
let s := (closures.find? toLift.fvarId).get!;
let lctx := toLift.lctx;
let xs := (nameSetToFVars s lctx letRecFVarIds).map mkFVar;
assignExprMVar toLift.mvarId (mkAppN (Lean.mkConst toLift.declName) xs);
pure (xs, toLift)
private def preprocess (e : Expr) : TermElabM Expr := do
e ← instantiateMVars e;
-- which let-decls are dependent. We say a let-decl is dependent if its lambda abstraction is type incorrect.
liftM $ check e;
pure e
/- Push free variables in `s` to `toProcess` if they are not already there. -/
private def pushNewVars (toProcess : Array FVarId) (s : CollectFVars.State) : Array FVarId :=
s.fvarSet.fold
(fun (toProcess : Array FVarId) fvarId => if toProcess.contains fvarId then toProcess else toProcess.push fvarId)
toProcess
private def pushLocalDecl (toProcess : Array FVarId) (fvarId : FVarId) (userName : Name) (type : Expr) (bi := BinderInfo.default)
: StateRefT ClosureState TermElabM (Array FVarId) := do
type ← liftM $ preprocess type;
modify fun s => { s with
newLocalDecls := s.newLocalDecls.push $ LocalDecl.cdecl (arbitrary _) fvarId userName type bi,
exprArgs := s.exprArgs.push (mkFVar fvarId)
};
ps.mapM fun (p : Array Expr × LetRecToLift) => do
let (xs, toLift) := p;
type ← instantiateMVars toLift.type;
val ← instantiateMVars toLift.val;
let lctx := toLift.lctx;
-- Val may contain outer let-recs, we must replace them with constants
let val := val.replace fun x => match x with
| Expr.fvar fvarId _ => match ps.find? fun (p : Array Expr × LetRecToLift) => p.2.fvarId == fvarId with
| some p => some (mkAppN (Lean.mkConst p.2.declName) p.1)
| none => none
| _ => none;
withLCtx lctx toLift.localInstances do
-- Apply closure
type ← mkForallFVars xs type;
val ← mkLambdaFVars xs val;
pure { toLift with type := type, val := val }
pure $ pushNewVars toProcess (collectFVars {} type)
end LetRecClosure
private partial def mkClosureForAux : Array FVarId → StateRefT ClosureState TermElabM Unit
| toProcess => do
lctx ← getLCtx;
match pickMaxFVar? lctx toProcess with
| none => pure ()
| some fvarId => do
let toProcess := toProcess.erase fvarId;
localDecl ← getLocalDecl fvarId;
match localDecl with
| LocalDecl.cdecl _ _ userName type bi => do
toProcess ← pushLocalDecl toProcess fvarId userName type bi;
mkClosureForAux toProcess
| LocalDecl.ldecl _ _ userName type val _ => do
zetaFVarIds ← getZetaFVarIds;
if !zetaFVarIds.contains fvarId then do
/- Non-dependent let-decl. See comment at src/Lean/Meta/Closure.lean -/
toProcess ← pushLocalDecl toProcess fvarId userName type;
mkClosureForAux toProcess
else do
/- Dependent let-decl. -/
type ← liftM $ preprocess type;
val ← liftM $ preprocess val;
modify fun s => { s with
newLetDecls := s.newLetDecls.push $ LocalDecl.ldecl (arbitrary _) fvarId userName type val false,
/- We don't want to interleave let and lambda declarations in our closure. So, we expand any occurrences of fvarId
at `newLocalDecls` and `localDecls` -/
newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl fvarId val),
localDecls := s.localDecls.map (replaceFVarIdAtLocalDecl fvarId val)
};
mkClosureForAux (pushNewVars toProcess (collectFVars (collectFVars {} type) val))
structure LetRecClosure :=
(localDecls : Array LocalDecl)
(closed : Expr) -- expression used to replace occurrences of the let-rec FVarId
(toLift : LetRecToLift)
private def mkLetRecClosureFor (toLift : LetRecToLift) (freeVars : Array FVarId) : TermElabM LetRecClosure := do
let lctx := toLift.lctx;
withLCtx lctx toLift.localInstances do
lambdaTelescope toLift.val fun xs val => do
type ← instantiateForall toLift.type xs;
lctx ← getLCtx;
(_, s) ← (mkClosureForAux freeVars).run { localDecls := xs.map fun x => lctx.get! x.fvarId! };
let type := Closure.mkForall s.localDecls $ Closure.mkForall s.newLetDecls type;
let val := Closure.mkLambda s.localDecls $ Closure.mkLambda s.newLetDecls val;
let c := mkAppN (Lean.mkConst toLift.declName) s.exprArgs;
assignExprMVar toLift.mvarId c;
pure ⟨s.newLocalDecls, c, { toLift with val := val, type := type }⟩
private def mkLetRecClosures (letRecsToLift : List LetRecToLift) (freeVarMap : FreeVarMap) : TermElabM (List LetRecClosure) :=
letRecsToLift.mapM fun toLift => mkLetRecClosureFor toLift (freeVarMap.find? toLift.fvarId).get!
/- Mapping from FVarId of mutually recursive functions being defined to "closure" expression. -/
abbrev Replacement := NameMap Expr
def insertReplacementForMainFns (r : Replacement) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) : Replacement :=
mainFVars.size.fold
(fun i (r : Replacement) =>
r.insert (mainFVars.get! i).fvarId! (mkAppN (Lean.mkConst (mainHeaders.get! i).declName) sectionVars))
r
def insertReplacementForLetRecs (r : Replacement) (letRecClosures : List LetRecClosure) : Replacement :=
letRecClosures.foldl (fun (r : Replacement) c => r.insert c.toLift.fvarId c.closed) r
def Replacement.apply (r : Replacement) (e : Expr) : Expr :=
e.replace fun e => match e with
| Expr.fvar fvarId _ => match r.find? fvarId with
| some c => some c
| _ => none
| _ => none
def pushMain (preDecls : Array PreDeclaration) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainVals : Array Expr)
: TermElabM (Array PreDeclaration) :=
mainHeaders.size.foldM
(fun i (preDecls : Array PreDeclaration) => do
let header := mainHeaders.get! i;
val ← mkLambdaFVars sectionVars (mainVals.get! i);
type ← mkForallFVars sectionVars header.type;
pure $ preDecls.push {
kind := header.kind,
declName := header.declName,
modifiers := header.modifiers,
type := type,
val := val
})
preDecls
def pushLetRecs (preDecls : Array PreDeclaration) (letRecClosures : List LetRecClosure) (kind : DefKind) : Array PreDeclaration :=
letRecClosures.foldl
(fun (preDecls : Array PreDeclaration) (c : LetRecClosure) =>
let type := Closure.mkForall c.localDecls c.toLift.type;
let val := Closure.mkLambda c.localDecls c.toLift.val;
preDecls.push {
kind := kind,
declName := c.toLift.declName,
modifiers := { attrs := c.toLift.attrs },
type := type,
val := val
})
preDecls
def getKindForLetRecs (mainHeaders : Array DefViewElabHeader) : DefKind :=
if mainHeaders.any fun h => h.kind.isTheorem then DefKind.«theorem»
else DefKind.«def»
/-
- `sectionVars`: The section variables used in the `mutual` block.
- `mainHeaders`: The elaborated header of the top-level definitions being defined by the mutual block.
- `mainFVars`: The auxiliary variables used to represent the top-level definitions being defined by the mutual block.
- `mainVals`: The elaborated value for the top-level definitions
- `letRecsToLift`: The let-rec's definitions that need to be lifted
-/
def main (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) (mainVals : Array Expr) (letRecsToLift : List LetRecToLift)
: TermElabM (Array PreDeclaration) := do
-- Store in recFVarIds the fvarId of every function being defined by the mutual block.
let mainFVarIds := mainFVars.map Expr.fvarId!;
let recFVarIds := (letRecsToLift.toArray.map fun toLift => toLift.fvarId) ++ mainFVarIds;
-- Compute the set of free variables (excluding `recFVarIds`) for each let-rec.
mctx ← getMCtx;
let freeVarMap := mkFreeVarMap mctx sectionVars mainFVarIds recFVarIds letRecsToLift;
resetZetaFVarIds;
withTrackingZeta do
-- By checking `toLift.type` and `toLift.val` we populate `zetaFVarIds`. See comments at `src/Lean/Meta/Closure.lean`.
letRecsToLift.forM fun toLift => withLCtx toLift.lctx toLift.localInstances do { liftM $ check toLift.type; liftM $ check toLift.val };
letRecClosures ← mkLetRecClosures letRecsToLift freeVarMap;
-- mkLetRecClosures assign metavariables that were placeholders for the lifted declarations.
mainVals ← mainVals.mapM instantiateMVars;
mainHeaders ← mainHeaders.mapM instantiateMVarsAtHeader;
letRecClosures ← letRecClosures.mapM fun closure => do { toLift ← instantiateMVarsAtLetRecToLift closure.toLift; pure { closure with toLift := toLift } };
-- Replace fvarIds for functions being defined with closed terms
let r := insertReplacementForMainFns {} sectionVars mainHeaders mainFVars;
let r := insertReplacementForLetRecs r letRecClosures;
let mainVals := mainVals.map r.apply;
let mainHeaders := mainHeaders.map fun h => { h with type := r.apply h.type };
let letRecClosures := letRecClosures.map fun c => { c with toLift := { c.toLift with type := r.apply c.toLift.type, val := r.apply c.toLift.val } };
let letRecKind := getKindForLetRecs mainHeaders;
pushMain (pushLetRecs #[] letRecClosures letRecKind) sectionVars mainHeaders mainVals
end MutualClosure
def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do
scopeLevelNames ← getLevelNames;
@ -339,19 +567,12 @@ withFunLocalDecls headers fun funFVars => do
letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift;
checkLetRecsToLiftTypes funFVars letRecsToLift;
withUsedWhen vars headers values letRecsToLift (not $ isTheorem views) fun vars => do
let values := values.map $ replaceFunFVarsWithConsts headers funFVars vars;
let letRecsToLift := letRecsToLift.map fun toLift => { toLift with val := replaceFunFVarsWithConsts headers funFVars vars toLift.val };
values ← values.mapM $ mkLambdaFVars vars;
headers ← headers.mapM fun header => do { type ← mkForallFVars vars header.type; pure { header with type := type } };
letRecsToLift ← LetRecClosure.main letRecsToLift;
values ← values.mapM instantiateMVars;
preDecls ← MutualClosure.main vars headers funFVars values letRecsToLift;
-- TODO
values.forM fun val => IO.println (toString val);
letRecsToLift.forM fun toLift => IO.println (toString toLift.declName ++ " := " ++ toString toLift.val);
preDecls.forM fun preDecl => IO.println (toString preDecl.declName ++ " : " ++ toString preDecl.type ++ " :=\n" ++ toString preDecl.val ++ "\n");
throwError "WIP mutual def"
end Term
namespace Command
def elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do

View file

@ -397,6 +397,21 @@ m.dAssignment.contains mvarId
def eraseDelayed (m : MetavarContext) (mvarId : MVarId) : MetavarContext :=
{ m with dAssignment := m.dAssignment.erase mvarId }
/- Given a sequence of delayed assignments
```
mvarId₁ := mvarId₂ ...;
...
mvarIdₙ := mvarId_root ... -- where `mvarId_root` is not delayed assigned
```
in `mctx`, `getDelayedRoot mctx mvarId₁` return `mvarId_root`.
If `mvarId₁` is not delayed assigned then return `mvarId₁` -/
partial def getDelayedRoot (m : MetavarContext) : MVarId → MVarId
| mvarId => match getDelayedAssignment? m mvarId with
| some d => match d.val.getAppFn with
| Expr.mvar mvarId _ => getDelayedRoot mvarId
| _ => mvarId
| none => mvarId
def isLevelAssignable (mctx : MetavarContext) (mvarId : MVarId) : Bool :=
match mctx.lDepth.find? mvarId with
| some d => d == mctx.depth