From 555a3f0dcf69ce1ff5d3c070e2d5e5d1f134bd20 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 3 Sep 2020 17:37:06 -0700 Subject: [PATCH] feat: new and improved `mkAuxDefinition` --- src/Lean/LocalContext.lean | 6 + src/Lean/Meta/Closure.lean | 352 +++++++++++++++------------ src/Lean/Meta/EqnCompiler/Match.lean | 6 - tests/lean/run/meta7.lean | 17 +- 4 files changed, 210 insertions(+), 171 deletions(-) diff --git a/src/Lean/LocalContext.lean b/src/Lean/LocalContext.lean index 6a6197bcca..66ee4d5b79 100644 --- a/src/Lean/LocalContext.lean +++ b/src/Lean/LocalContext.lean @@ -357,4 +357,10 @@ export MonadLCtx (getLCtx) instance monadLCtxTrans (m n) [MonadLCtx m] [MonadLift m n] : MonadLCtx n := { getLCtx := liftM (getLCtx : m _) } +def replaceFVarIdAtLocalDecl (fvarId : FVarId) (e : Expr) (d : LocalDecl) : LocalDecl := +if d.fvarId == fvarId then d +else match d with + | LocalDecl.cdecl idx id n type bi => LocalDecl.cdecl idx id n (type.replaceFVarId fvarId e) bi + | LocalDecl.ldecl idx id n type val nonDep => LocalDecl.ldecl idx id n (type.replaceFVarId fvarId e) (val.replaceFVarId fvarId e) nonDep + end Lean diff --git a/src/Lean/Meta/Closure.lean b/src/Lean/Meta/Closure.lean index 4a00d3a339..a17f1ba69e 100644 --- a/src/Lean/Meta/Closure.lean +++ b/src/Lean/Meta/Closure.lean @@ -8,31 +8,36 @@ import Lean.MetavarContext import Lean.Environment import Lean.Util.FoldConsts import Lean.Meta.Basic +import Lean.Meta.Check namespace Lean namespace Meta - namespace Closure +structure ToProcessElement := +(fvarId : FVarId) (newFVarId : FVarId) + +instance ToProcessElement.inhabited : Inhabited ToProcessElement := +⟨⟨arbitrary _, arbitrary _⟩⟩ + structure Context := -(lctxInput : LocalContext) -(zeta : Bool) -- if `true` let-variables are expanded +(zeta : Bool) structure State := -(mctx : MetavarContext) -(lctxOutput : LocalContext := {}) -(ngen : NameGenerator := { namePrefix := `_closure }) -(visitedLevel : LevelMap Level := {}) -(visitedExpr : ExprStructMap Expr := {}) -(levelParams : Array Name := #[]) -(nextLevelIdx : Nat := 1) -(levelClosure : Array Level := #[]) -(nextExprIdx : Nat := 1) -(exprClosure : Array Expr := #[]) +(visitedLevel : LevelMap Level := {}) +(visitedExpr : ExprStructMap Expr := {}) +(levelParams : Array Name := #[]) +(nextLevelIdx : Nat := 1) +(levelArgs : Array Level := #[]) +(newLocalDecls : Array LocalDecl := #[]) +(newLocalDeclsForMVars : Array LocalDecl := #[]) +(newLetDecls : Array LocalDecl := #[]) +(nextExprIdx : Nat := 1) +(exprMVarArgs : Array Expr := #[]) +(exprFVarArgs : Array Expr := #[]) +(toProcess : Array ToProcessElement := #[]) -def Exception := String - -abbrev ClosureM := ReaderT Context (EStateM Exception State) +abbrev ClosureM := ReaderT Context $ StateRefT State MetaM @[inline] def visitLevel (f : Level → ClosureM Level) (u : Level) : ClosureM Level := if !u.hasMVar && !u.hasParam then pure u @@ -45,20 +50,23 @@ else do modify $ fun s => { s with visitedLevel := s.visitedLevel.insert u v }; pure v +@[inline] def visitExpr (f : Expr → ClosureM Expr) (e : Expr) : ClosureM Expr := +if !e.hasLevelParam && !e.hasFVar && !e.hasMVar then pure e +else do + s ← get; + match s.visitedExpr.find? e with + | some r => pure r + | none => do + r ← f e; + modify $ fun s => { s with visitedExpr := s.visitedExpr.insert e r }; + pure r + def mkNewLevelParam (u : Level) : ClosureM Level := do s ← get; let p := (`u).appendIndexAfter s.nextLevelIdx; -modify $ fun s => { s with levelParams := s.levelParams.push p, nextLevelIdx := s.nextLevelIdx + 1, levelClosure := s.levelClosure.push u }; +modify $ fun s => { s with levelParams := s.levelParams.push p, nextLevelIdx := s.nextLevelIdx + 1, levelArgs := s.levelArgs.push u }; pure $ mkLevelParam p -def getMCtx : ClosureM MetavarContext := do -s ← get; pure s.mctx - -def instantiateMVars (e : Expr) : ClosureM Expr := do -modifyGet fun s => - let (e, mctx) := s.mctx.instantiateMVars e; - (e, { s with mctx := mctx }) - partial def collectLevelAux : Level → ClosureM Level | u@(Level.succ v _) => do v ← visitLevel collectLevelAux v; pure $ u.updateSucc! v | u@(Level.max v w _) => do v ← visitLevel collectLevelAux v; w ← visitLevel collectLevelAux w; pure $ u.updateMax! v w @@ -67,12 +75,17 @@ partial def collectLevelAux : Level → ClosureM Level | u@(Level.param _ _) => mkNewLevelParam u | u@(Level.zero _) => pure u -def collectLevel (u : Level) : ClosureM Level := +def collectLevel (u : Level) : ClosureM Level := do +-- u ← instantiateLevelMVars u; visitLevel collectLevelAux u -instance : MonadNameGenerator ClosureM := -{ getNGen := do s ← get; pure s.ngen, - setNGen := fun ngen => modify fun s => { s with ngen := ngen } } +def preprocess (e : Expr) : ClosureM Expr := do +e ← instantiateMVars e; +ctx ← read; +-- If we are not zeta-expanding let-decls, then we use `check` to find +-- which let-decls are dependent. We say a let-decl is dependent if its lambda abstraction is type incorrect. +when (!ctx.zeta) $ liftM $ check e; +pure e /-- Remark: This method does not guarantee unique user names. @@ -84,32 +97,8 @@ let n := (`_x).appendIndexAfter s.nextExprIdx; modify $ fun s => { s with nextExprIdx := s.nextExprIdx + 1 }; pure n -def getUserName (userName? : Option Name) : ClosureM Name := -match userName? with -| some userName => pure userName -| none => mkNextUserName - -def mkLocalDecl (userName? : Option Name) (type : Expr) (bi : BinderInfo) : ClosureM Expr := do -userName ← getUserName userName?; -fvarId ← mkFreshFVarId; -modify $ fun s => { s with lctxOutput := s.lctxOutput.mkLocalDecl fvarId userName type bi }; -pure $ mkFVar fvarId - -def mkLetDecl (userName : Name) (type : Expr) (value : Expr) (nonDep : Bool) : ClosureM Expr := do -fvarId ← mkFreshFVarId; -modify $ fun s => { s with lctxOutput := s.lctxOutput.mkLetDecl fvarId userName type value nonDep }; -pure $ mkFVar fvarId - -@[inline] def visitExpr (f : Expr → ClosureM Expr) (e : Expr) : ClosureM Expr := -if !e.hasLevelParam && !e.hasFVar && !e.hasMVar then pure e -else do - s ← get; - match s.visitedExpr.find? e with - | some r => pure r - | none => do - r ← f e; - modify $ fun s => { s with visitedExpr := s.visitedExpr.insert e r }; - pure r +def pushToProcess (elem : ToProcessElement) : ClosureM Unit := +modify fun s => { s with toProcess := s.toProcess.push elem } partial def collectExprAux : Expr → ClosureM Expr | e => @@ -124,126 +113,171 @@ partial def collectExprAux : Expr → ClosureM Expr | Expr.sort u _ => do u ← collectLevel u; pure (e.updateSort! u) | Expr.const c us _ => do us ← us.mapM collectLevel; pure (e.updateConst! us) | Expr.mvar mvarId _ => do - mctx ← getMCtx; - match mctx.findDecl? mvarId with - | none => throw "unknown metavariable" - | some mvarDecl => do - type ← instantiateMVars mvarDecl.type; - type ← collect type; - x ← mkLocalDecl none type BinderInfo.default; - modify $ fun s => { s with exprClosure := s.exprClosure.push e }; - pure x + mvarDecl ← getMVarDecl mvarId; + type ← preprocess mvarDecl.type; + type ← collect type; + newFVarId ← mkFreshFVarId; + userName ← mkNextUserName; + modify fun s => { s with + newLocalDeclsForMVars := s.newLocalDeclsForMVars.push $ LocalDecl.cdecl (arbitrary _) newFVarId userName type BinderInfo.default, + exprMVarArgs := s.exprMVarArgs.push e + }; + pure $ mkFVar newFVarId | Expr.fvar fvarId _ => do ctx ← read; - match ctx.lctxInput.find? fvarId with - | none => throw "unknown free variable" - | some (LocalDecl.cdecl _ _ userName type bi) => do - type ← instantiateMVars type; - type ← collect type; - x ← mkLocalDecl userName type bi; - modify $ fun s => { s with exprClosure := s.exprClosure.push e }; - pure x - | some (LocalDecl.ldecl _ _ userName type value nonDep) => - if ctx.zeta then do - value ← instantiateMVars value; - collect value - else do - type ← instantiateMVars type; - type ← collect type; - value ← instantiateMVars value; - value ← collect value; - -- Note that let-declarations do not need to be provided to the closure being constructed. - mkLetDecl userName type value nonDep + decl ← getLocalDecl fvarId; + match ctx.zeta, decl.value? with + | true, some value => do value ← preprocess value; collect value + | _, _ => do + newFVarId ← mkFreshFVarId; + pushToProcess ⟨fvarId, newFVarId⟩; + pure $ mkFVar newFVarId | e => pure e def collectExpr (e : Expr) : ClosureM Expr := do -e ← instantiateMVars e; +e ← preprocess e; visitExpr collectExprAux e -structure ExprToClose := -(expr : Expr) -(isType : Bool) - -instance ExprToClose.inhabited : Inhabited ExprToClose := ⟨⟨arbitrary _, arbitrary _⟩⟩ - -structure MkClosureResult := -(levelParams : Array Name) -(exprs : Array Expr) -(levelClosure : Array Level) -(exprClosure : Array Expr) -(mctx : MetavarContext) - -def mkClosure (mctx : MetavarContext) (lctx : LocalContext) (exprsToClose : Array ExprToClose) (zeta : Bool := false) : Except String MkClosureResult := -let shareCommonExprs : Std.ShareCommonM (Array ExprToClose) := exprsToClose.mapM fun ⟨e, isType⟩ => do { - e ← Std.withShareCommon e; - pure ⟨e, isType⟩ -}; -let exprsToClose := shareCommonExprs.run; -let mkExprs : ClosureM (Array Expr × MetavarContext) := do { - exprs ← exprsToClose.mapM fun ⟨e, _⟩ => collectExpr e; - mctx ← getMCtx; - pure (exprs, mctx) -}; -match (mkExprs { lctxInput := lctx, zeta := zeta }).run { mctx := mctx } with -| EStateM.Result.ok (exprs, mctx) s => - let fvars := s.lctxOutput.getFVars; - let exprs := exprs.mapIdx fun i e => - let isType := (exprsToClose.get! i).isType; - if isType then - s.lctxOutput.mkForall fvars e +partial def pickNextToProcessAux (lctx : LocalContext) + : Nat → Array ToProcessElement → ToProcessElement → ToProcessElement × Array ToProcessElement +| i, toProcess, elem => + if h : i < toProcess.size then + let elem' := toProcess.get ⟨i, h⟩; + if (lctx.get! elem.fvarId).index < (lctx.get! elem'.fvarId).index then + pickNextToProcessAux (i+1) (toProcess.set ⟨i, h⟩ elem) elem' else - s.lctxOutput.mkLambda fvars e; - Except.ok { - levelParams := s.levelParams, - exprs := exprs, - levelClosure := s.levelClosure, - exprClosure := s.exprClosure, - mctx := mctx - } -| EStateM.Result.error ex s => Except.error ex + pickNextToProcessAux (i+1) toProcess elem + else + (elem, toProcess) + +def pickNextToProcess? : ClosureM (Option ToProcessElement) := do +lctx ← getLCtx; +s ← get; +if s.toProcess.isEmpty then pure none +else + modifyGet fun s => + let elem := s.toProcess.back; + let toProcess := s.toProcess.pop; + let (elem, toProcess) := pickNextToProcessAux lctx 0 toProcess elem; + (some elem, { s with toProcess := toProcess }) + +def pushFVarArg (e : Expr) : ClosureM Unit := +modify fun s => { s with exprFVarArgs := s.exprFVarArgs.push e } + +def pushLocalDecl (newFVarId : FVarId) (userName : Name) (type : Expr) (bi := BinderInfo.default) : ClosureM Unit := do +type ← collectExpr type; +modify fun s => { s with newLocalDecls := s.newLocalDecls.push $ LocalDecl.cdecl (arbitrary _) newFVarId userName type bi } + +partial def process : Unit → ClosureM Unit +| _ => do + elem? ← pickNextToProcess?; + match elem? with + | none => pure () + | some ⟨fvarId, newFVarId⟩ => do + localDecl ← getLocalDecl fvarId; + match localDecl with + | LocalDecl.cdecl _ _ userName type bi => do + pushLocalDecl newFVarId userName type bi; + pushFVarArg (mkFVar fvarId); + process () + | LocalDecl.ldecl _ _ userName type val _ => do + zetaFVarIds ← getZetaFVarIds; + if !zetaFVarIds.contains fvarId then do + /- Non-dependent let-decl + + Recall that if `fvarId` is in `zetaFVarIds`, then we zeta-expanded it + during type checking (see `check` at `collectExpr`). + + Our type checker may zeta-expand declarations that are not needed, but this + check is conservative, and seems to work well in practice. -/ + pushLocalDecl newFVarId userName type; + pushFVarArg (mkFVar fvarId); + process () + else do + /- Dependent let-decl -/ + type ← collectExpr type; + val ← collectExpr val; + modify fun s => { s with newLetDecls := s.newLetDecls.push $ LocalDecl.ldecl (arbitrary _) newFVarId userName type val false }; + /- We don't want to interleave let and lambda declarations in our closure. So, we expand any occurrences of newFVarId + at `newLocalDecls` -/ + modify fun s => { s with newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl newFVarId val) }; + process () + +@[inline] def mkBinding (isLambda : Bool) (decls : Array LocalDecl) (b : Expr) : Expr := +let xs := decls.map LocalDecl.toExpr; +let b := b.abstract xs; +decls.size.foldRev + (fun i b => + let decl := decls.get! i; + match decl with + | LocalDecl.cdecl _ _ n ty bi => + let ty := ty.abstractRange i xs; + if isLambda then + Lean.mkLambda n bi ty b + else + Lean.mkForall n bi ty b + | LocalDecl.ldecl _ _ n ty val nonDep => + if b.hasLooseBVar 0 then + let ty := ty.abstractRange i xs; + let val := val.abstractRange i xs; + mkLet n ty val b nonDep + else + b.lowerLooseBVars 1 1) + b + +def mkLambda (decls : Array LocalDecl) (b : Expr) : Expr := +mkBinding true decls b + +def mkForall (decls : Array LocalDecl) (b : Expr) : Expr := +mkBinding false decls b structure MkValueTypeClosureResult := -(levelParams : Array Name) -(type : Expr) -(value : Expr) -(levelClosure : Array Level) -(exprClosure : Array Expr) -(mctx : MetavarContext) +(levelParams : Array Name) +(type : Expr) +(value : Expr) +(levelArgs : Array Level) +(exprArgs : Array Expr) -def mkValueTypeClosure (mctx : MetavarContext) (lctx : LocalContext) (type : Expr) (value : Expr) (zeta : Bool := false) - : Except String MkValueTypeClosureResult := do -r ← mkClosure mctx lctx #[ { expr := type, isType := true }, { expr := value, isType := false } ] zeta; +def mkValueTypeClosureAux (type : Expr) (value : Expr) : ClosureM (Expr × Expr) := do +resetZetaFVarIds; +withTrackingZeta do + type ← collectExpr type; + value ← collectExpr value; + process (); + pure (type, value) + +def mkValueTypeClosure (type : Expr) (value : Expr) (zeta : Bool) : MetaM MkValueTypeClosureResult := do +((type, value), s) ← ((mkValueTypeClosureAux type value).run { zeta := zeta }).run {}; +let newLocalDecls := s.newLocalDecls.reverse ++ s.newLocalDeclsForMVars; +let newLetDecls := s.newLetDecls.reverse; +let type := mkForall newLocalDecls (mkForall newLetDecls type); +let value := mkLambda newLocalDecls (mkLambda newLetDecls value); pure { - levelParams := r.levelParams, - type := r.exprs.get! 0, - value := r.exprs.get! 1, - levelClosure := r.levelClosure, - exprClosure := r.exprClosure, - mctx := r.mctx + type := type, + value := value, + levelParams := s.levelParams, + levelArgs := s.levelArgs, + exprArgs := s.exprFVarArgs.reverse ++ s.exprMVarArgs } + end Closure variables {m : Type → Type} [MonadLiftT MetaM m] private def mkAuxDefinitionImp (name : Name) (type : Expr) (value : Expr) (zeta : Bool) : MetaM Expr := do -opts ← getOptions; -mctx ← getMCtx; -lctx ← getLCtx; -match Closure.mkValueTypeClosure mctx lctx type value zeta with -| Except.error ex => throwError ex -| Except.ok result => do - env ← getEnv; - let decl := Declaration.defnDecl { - name := name, - lparams := result.levelParams.toList, - type := result.type, - value := result.value, - hints := ReducibilityHints.regular (getMaxHeight env result.value + 1), - isUnsafe := env.hasUnsafe result.type || env.hasUnsafe result.value - }; - setMCtx result.mctx; - addAndCompile decl; - pure $ mkAppN (mkConst name result.levelClosure.toList) result.exprClosure +result ← Closure.mkValueTypeClosure type value zeta; +env ← getEnv; +let decl := Declaration.defnDecl { + name := name, + lparams := result.levelParams.toList, + type := result.type, + value := result.value, + hints := ReducibilityHints.regular (getMaxHeight env result.value + 1), + isUnsafe := env.hasUnsafe result.type || env.hasUnsafe result.value +}; +trace! `Meta.debug (name ++ " : " ++ result.type ++ " := " ++ result.value); +addAndCompile decl; +pure $ mkAppN (mkConst name result.levelArgs.toList) result.exprArgs /-- Create an auxiliary definition with the given name, type and value. @@ -251,7 +285,8 @@ match Closure.mkValueTypeClosure mctx lctx type value zeta with A "closure" is computed, and a term of the form `name.{u_1 ... u_n} t_1 ... t_m` is returned where `u_i`s are universe parameters and metavariables `type` and `value` depend on, and `t_j`s are free and meta variables `type` and `value` depend on. -/ -def mkAuxDefinition (name : Name) (type : Expr) (value : Expr) (zeta : Bool := false) : m Expr := liftMetaM do +def mkAuxDefinition (name : Name) (type : Expr) (value : Expr) (zeta := false) : m Expr := liftMetaM do +trace! `Meta.debug (name ++ " : " ++ type ++ " := " ++ value); mkAuxDefinitionImp name type value zeta /-- Similar to `mkAuxDefinition`, but infers the type of `value`. -/ @@ -260,6 +295,5 @@ type ← inferType value; let type := type.headBeta; mkAuxDefinition name type value - end Meta end Lean diff --git a/src/Lean/Meta/EqnCompiler/Match.lean b/src/Lean/Meta/EqnCompiler/Match.lean index 147facf4c4..83dc197337 100644 --- a/src/Lean/Meta/EqnCompiler/Match.lean +++ b/src/Lean/Meta/EqnCompiler/Match.lean @@ -16,12 +16,6 @@ namespace Lean namespace Meta namespace Match -def replaceFVarIdAtLocalDecl (fvarId : FVarId) (e : Expr) (d : LocalDecl) : LocalDecl := -if d.fvarId == fvarId then d -else match d with - | LocalDecl.cdecl idx id n type bi => LocalDecl.cdecl idx id n (type.replaceFVarId fvarId e) bi - | LocalDecl.ldecl idx id n type val nonDep => LocalDecl.ldecl idx id n (type.replaceFVarId fvarId e) (val.replaceFVarId fvarId e) nonDep - inductive Pattern : Type | inaccessible (e : Expr) : Pattern | var (fvarId : FVarId) : Pattern diff --git a/tests/lean/run/meta7.lean b/tests/lean/run/meta7.lean index 862a4eb0c9..e644f8d4f1 100644 --- a/tests/lean/run/meta7.lean +++ b/tests/lean/run/meta7.lean @@ -2,10 +2,13 @@ import Lean.Meta open Lean open Lean.Meta +def fact : Nat → Nat +| 0 => 1 +| n+1 => (n+1)*fact n + set_option trace.Meta true -set_option trace.Meta.isDefEq.step false -set_option trace.Meta.isDefEq.delta false -set_option trace.Meta.isDefEq.assign false +set_option trace.Meta.isDefEq false +set_option trace.Meta.check false def print (msg : MessageData) : MetaM Unit := trace! `Meta.debug msg @@ -13,8 +16,8 @@ trace! `Meta.debug msg def check (x : MetaM Bool) : MetaM Unit := unlessM x $ throwError "check failed" -def ex : Nat × Nat := -let x := 10; +def ex (x_1 x_2 x_3 : Nat) : Nat × Nat := +let x := fact (10 + x_1 + x_2 + x_3); let ty := Nat → Nat; let f : ty := fun x => x; let n := 20; @@ -31,7 +34,9 @@ lambdaTelescope c.value?.get! fun xs body => let ys := ys.toList.map mkFVar; print ys; check $ pure $ ys.length == 2; - mkAuxDefinitionFor `foo body; + c ← mkAuxDefinitionFor `foo body; + print c; + Meta.check c; pure () #eval tst1