feat: new and improved mkAuxDefinition

This commit is contained in:
Leonardo de Moura 2020-09-03 17:37:06 -07:00
parent f34fd3e6b4
commit 555a3f0dcf
4 changed files with 210 additions and 171 deletions

View file

@ -357,4 +357,10 @@ export MonadLCtx (getLCtx)
instance monadLCtxTrans (m n) [MonadLCtx m] [MonadLift m n] : MonadLCtx n :=
{ getLCtx := liftM (getLCtx : m _) }
def replaceFVarIdAtLocalDecl (fvarId : FVarId) (e : Expr) (d : LocalDecl) : LocalDecl :=
if d.fvarId == fvarId then d
else match d with
| LocalDecl.cdecl idx id n type bi => LocalDecl.cdecl idx id n (type.replaceFVarId fvarId e) bi
| LocalDecl.ldecl idx id n type val nonDep => LocalDecl.ldecl idx id n (type.replaceFVarId fvarId e) (val.replaceFVarId fvarId e) nonDep
end Lean

View file

@ -8,31 +8,36 @@ import Lean.MetavarContext
import Lean.Environment
import Lean.Util.FoldConsts
import Lean.Meta.Basic
import Lean.Meta.Check
namespace Lean
namespace Meta
namespace Closure
structure ToProcessElement :=
(fvarId : FVarId) (newFVarId : FVarId)
instance ToProcessElement.inhabited : Inhabited ToProcessElement :=
⟨⟨arbitrary _, arbitrary _⟩⟩
structure Context :=
(lctxInput : LocalContext)
(zeta : Bool) -- if `true` let-variables are expanded
(zeta : Bool)
structure State :=
(mctx : MetavarContext)
(lctxOutput : LocalContext := {})
(ngen : NameGenerator := { namePrefix := `_closure })
(visitedLevel : LevelMap Level := {})
(visitedExpr : ExprStructMap Expr := {})
(levelParams : Array Name := #[])
(nextLevelIdx : Nat := 1)
(levelClosure : Array Level := #[])
(nextExprIdx : Nat := 1)
(exprClosure : Array Expr := #[])
(visitedLevel : LevelMap Level := {})
(visitedExpr : ExprStructMap Expr := {})
(levelParams : Array Name := #[])
(nextLevelIdx : Nat := 1)
(levelArgs : Array Level := #[])
(newLocalDecls : Array LocalDecl := #[])
(newLocalDeclsForMVars : Array LocalDecl := #[])
(newLetDecls : Array LocalDecl := #[])
(nextExprIdx : Nat := 1)
(exprMVarArgs : Array Expr := #[])
(exprFVarArgs : Array Expr := #[])
(toProcess : Array ToProcessElement := #[])
def Exception := String
abbrev ClosureM := ReaderT Context (EStateM Exception State)
abbrev ClosureM := ReaderT Context $ StateRefT State MetaM
@[inline] def visitLevel (f : Level → ClosureM Level) (u : Level) : ClosureM Level :=
if !u.hasMVar && !u.hasParam then pure u
@ -45,20 +50,23 @@ else do
modify $ fun s => { s with visitedLevel := s.visitedLevel.insert u v };
pure v
@[inline] def visitExpr (f : Expr → ClosureM Expr) (e : Expr) : ClosureM Expr :=
if !e.hasLevelParam && !e.hasFVar && !e.hasMVar then pure e
else do
s ← get;
match s.visitedExpr.find? e with
| some r => pure r
| none => do
r ← f e;
modify $ fun s => { s with visitedExpr := s.visitedExpr.insert e r };
pure r
def mkNewLevelParam (u : Level) : ClosureM Level := do
s ← get;
let p := (`u).appendIndexAfter s.nextLevelIdx;
modify $ fun s => { s with levelParams := s.levelParams.push p, nextLevelIdx := s.nextLevelIdx + 1, levelClosure := s.levelClosure.push u };
modify $ fun s => { s with levelParams := s.levelParams.push p, nextLevelIdx := s.nextLevelIdx + 1, levelArgs := s.levelArgs.push u };
pure $ mkLevelParam p
def getMCtx : ClosureM MetavarContext := do
s ← get; pure s.mctx
def instantiateMVars (e : Expr) : ClosureM Expr := do
modifyGet fun s =>
let (e, mctx) := s.mctx.instantiateMVars e;
(e, { s with mctx := mctx })
partial def collectLevelAux : Level → ClosureM Level
| u@(Level.succ v _) => do v ← visitLevel collectLevelAux v; pure $ u.updateSucc! v
| u@(Level.max v w _) => do v ← visitLevel collectLevelAux v; w ← visitLevel collectLevelAux w; pure $ u.updateMax! v w
@ -67,12 +75,17 @@ partial def collectLevelAux : Level → ClosureM Level
| u@(Level.param _ _) => mkNewLevelParam u
| u@(Level.zero _) => pure u
def collectLevel (u : Level) : ClosureM Level :=
def collectLevel (u : Level) : ClosureM Level := do
-- u ← instantiateLevelMVars u;
visitLevel collectLevelAux u
instance : MonadNameGenerator ClosureM :=
{ getNGen := do s ← get; pure s.ngen,
setNGen := fun ngen => modify fun s => { s with ngen := ngen } }
def preprocess (e : Expr) : ClosureM Expr := do
e ← instantiateMVars e;
ctx ← read;
-- If we are not zeta-expanding let-decls, then we use `check` to find
-- which let-decls are dependent. We say a let-decl is dependent if its lambda abstraction is type incorrect.
when (!ctx.zeta) $ liftM $ check e;
pure e
/--
Remark: This method does not guarantee unique user names.
@ -84,32 +97,8 @@ let n := (`_x).appendIndexAfter s.nextExprIdx;
modify $ fun s => { s with nextExprIdx := s.nextExprIdx + 1 };
pure n
def getUserName (userName? : Option Name) : ClosureM Name :=
match userName? with
| some userName => pure userName
| none => mkNextUserName
def mkLocalDecl (userName? : Option Name) (type : Expr) (bi : BinderInfo) : ClosureM Expr := do
userName ← getUserName userName?;
fvarId ← mkFreshFVarId;
modify $ fun s => { s with lctxOutput := s.lctxOutput.mkLocalDecl fvarId userName type bi };
pure $ mkFVar fvarId
def mkLetDecl (userName : Name) (type : Expr) (value : Expr) (nonDep : Bool) : ClosureM Expr := do
fvarId ← mkFreshFVarId;
modify $ fun s => { s with lctxOutput := s.lctxOutput.mkLetDecl fvarId userName type value nonDep };
pure $ mkFVar fvarId
@[inline] def visitExpr (f : Expr → ClosureM Expr) (e : Expr) : ClosureM Expr :=
if !e.hasLevelParam && !e.hasFVar && !e.hasMVar then pure e
else do
s ← get;
match s.visitedExpr.find? e with
| some r => pure r
| none => do
r ← f e;
modify $ fun s => { s with visitedExpr := s.visitedExpr.insert e r };
pure r
def pushToProcess (elem : ToProcessElement) : ClosureM Unit :=
modify fun s => { s with toProcess := s.toProcess.push elem }
partial def collectExprAux : Expr → ClosureM Expr
| e =>
@ -124,126 +113,171 @@ partial def collectExprAux : Expr → ClosureM Expr
| Expr.sort u _ => do u ← collectLevel u; pure (e.updateSort! u)
| Expr.const c us _ => do us ← us.mapM collectLevel; pure (e.updateConst! us)
| Expr.mvar mvarId _ => do
mctx ← getMCtx;
match mctx.findDecl? mvarId with
| none => throw "unknown metavariable"
| some mvarDecl => do
type ← instantiateMVars mvarDecl.type;
type ← collect type;
x ← mkLocalDecl none type BinderInfo.default;
modify $ fun s => { s with exprClosure := s.exprClosure.push e };
pure x
mvarDecl ← getMVarDecl mvarId;
type ← preprocess mvarDecl.type;
type ← collect type;
newFVarId ← mkFreshFVarId;
userName ← mkNextUserName;
modify fun s => { s with
newLocalDeclsForMVars := s.newLocalDeclsForMVars.push $ LocalDecl.cdecl (arbitrary _) newFVarId userName type BinderInfo.default,
exprMVarArgs := s.exprMVarArgs.push e
};
pure $ mkFVar newFVarId
| Expr.fvar fvarId _ => do
ctx ← read;
match ctx.lctxInput.find? fvarId with
| none => throw "unknown free variable"
| some (LocalDecl.cdecl _ _ userName type bi) => do
type ← instantiateMVars type;
type ← collect type;
x ← mkLocalDecl userName type bi;
modify $ fun s => { s with exprClosure := s.exprClosure.push e };
pure x
| some (LocalDecl.ldecl _ _ userName type value nonDep) =>
if ctx.zeta then do
value ← instantiateMVars value;
collect value
else do
type ← instantiateMVars type;
type ← collect type;
value ← instantiateMVars value;
value ← collect value;
-- Note that let-declarations do not need to be provided to the closure being constructed.
mkLetDecl userName type value nonDep
decl ← getLocalDecl fvarId;
match ctx.zeta, decl.value? with
| true, some value => do value ← preprocess value; collect value
| _, _ => do
newFVarId ← mkFreshFVarId;
pushToProcess ⟨fvarId, newFVarId⟩;
pure $ mkFVar newFVarId
| e => pure e
def collectExpr (e : Expr) : ClosureM Expr := do
e ← instantiateMVars e;
e ← preprocess e;
visitExpr collectExprAux e
structure ExprToClose :=
(expr : Expr)
(isType : Bool)
instance ExprToClose.inhabited : Inhabited ExprToClose := ⟨⟨arbitrary _, arbitrary _⟩⟩
structure MkClosureResult :=
(levelParams : Array Name)
(exprs : Array Expr)
(levelClosure : Array Level)
(exprClosure : Array Expr)
(mctx : MetavarContext)
def mkClosure (mctx : MetavarContext) (lctx : LocalContext) (exprsToClose : Array ExprToClose) (zeta : Bool := false) : Except String MkClosureResult :=
let shareCommonExprs : Std.ShareCommonM (Array ExprToClose) := exprsToClose.mapM fun ⟨e, isType⟩ => do {
e ← Std.withShareCommon e;
pure ⟨e, isType⟩
};
let exprsToClose := shareCommonExprs.run;
let mkExprs : ClosureM (Array Expr × MetavarContext) := do {
exprs ← exprsToClose.mapM fun ⟨e, _⟩ => collectExpr e;
mctx ← getMCtx;
pure (exprs, mctx)
};
match (mkExprs { lctxInput := lctx, zeta := zeta }).run { mctx := mctx } with
| EStateM.Result.ok (exprs, mctx) s =>
let fvars := s.lctxOutput.getFVars;
let exprs := exprs.mapIdx fun i e =>
let isType := (exprsToClose.get! i).isType;
if isType then
s.lctxOutput.mkForall fvars e
partial def pickNextToProcessAux (lctx : LocalContext)
: Nat → Array ToProcessElement → ToProcessElement → ToProcessElement × Array ToProcessElement
| i, toProcess, elem =>
if h : i < toProcess.size then
let elem' := toProcess.get ⟨i, h⟩;
if (lctx.get! elem.fvarId).index < (lctx.get! elem'.fvarId).index then
pickNextToProcessAux (i+1) (toProcess.set ⟨i, h⟩ elem) elem'
else
s.lctxOutput.mkLambda fvars e;
Except.ok {
levelParams := s.levelParams,
exprs := exprs,
levelClosure := s.levelClosure,
exprClosure := s.exprClosure,
mctx := mctx
}
| EStateM.Result.error ex s => Except.error ex
pickNextToProcessAux (i+1) toProcess elem
else
(elem, toProcess)
def pickNextToProcess? : ClosureM (Option ToProcessElement) := do
lctx ← getLCtx;
s ← get;
if s.toProcess.isEmpty then pure none
else
modifyGet fun s =>
let elem := s.toProcess.back;
let toProcess := s.toProcess.pop;
let (elem, toProcess) := pickNextToProcessAux lctx 0 toProcess elem;
(some elem, { s with toProcess := toProcess })
def pushFVarArg (e : Expr) : ClosureM Unit :=
modify fun s => { s with exprFVarArgs := s.exprFVarArgs.push e }
def pushLocalDecl (newFVarId : FVarId) (userName : Name) (type : Expr) (bi := BinderInfo.default) : ClosureM Unit := do
type ← collectExpr type;
modify fun s => { s with newLocalDecls := s.newLocalDecls.push $ LocalDecl.cdecl (arbitrary _) newFVarId userName type bi }
partial def process : Unit → ClosureM Unit
| _ => do
elem? ← pickNextToProcess?;
match elem? with
| none => pure ()
| some ⟨fvarId, newFVarId⟩ => do
localDecl ← getLocalDecl fvarId;
match localDecl with
| LocalDecl.cdecl _ _ userName type bi => do
pushLocalDecl newFVarId userName type bi;
pushFVarArg (mkFVar fvarId);
process ()
| LocalDecl.ldecl _ _ userName type val _ => do
zetaFVarIds ← getZetaFVarIds;
if !zetaFVarIds.contains fvarId then do
/- Non-dependent let-decl
Recall that if `fvarId` is in `zetaFVarIds`, then we zeta-expanded it
during type checking (see `check` at `collectExpr`).
Our type checker may zeta-expand declarations that are not needed, but this
check is conservative, and seems to work well in practice. -/
pushLocalDecl newFVarId userName type;
pushFVarArg (mkFVar fvarId);
process ()
else do
/- Dependent let-decl -/
type ← collectExpr type;
val ← collectExpr val;
modify fun s => { s with newLetDecls := s.newLetDecls.push $ LocalDecl.ldecl (arbitrary _) newFVarId userName type val false };
/- We don't want to interleave let and lambda declarations in our closure. So, we expand any occurrences of newFVarId
at `newLocalDecls` -/
modify fun s => { s with newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl newFVarId val) };
process ()
@[inline] def mkBinding (isLambda : Bool) (decls : Array LocalDecl) (b : Expr) : Expr :=
let xs := decls.map LocalDecl.toExpr;
let b := b.abstract xs;
decls.size.foldRev
(fun i b =>
let decl := decls.get! i;
match decl with
| LocalDecl.cdecl _ _ n ty bi =>
let ty := ty.abstractRange i xs;
if isLambda then
Lean.mkLambda n bi ty b
else
Lean.mkForall n bi ty b
| LocalDecl.ldecl _ _ n ty val nonDep =>
if b.hasLooseBVar 0 then
let ty := ty.abstractRange i xs;
let val := val.abstractRange i xs;
mkLet n ty val b nonDep
else
b.lowerLooseBVars 1 1)
b
def mkLambda (decls : Array LocalDecl) (b : Expr) : Expr :=
mkBinding true decls b
def mkForall (decls : Array LocalDecl) (b : Expr) : Expr :=
mkBinding false decls b
structure MkValueTypeClosureResult :=
(levelParams : Array Name)
(type : Expr)
(value : Expr)
(levelClosure : Array Level)
(exprClosure : Array Expr)
(mctx : MetavarContext)
(levelParams : Array Name)
(type : Expr)
(value : Expr)
(levelArgs : Array Level)
(exprArgs : Array Expr)
def mkValueTypeClosure (mctx : MetavarContext) (lctx : LocalContext) (type : Expr) (value : Expr) (zeta : Bool := false)
: Except String MkValueTypeClosureResult := do
r ← mkClosure mctx lctx #[ { expr := type, isType := true }, { expr := value, isType := false } ] zeta;
def mkValueTypeClosureAux (type : Expr) (value : Expr) : ClosureM (Expr × Expr) := do
resetZetaFVarIds;
withTrackingZeta do
type ← collectExpr type;
value ← collectExpr value;
process ();
pure (type, value)
def mkValueTypeClosure (type : Expr) (value : Expr) (zeta : Bool) : MetaM MkValueTypeClosureResult := do
((type, value), s) ← ((mkValueTypeClosureAux type value).run { zeta := zeta }).run {};
let newLocalDecls := s.newLocalDecls.reverse ++ s.newLocalDeclsForMVars;
let newLetDecls := s.newLetDecls.reverse;
let type := mkForall newLocalDecls (mkForall newLetDecls type);
let value := mkLambda newLocalDecls (mkLambda newLetDecls value);
pure {
levelParams := r.levelParams,
type := r.exprs.get! 0,
value := r.exprs.get! 1,
levelClosure := r.levelClosure,
exprClosure := r.exprClosure,
mctx := r.mctx
type := type,
value := value,
levelParams := s.levelParams,
levelArgs := s.levelArgs,
exprArgs := s.exprFVarArgs.reverse ++ s.exprMVarArgs
}
end Closure
variables {m : Type → Type} [MonadLiftT MetaM m]
private def mkAuxDefinitionImp (name : Name) (type : Expr) (value : Expr) (zeta : Bool) : MetaM Expr := do
opts ← getOptions;
mctx ← getMCtx;
lctx ← getLCtx;
match Closure.mkValueTypeClosure mctx lctx type value zeta with
| Except.error ex => throwError ex
| Except.ok result => do
env ← getEnv;
let decl := Declaration.defnDecl {
name := name,
lparams := result.levelParams.toList,
type := result.type,
value := result.value,
hints := ReducibilityHints.regular (getMaxHeight env result.value + 1),
isUnsafe := env.hasUnsafe result.type || env.hasUnsafe result.value
};
setMCtx result.mctx;
addAndCompile decl;
pure $ mkAppN (mkConst name result.levelClosure.toList) result.exprClosure
result ← Closure.mkValueTypeClosure type value zeta;
env ← getEnv;
let decl := Declaration.defnDecl {
name := name,
lparams := result.levelParams.toList,
type := result.type,
value := result.value,
hints := ReducibilityHints.regular (getMaxHeight env result.value + 1),
isUnsafe := env.hasUnsafe result.type || env.hasUnsafe result.value
};
trace! `Meta.debug (name ++ " : " ++ result.type ++ " := " ++ result.value);
addAndCompile decl;
pure $ mkAppN (mkConst name result.levelArgs.toList) result.exprArgs
/--
Create an auxiliary definition with the given name, type and value.
@ -251,7 +285,8 @@ match Closure.mkValueTypeClosure mctx lctx type value zeta with
A "closure" is computed, and a term of the form `name.{u_1 ... u_n} t_1 ... t_m` is
returned where `u_i`s are universe parameters and metavariables `type` and `value` depend on,
and `t_j`s are free and meta variables `type` and `value` depend on. -/
def mkAuxDefinition (name : Name) (type : Expr) (value : Expr) (zeta : Bool := false) : m Expr := liftMetaM do
def mkAuxDefinition (name : Name) (type : Expr) (value : Expr) (zeta := false) : m Expr := liftMetaM do
trace! `Meta.debug (name ++ " : " ++ type ++ " := " ++ value);
mkAuxDefinitionImp name type value zeta
/-- Similar to `mkAuxDefinition`, but infers the type of `value`. -/
@ -260,6 +295,5 @@ type ← inferType value;
let type := type.headBeta;
mkAuxDefinition name type value
end Meta
end Lean

View file

@ -16,12 +16,6 @@ namespace Lean
namespace Meta
namespace Match
def replaceFVarIdAtLocalDecl (fvarId : FVarId) (e : Expr) (d : LocalDecl) : LocalDecl :=
if d.fvarId == fvarId then d
else match d with
| LocalDecl.cdecl idx id n type bi => LocalDecl.cdecl idx id n (type.replaceFVarId fvarId e) bi
| LocalDecl.ldecl idx id n type val nonDep => LocalDecl.ldecl idx id n (type.replaceFVarId fvarId e) (val.replaceFVarId fvarId e) nonDep
inductive Pattern : Type
| inaccessible (e : Expr) : Pattern
| var (fvarId : FVarId) : Pattern

View file

@ -2,10 +2,13 @@ import Lean.Meta
open Lean
open Lean.Meta
def fact : Nat → Nat
| 0 => 1
| n+1 => (n+1)*fact n
set_option trace.Meta true
set_option trace.Meta.isDefEq.step false
set_option trace.Meta.isDefEq.delta false
set_option trace.Meta.isDefEq.assign false
set_option trace.Meta.isDefEq false
set_option trace.Meta.check false
def print (msg : MessageData) : MetaM Unit :=
trace! `Meta.debug msg
@ -13,8 +16,8 @@ trace! `Meta.debug msg
def check (x : MetaM Bool) : MetaM Unit :=
unlessM x $ throwError "check failed"
def ex : Nat × Nat :=
let x := 10;
def ex (x_1 x_2 x_3 : Nat) : Nat × Nat :=
let x := fact (10 + x_1 + x_2 + x_3);
let ty := Nat → Nat;
let f : ty := fun x => x;
let n := 20;
@ -31,7 +34,9 @@ lambdaTelescope c.value?.get! fun xs body =>
let ys := ys.toList.map mkFVar;
print ys;
check $ pure $ ys.length == 2;
mkAuxDefinitionFor `foo body;
c ← mkAuxDefinitionFor `foo body;
print c;
Meta.check c;
pure ()
#eval tst1