feat: lift let rec expressions nested in a mutual block
This commit is contained in:
parent
12f69a78b7
commit
b2f932c4dc
2 changed files with 337 additions and 101 deletions
|
|
@ -3,6 +3,8 @@ Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
|||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Meta.Closure
|
||||
import Lean.Meta.Check
|
||||
import Lean.Elab.Command
|
||||
import Lean.Elab.Definition
|
||||
|
||||
|
|
@ -21,6 +23,9 @@ structure DefViewElabHeader :=
|
|||
(type : Expr) -- including the parameters
|
||||
(declVal : Syntax)
|
||||
|
||||
instance DefViewElabHeader.inhabited : Inhabited DefViewElabHeader :=
|
||||
⟨⟨arbitrary _, {}, DefKind.«def», arbitrary _, arbitrary _, [], 0, arbitrary _, arbitrary _⟩⟩
|
||||
|
||||
namespace Term
|
||||
|
||||
open Meta
|
||||
|
|
@ -187,7 +192,7 @@ match decl? with
|
|||
| some n => pure n
|
||||
|
||||
/-
|
||||
Ensures that the of let-rec definitions do not contain functions being defined.
|
||||
Ensures that the of let-rec definition types do not contain functions being defined.
|
||||
In principle, this test can be improved. We could perform it after we separate the set of functions is strongly connected components.
|
||||
However, this extra complication doesn't seem worth it.
|
||||
-/
|
||||
|
|
@ -199,41 +204,119 @@ letRecsToLift.forM fun toLift => do
|
|||
fnName ← getFunName fvarId letRecsToLift;
|
||||
throwErrorAt toLift.ref ("invalid type in 'let rec', it uses '" ++ fnName ++ "' which is being defined simultaneously")
|
||||
|
||||
private def replaceFunFVarsWithConsts (headers : Array DefViewElabHeader) (funFVars : Array Expr) (vars : Array Expr) (e : Expr) : Expr :=
|
||||
e.replace fun x => match x with
|
||||
| Expr.fvar fvarId _ =>
|
||||
match funFVars.indexOf x with
|
||||
| some idx => match headers.get? idx.val with
|
||||
| some header => some $ mkAppN (Lean.mkConst header.declName) vars -- Remark: we add the universe levels later
|
||||
| none => none
|
||||
| none => none
|
||||
| _ => none
|
||||
structure PreDeclaration :=
|
||||
(kind : DefKind)
|
||||
(modifiers : Modifiers)
|
||||
(declName : Name)
|
||||
(type : Expr)
|
||||
(val : Expr)
|
||||
|
||||
/- A mapping from -/
|
||||
abbrev Closures := NameMap NameSet
|
||||
namespace MutualClosure
|
||||
|
||||
namespace LetRecClosure
|
||||
/- A mapping from FVarId to Set of FVarIds. -/
|
||||
abbrev UsedFVarsMap := NameMap NameSet
|
||||
|
||||
/-
|
||||
Create the `UsedFVarsMap` mapping that takes the variable id for the mutually recursive functions being defined to the set of
|
||||
free variables in its definition.
|
||||
|
||||
For `mainFVars`, this is just the set of section variables `sectionVars` used.
|
||||
For nested let-rec functions, we collect their free variables.
|
||||
|
||||
Recall that a `let rec` expressions are encoded as follows in the elaborator.
|
||||
```lean
|
||||
let rec
|
||||
f : A := t,
|
||||
g : B := s;
|
||||
body
|
||||
```
|
||||
is encoded as
|
||||
```lean
|
||||
let f : A := ?m₁;
|
||||
let g : B := ?m₂;
|
||||
body
|
||||
```
|
||||
where `?m₁` and `?m₂` are synthetic opaque metavariables. That are assigned by this module.
|
||||
We may have nested `let rec`s.
|
||||
```lean
|
||||
let rec f : A :=
|
||||
let rec g : B := t;
|
||||
s;
|
||||
body
|
||||
```
|
||||
is encoded as
|
||||
```lean
|
||||
let f : A := ?m₁;
|
||||
body
|
||||
```
|
||||
and the body of `f` is stored the field `val` of a `LetRecToLift`. For the example above,
|
||||
we would have a `LetRecToLift` containing:
|
||||
```
|
||||
{
|
||||
mvarId := m₁,
|
||||
val := `(let g : B := ?m₂; body)
|
||||
...
|
||||
}
|
||||
```
|
||||
Note that `g` is not a free variable at `(let g : B := ?m₂; body)`. We recover the fact that
|
||||
`f` depends on `g` because it contains `m₂`
|
||||
-/
|
||||
private def mkInitialUsedFVarsMap (mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift)
|
||||
: UsedFVarsMap :=
|
||||
let sectionVarSet := sectionVars.foldl (fun (s : NameSet) (var : Expr) => s.insert var.fvarId!) {};
|
||||
let usedFVarMap := mainFVarIds.foldl
|
||||
(fun (usedFVarMap : UsedFVarsMap) mainFVarId =>
|
||||
usedFVarMap.insert mainFVarId sectionVarSet)
|
||||
{};
|
||||
letRecsToLift.foldl
|
||||
(fun (usedFVarMap : UsedFVarsMap) toLift =>
|
||||
let state := Lean.collectFVars {} toLift.val;
|
||||
let state := Lean.collectFVars state toLift.type;
|
||||
let set := state.fvarSet;
|
||||
/- toLift.val may contain metavariables that are placeholders for nested let-recs. We should collect the fvarId
|
||||
for the associated let-rec because we need this information to compute the fixpoint later. -/
|
||||
let mvarIds := (toLift.val.collectMVars {}).result;
|
||||
let set := mvarIds.foldl
|
||||
(fun (set : NameSet) (mvarId : MVarId) =>
|
||||
match letRecsToLift.findSome? fun (toLift : LetRecToLift) => if toLift.mvarId == mctx.getDelayedRoot mvarId then some toLift.fvarId else none with
|
||||
| some fvarId => set.insert fvarId
|
||||
| none => set)
|
||||
set;
|
||||
usedFVarMap.insert toLift.fvarId set)
|
||||
usedFVarMap
|
||||
|
||||
/-
|
||||
The let-recs may invoke each other. Example:
|
||||
```
|
||||
let rec
|
||||
f (x : Nat) := g x + y
|
||||
g : Nat → Nat
|
||||
| 0 => 1
|
||||
| x+1 => f x + z
|
||||
```
|
||||
`y` is free variable in `f`, and `z` is a free variable in `g`.
|
||||
To close `f` and `g`, we `y` and `z` must be in the closure of both.
|
||||
That is, we need to generate the top-level definitions.
|
||||
```
|
||||
def f (y z x : Nat) := g y z x + y
|
||||
def g (y z : Nat) : Nat → Nat
|
||||
| 0 => 1
|
||||
| x+1 => f y z x + z
|
||||
```
|
||||
-/
|
||||
namespace FixPoint
|
||||
|
||||
structure State :=
|
||||
(closures : Closures := {})
|
||||
(modified : Bool := false)
|
||||
(usedFVarsMap : UsedFVarsMap := {})
|
||||
(modified : Bool := false)
|
||||
|
||||
abbrev M := ReaderT (List FVarId) $ StateM State
|
||||
|
||||
private def isModified : M Bool := do
|
||||
s ← get; pure s.modified
|
||||
|
||||
private def resetModified : M Unit :=
|
||||
modify fun s => { s with modified := false }
|
||||
|
||||
private def markModified : M Unit :=
|
||||
modify fun s => { s with modified := true }
|
||||
|
||||
private def getClosures : M Closures :=
|
||||
do s ← get; pure s.closures
|
||||
|
||||
private def modifyClosures (f : Closures → Closures) : M Unit :=
|
||||
modify fun s => { s with closures := f s.closures }
|
||||
private def isModified : M Bool := do s ← get; pure s.modified
|
||||
private def resetModified : M Unit := modify fun s => { s with modified := false }
|
||||
private def markModified : M Unit := modify fun s => { s with modified := true }
|
||||
private def getUsedFVarsMap : M UsedFVarsMap := do s ← get; pure s.usedFVarsMap
|
||||
private def modifyUsedFVars (f : UsedFVarsMap → UsedFVarsMap) : M Unit := modify fun s => { s with usedFVarsMap := f s.usedFVarsMap }
|
||||
|
||||
-- merge s₂ into s₁
|
||||
private def merge (s₁ s₂ : NameSet) : M NameSet :=
|
||||
|
|
@ -246,84 +329,229 @@ s₂.foldM
|
|||
pure $ s₁.insert k)
|
||||
s₁
|
||||
|
||||
private def updateClosureOf (fvarId : FVarId) : M Unit := do
|
||||
closures ← getClosures;
|
||||
match closures.find? fvarId with
|
||||
private def updateUsedVarsOf (fvarId : FVarId) : M Unit := do
|
||||
usedFVarsMap ← getUsedFVarsMap;
|
||||
match usedFVarsMap.find? fvarId with
|
||||
| none => pure ()
|
||||
| some closure => do
|
||||
closureNew ← closure.foldM
|
||||
(fun (closureNew : NameSet) (fvarId' : FVarId) =>
|
||||
| some fvarIds => do
|
||||
fvarIdsNew ← fvarIds.foldM
|
||||
(fun (fvarIdsNew : NameSet) (fvarId' : FVarId) =>
|
||||
if fvarId == fvarId' then
|
||||
pure closureNew
|
||||
pure fvarIdsNew
|
||||
else
|
||||
match closures.find? fvarId' with
|
||||
| none => pure closureNew
|
||||
-- We are being sloppy here `otherClosure` may contain free variables that are not in the context of the let-rec associated with fvarId.
|
||||
-- We filter these out-of-context free variables later.
|
||||
| some otherClosure => merge closureNew otherClosure)
|
||||
closure;
|
||||
modifyClosures fun closures => closures.insert fvarId closureNew
|
||||
match usedFVarsMap.find? fvarId' with
|
||||
| none => pure fvarIdsNew
|
||||
/- We are being sloppy here `otherFVarIds` may contain free variables that are
|
||||
not in the context of the let-rec associated with fvarId.
|
||||
We filter these out-of-context free variables later. -/
|
||||
| some otherFVarIds => merge fvarIdsNew otherFVarIds)
|
||||
fvarIds;
|
||||
modifyUsedFVars fun usedFVars => usedFVars.insert fvarId fvarIdsNew
|
||||
|
||||
private partial def closureFixpoint : Unit → M Unit
|
||||
private partial def fixpoint : Unit → M Unit
|
||||
| _ => do
|
||||
resetModified;
|
||||
letRecFVarIds ← read;
|
||||
letRecFVarIds.forM updateClosureOf;
|
||||
whenM isModified $ closureFixpoint ()
|
||||
letRecFVarIds.forM updateUsedVarsOf;
|
||||
whenM isModified $ fixpoint ()
|
||||
|
||||
private def mkInitialClosures (letRecsToLift : List LetRecToLift) : Closures :=
|
||||
def run (letRecFVarIds : List FVarId) (usedFVarsMap : UsedFVarsMap) : UsedFVarsMap :=
|
||||
let (_, s) := ((fixpoint ()).run letRecFVarIds).run { usedFVarsMap := usedFVarsMap };
|
||||
s.usedFVarsMap
|
||||
|
||||
end FixPoint
|
||||
|
||||
abbrev FreeVarMap := NameMap (Array FVarId)
|
||||
|
||||
private def mkFreeVarMap
|
||||
(mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId)
|
||||
(recFVarIds : Array FVarId) (letRecsToLift : List LetRecToLift) : FreeVarMap :=
|
||||
let usedFVarsMap := mkInitialUsedFVarsMap mctx sectionVars mainFVarIds letRecsToLift;
|
||||
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId;
|
||||
let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap;
|
||||
letRecsToLift.foldl
|
||||
(fun (closures : Closures) toLift =>
|
||||
let state := Lean.collectFVars {} toLift.val;
|
||||
let state := Lean.collectFVars state toLift.type;
|
||||
closures.insert toLift.fvarId state.fvarSet)
|
||||
(fun (freeVarMap : FreeVarMap) toLift =>
|
||||
let lctx := toLift.lctx;
|
||||
let fvarIdsSet := (usedFVarsMap.find? toLift.fvarId).get!;
|
||||
let fvarIds := fvarIdsSet.fold
|
||||
(fun (fvarIds : Array FVarId) (fvarId : FVarId) =>
|
||||
if lctx.contains fvarId && !recFVarIds.contains fvarId then
|
||||
fvarIds.push fvarId
|
||||
else
|
||||
fvarIds)
|
||||
#[];
|
||||
freeVarMap.insert toLift.fvarId fvarIds)
|
||||
{}
|
||||
|
||||
private def mkClosures (letRecsToLift : List LetRecToLift) : Closures :=
|
||||
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId;
|
||||
let closures := mkInitialClosures letRecsToLift;
|
||||
let (_, s) := ((closureFixpoint ()).run letRecFVarIds).run { closures := closures };
|
||||
s.closures
|
||||
structure ClosureState :=
|
||||
(newLocalDecls : Array LocalDecl := #[])
|
||||
(localDecls : Array LocalDecl := #[])
|
||||
(newLetDecls : Array LocalDecl := #[])
|
||||
(exprArgs : Array Expr := #[])
|
||||
|
||||
private def nameSetToFVars (s : NameSet) (lctx : LocalContext) (letRecFVarIds : List FVarId) : Array FVarId :=
|
||||
let fvarIds := s.fold
|
||||
(fun (fvarIds : Array FVarId) (fvarId : FVarId) =>
|
||||
if lctx.contains fvarId && !letRecFVarIds.contains fvarId then
|
||||
fvarIds.push fvarId
|
||||
else
|
||||
fvarIds)
|
||||
#[];
|
||||
fvarIds.qsort fun fvarId₁ fvarId₂ => (lctx.get! fvarId₁).index < (lctx.get! fvarId₂).index
|
||||
private def pickMaxFVar? (lctx : LocalContext) (fvarIds : Array FVarId) : Option FVarId :=
|
||||
fvarIds.getMax? fun fvarId₁ fvarId₂ => (lctx.get! fvarId₁).index < (lctx.get! fvarId₂).index
|
||||
|
||||
def main (letRecsToLift : List LetRecToLift) : TermElabM (List LetRecToLift) := do
|
||||
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId;
|
||||
let closures := mkClosures letRecsToLift;
|
||||
-- Assign metavariables associated with each let-rec
|
||||
ps ← letRecsToLift.mapM fun toLift => do {
|
||||
let s := (closures.find? toLift.fvarId).get!;
|
||||
let lctx := toLift.lctx;
|
||||
let xs := (nameSetToFVars s lctx letRecFVarIds).map mkFVar;
|
||||
assignExprMVar toLift.mvarId (mkAppN (Lean.mkConst toLift.declName) xs);
|
||||
pure (xs, toLift)
|
||||
private def preprocess (e : Expr) : TermElabM Expr := do
|
||||
e ← instantiateMVars e;
|
||||
-- which let-decls are dependent. We say a let-decl is dependent if its lambda abstraction is type incorrect.
|
||||
liftM $ check e;
|
||||
pure e
|
||||
|
||||
/- Push free variables in `s` to `toProcess` if they are not already there. -/
|
||||
private def pushNewVars (toProcess : Array FVarId) (s : CollectFVars.State) : Array FVarId :=
|
||||
s.fvarSet.fold
|
||||
(fun (toProcess : Array FVarId) fvarId => if toProcess.contains fvarId then toProcess else toProcess.push fvarId)
|
||||
toProcess
|
||||
|
||||
private def pushLocalDecl (toProcess : Array FVarId) (fvarId : FVarId) (userName : Name) (type : Expr) (bi := BinderInfo.default)
|
||||
: StateRefT ClosureState TermElabM (Array FVarId) := do
|
||||
type ← liftM $ preprocess type;
|
||||
modify fun s => { s with
|
||||
newLocalDecls := s.newLocalDecls.push $ LocalDecl.cdecl (arbitrary _) fvarId userName type bi,
|
||||
exprArgs := s.exprArgs.push (mkFVar fvarId)
|
||||
};
|
||||
ps.mapM fun (p : Array Expr × LetRecToLift) => do
|
||||
let (xs, toLift) := p;
|
||||
type ← instantiateMVars toLift.type;
|
||||
val ← instantiateMVars toLift.val;
|
||||
let lctx := toLift.lctx;
|
||||
-- Val may contain outer let-recs, we must replace them with constants
|
||||
let val := val.replace fun x => match x with
|
||||
| Expr.fvar fvarId _ => match ps.find? fun (p : Array Expr × LetRecToLift) => p.2.fvarId == fvarId with
|
||||
| some p => some (mkAppN (Lean.mkConst p.2.declName) p.1)
|
||||
| none => none
|
||||
| _ => none;
|
||||
withLCtx lctx toLift.localInstances do
|
||||
-- Apply closure
|
||||
type ← mkForallFVars xs type;
|
||||
val ← mkLambdaFVars xs val;
|
||||
pure { toLift with type := type, val := val }
|
||||
pure $ pushNewVars toProcess (collectFVars {} type)
|
||||
|
||||
end LetRecClosure
|
||||
private partial def mkClosureForAux : Array FVarId → StateRefT ClosureState TermElabM Unit
|
||||
| toProcess => do
|
||||
lctx ← getLCtx;
|
||||
match pickMaxFVar? lctx toProcess with
|
||||
| none => pure ()
|
||||
| some fvarId => do
|
||||
let toProcess := toProcess.erase fvarId;
|
||||
localDecl ← getLocalDecl fvarId;
|
||||
match localDecl with
|
||||
| LocalDecl.cdecl _ _ userName type bi => do
|
||||
toProcess ← pushLocalDecl toProcess fvarId userName type bi;
|
||||
mkClosureForAux toProcess
|
||||
| LocalDecl.ldecl _ _ userName type val _ => do
|
||||
zetaFVarIds ← getZetaFVarIds;
|
||||
if !zetaFVarIds.contains fvarId then do
|
||||
/- Non-dependent let-decl. See comment at src/Lean/Meta/Closure.lean -/
|
||||
toProcess ← pushLocalDecl toProcess fvarId userName type;
|
||||
mkClosureForAux toProcess
|
||||
else do
|
||||
/- Dependent let-decl. -/
|
||||
type ← liftM $ preprocess type;
|
||||
val ← liftM $ preprocess val;
|
||||
modify fun s => { s with
|
||||
newLetDecls := s.newLetDecls.push $ LocalDecl.ldecl (arbitrary _) fvarId userName type val false,
|
||||
/- We don't want to interleave let and lambda declarations in our closure. So, we expand any occurrences of fvarId
|
||||
at `newLocalDecls` and `localDecls` -/
|
||||
newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl fvarId val),
|
||||
localDecls := s.localDecls.map (replaceFVarIdAtLocalDecl fvarId val)
|
||||
};
|
||||
mkClosureForAux (pushNewVars toProcess (collectFVars (collectFVars {} type) val))
|
||||
|
||||
structure LetRecClosure :=
|
||||
(localDecls : Array LocalDecl)
|
||||
(closed : Expr) -- expression used to replace occurrences of the let-rec FVarId
|
||||
(toLift : LetRecToLift)
|
||||
|
||||
private def mkLetRecClosureFor (toLift : LetRecToLift) (freeVars : Array FVarId) : TermElabM LetRecClosure := do
|
||||
let lctx := toLift.lctx;
|
||||
withLCtx lctx toLift.localInstances do
|
||||
lambdaTelescope toLift.val fun xs val => do
|
||||
type ← instantiateForall toLift.type xs;
|
||||
lctx ← getLCtx;
|
||||
(_, s) ← (mkClosureForAux freeVars).run { localDecls := xs.map fun x => lctx.get! x.fvarId! };
|
||||
let type := Closure.mkForall s.localDecls $ Closure.mkForall s.newLetDecls type;
|
||||
let val := Closure.mkLambda s.localDecls $ Closure.mkLambda s.newLetDecls val;
|
||||
let c := mkAppN (Lean.mkConst toLift.declName) s.exprArgs;
|
||||
assignExprMVar toLift.mvarId c;
|
||||
pure ⟨s.newLocalDecls, c, { toLift with val := val, type := type }⟩
|
||||
|
||||
private def mkLetRecClosures (letRecsToLift : List LetRecToLift) (freeVarMap : FreeVarMap) : TermElabM (List LetRecClosure) :=
|
||||
letRecsToLift.mapM fun toLift => mkLetRecClosureFor toLift (freeVarMap.find? toLift.fvarId).get!
|
||||
|
||||
/- Mapping from FVarId of mutually recursive functions being defined to "closure" expression. -/
|
||||
abbrev Replacement := NameMap Expr
|
||||
|
||||
def insertReplacementForMainFns (r : Replacement) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) : Replacement :=
|
||||
mainFVars.size.fold
|
||||
(fun i (r : Replacement) =>
|
||||
r.insert (mainFVars.get! i).fvarId! (mkAppN (Lean.mkConst (mainHeaders.get! i).declName) sectionVars))
|
||||
r
|
||||
|
||||
def insertReplacementForLetRecs (r : Replacement) (letRecClosures : List LetRecClosure) : Replacement :=
|
||||
letRecClosures.foldl (fun (r : Replacement) c => r.insert c.toLift.fvarId c.closed) r
|
||||
|
||||
def Replacement.apply (r : Replacement) (e : Expr) : Expr :=
|
||||
e.replace fun e => match e with
|
||||
| Expr.fvar fvarId _ => match r.find? fvarId with
|
||||
| some c => some c
|
||||
| _ => none
|
||||
| _ => none
|
||||
|
||||
def pushMain (preDecls : Array PreDeclaration) (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainVals : Array Expr)
|
||||
: TermElabM (Array PreDeclaration) :=
|
||||
mainHeaders.size.foldM
|
||||
(fun i (preDecls : Array PreDeclaration) => do
|
||||
let header := mainHeaders.get! i;
|
||||
val ← mkLambdaFVars sectionVars (mainVals.get! i);
|
||||
type ← mkForallFVars sectionVars header.type;
|
||||
pure $ preDecls.push {
|
||||
kind := header.kind,
|
||||
declName := header.declName,
|
||||
modifiers := header.modifiers,
|
||||
type := type,
|
||||
val := val
|
||||
})
|
||||
preDecls
|
||||
|
||||
def pushLetRecs (preDecls : Array PreDeclaration) (letRecClosures : List LetRecClosure) (kind : DefKind) : Array PreDeclaration :=
|
||||
letRecClosures.foldl
|
||||
(fun (preDecls : Array PreDeclaration) (c : LetRecClosure) =>
|
||||
let type := Closure.mkForall c.localDecls c.toLift.type;
|
||||
let val := Closure.mkLambda c.localDecls c.toLift.val;
|
||||
preDecls.push {
|
||||
kind := kind,
|
||||
declName := c.toLift.declName,
|
||||
modifiers := { attrs := c.toLift.attrs },
|
||||
type := type,
|
||||
val := val
|
||||
})
|
||||
preDecls
|
||||
|
||||
def getKindForLetRecs (mainHeaders : Array DefViewElabHeader) : DefKind :=
|
||||
if mainHeaders.any fun h => h.kind.isTheorem then DefKind.«theorem»
|
||||
else DefKind.«def»
|
||||
|
||||
/-
|
||||
- `sectionVars`: The section variables used in the `mutual` block.
|
||||
- `mainHeaders`: The elaborated header of the top-level definitions being defined by the mutual block.
|
||||
- `mainFVars`: The auxiliary variables used to represent the top-level definitions being defined by the mutual block.
|
||||
- `mainVals`: The elaborated value for the top-level definitions
|
||||
- `letRecsToLift`: The let-rec's definitions that need to be lifted
|
||||
-/
|
||||
def main (sectionVars : Array Expr) (mainHeaders : Array DefViewElabHeader) (mainFVars : Array Expr) (mainVals : Array Expr) (letRecsToLift : List LetRecToLift)
|
||||
: TermElabM (Array PreDeclaration) := do
|
||||
-- Store in recFVarIds the fvarId of every function being defined by the mutual block.
|
||||
let mainFVarIds := mainFVars.map Expr.fvarId!;
|
||||
let recFVarIds := (letRecsToLift.toArray.map fun toLift => toLift.fvarId) ++ mainFVarIds;
|
||||
-- Compute the set of free variables (excluding `recFVarIds`) for each let-rec.
|
||||
mctx ← getMCtx;
|
||||
let freeVarMap := mkFreeVarMap mctx sectionVars mainFVarIds recFVarIds letRecsToLift;
|
||||
resetZetaFVarIds;
|
||||
withTrackingZeta do
|
||||
-- By checking `toLift.type` and `toLift.val` we populate `zetaFVarIds`. See comments at `src/Lean/Meta/Closure.lean`.
|
||||
letRecsToLift.forM fun toLift => withLCtx toLift.lctx toLift.localInstances do { liftM $ check toLift.type; liftM $ check toLift.val };
|
||||
letRecClosures ← mkLetRecClosures letRecsToLift freeVarMap;
|
||||
-- mkLetRecClosures assign metavariables that were placeholders for the lifted declarations.
|
||||
mainVals ← mainVals.mapM instantiateMVars;
|
||||
mainHeaders ← mainHeaders.mapM instantiateMVarsAtHeader;
|
||||
letRecClosures ← letRecClosures.mapM fun closure => do { toLift ← instantiateMVarsAtLetRecToLift closure.toLift; pure { closure with toLift := toLift } };
|
||||
-- Replace fvarIds for functions being defined with closed terms
|
||||
let r := insertReplacementForMainFns {} sectionVars mainHeaders mainFVars;
|
||||
let r := insertReplacementForLetRecs r letRecClosures;
|
||||
let mainVals := mainVals.map r.apply;
|
||||
let mainHeaders := mainHeaders.map fun h => { h with type := r.apply h.type };
|
||||
let letRecClosures := letRecClosures.map fun c => { c with toLift := { c.toLift with type := r.apply c.toLift.type, val := r.apply c.toLift.val } };
|
||||
let letRecKind := getKindForLetRecs mainHeaders;
|
||||
pushMain (pushLetRecs #[] letRecClosures letRecKind) sectionVars mainHeaders mainVals
|
||||
|
||||
end MutualClosure
|
||||
|
||||
def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do
|
||||
scopeLevelNames ← getLevelNames;
|
||||
|
|
@ -339,19 +567,12 @@ withFunLocalDecls headers fun funFVars => do
|
|||
letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift;
|
||||
checkLetRecsToLiftTypes funFVars letRecsToLift;
|
||||
withUsedWhen vars headers values letRecsToLift (not $ isTheorem views) fun vars => do
|
||||
let values := values.map $ replaceFunFVarsWithConsts headers funFVars vars;
|
||||
let letRecsToLift := letRecsToLift.map fun toLift => { toLift with val := replaceFunFVarsWithConsts headers funFVars vars toLift.val };
|
||||
values ← values.mapM $ mkLambdaFVars vars;
|
||||
headers ← headers.mapM fun header => do { type ← mkForallFVars vars header.type; pure { header with type := type } };
|
||||
letRecsToLift ← LetRecClosure.main letRecsToLift;
|
||||
values ← values.mapM instantiateMVars;
|
||||
preDecls ← MutualClosure.main vars headers funFVars values letRecsToLift;
|
||||
-- TODO
|
||||
values.forM fun val => IO.println (toString val);
|
||||
letRecsToLift.forM fun toLift => IO.println (toString toLift.declName ++ " := " ++ toString toLift.val);
|
||||
preDecls.forM fun preDecl => IO.println (toString preDecl.declName ++ " : " ++ toString preDecl.type ++ " :=\n" ++ toString preDecl.val ++ "\n");
|
||||
throwError "WIP mutual def"
|
||||
|
||||
end Term
|
||||
|
||||
namespace Command
|
||||
|
||||
def elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
|
||||
|
|
|
|||
|
|
@ -397,6 +397,21 @@ m.dAssignment.contains mvarId
|
|||
def eraseDelayed (m : MetavarContext) (mvarId : MVarId) : MetavarContext :=
|
||||
{ m with dAssignment := m.dAssignment.erase mvarId }
|
||||
|
||||
/- Given a sequence of delayed assignments
|
||||
```
|
||||
mvarId₁ := mvarId₂ ...;
|
||||
...
|
||||
mvarIdₙ := mvarId_root ... -- where `mvarId_root` is not delayed assigned
|
||||
```
|
||||
in `mctx`, `getDelayedRoot mctx mvarId₁` return `mvarId_root`.
|
||||
If `mvarId₁` is not delayed assigned then return `mvarId₁` -/
|
||||
partial def getDelayedRoot (m : MetavarContext) : MVarId → MVarId
|
||||
| mvarId => match getDelayedAssignment? m mvarId with
|
||||
| some d => match d.val.getAppFn with
|
||||
| Expr.mvar mvarId _ => getDelayedRoot mvarId
|
||||
| _ => mvarId
|
||||
| none => mvarId
|
||||
|
||||
def isLevelAssignable (mctx : MetavarContext) (mvarId : MVarId) : Bool :=
|
||||
match mctx.lDepth.find? mvarId with
|
||||
| some d => d == mctx.depth
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue