perf: create unfolding theorem for wf-rec in one go (#9646)
This PR uses a more simple approach to proving the unfolding theorem for a function defined by well-founded recursion. Instead of looping a bunch of tactics, it uses simp in single-pass mode to (try to) exactly undo the changes done in `WF.Fix`, using a dedicated theorem that pushes the extra argument in for each matcher (or `casesOn`). Improves performance for recursive functions with large `match` statements, as in #9598.
This commit is contained in:
parent
b60f97cc19
commit
df9ca20339
6 changed files with 208 additions and 64 deletions
|
|
@ -141,28 +141,29 @@ private def betaReduceLetRecApps (preDefs : Array PreDefinition) : MetaM (Array
|
|||
|
||||
private def addSorried (preDefs : Array PreDefinition) : TermElabM Unit := do
|
||||
for preDef in preDefs do
|
||||
let value ← mkSorry (synthetic := true) preDef.type
|
||||
let decl := if preDef.kind.isTheorem then
|
||||
Declaration.thmDecl {
|
||||
name := preDef.declName,
|
||||
levelParams := preDef.levelParams,
|
||||
type := preDef.type,
|
||||
value
|
||||
}
|
||||
else
|
||||
Declaration.defnDecl {
|
||||
name := preDef.declName,
|
||||
levelParams := preDef.levelParams,
|
||||
type := preDef.type,
|
||||
hints := .abbrev
|
||||
safety := .safe
|
||||
value
|
||||
}
|
||||
addDecl decl
|
||||
withSaveInfoContext do -- save new env
|
||||
addTermInfo' preDef.ref (← mkConstWithLevelParams preDef.declName) (isBinder := true)
|
||||
applyAttributesOf #[preDef] AttributeApplicationTime.afterTypeChecking
|
||||
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
|
||||
unless (← hasConst preDef.declName) do
|
||||
let value ← mkSorry (synthetic := true) preDef.type
|
||||
let decl := if preDef.kind.isTheorem then
|
||||
Declaration.thmDecl {
|
||||
name := preDef.declName,
|
||||
levelParams := preDef.levelParams,
|
||||
type := preDef.type,
|
||||
value
|
||||
}
|
||||
else
|
||||
Declaration.defnDecl {
|
||||
name := preDef.declName,
|
||||
levelParams := preDef.levelParams,
|
||||
type := preDef.type,
|
||||
hints := .abbrev
|
||||
safety := .safe
|
||||
value
|
||||
}
|
||||
addDecl decl
|
||||
withSaveInfoContext do -- save new env
|
||||
addTermInfo' preDef.ref (← mkConstWithLevelParams preDef.declName) (isBinder := true)
|
||||
applyAttributesOf #[preDef] AttributeApplicationTime.afterTypeChecking
|
||||
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
|
||||
|
||||
def ensureFunIndReservedNamesAvailable (preDefs : Array PreDefinition) : MetaM Unit := do
|
||||
preDefs.forM fun preDef =>
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ public import Lean.Meta.Tactic.Split
|
|||
public import Lean.Elab.PreDefinition.Basic
|
||||
public import Lean.Elab.PreDefinition.Eqns
|
||||
public import Lean.Meta.ArgsPacker.Basic
|
||||
public import Lean.Elab.PreDefinition.WF.Unfold
|
||||
public import Lean.Elab.PreDefinition.FixedParams
|
||||
public import Init.Data.Array.Basic
|
||||
|
||||
|
|
|
|||
|
|
@ -1,23 +1,34 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
Authors: Leonardo de Moura, Joachim Breitner
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Lean.Elab.PreDefinition.Basic
|
||||
public import Lean.Elab.PreDefinition.Eqns
|
||||
public import Lean.Meta.Tactic.Apply
|
||||
import Lean.Elab.PreDefinition.Eqns
|
||||
import Lean.Meta.Tactic.Apply
|
||||
import Lean.Meta.Tactic.Split
|
||||
public import Lean.Meta.Tactic.Simp.Types
|
||||
import Lean.Meta.Tactic.Simp.Main
|
||||
import Lean.Meta.Tactic.Simp.BuiltinSimprocs
|
||||
|
||||
public section
|
||||
/-!
|
||||
This module is responsible for proving the unfolding equation for functions defined
|
||||
by well-founded recursion. It uses `WellFounded.fix_eq`, and then has to undo
|
||||
the changes to matchers that `WF.Fix` did using `MatcherApp.addArg`.
|
||||
|
||||
This is done using a single-pass `simp` traversal of the expression that looks
|
||||
for expressions that were modified that way, and rewrites them back using the
|
||||
rather specialized `_arg_pusher` theorem that is generated by `mkMatchArgPusher`.
|
||||
-/
|
||||
|
||||
namespace Lean.Elab.WF
|
||||
open Meta
|
||||
open Eqns
|
||||
|
||||
private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
|
||||
def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
|
||||
let target ← mvarId.getType'
|
||||
let some (_, lhs, rhs) := target.eq? | unreachable!
|
||||
|
||||
|
|
@ -43,44 +54,157 @@ private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
|
|||
mvarId.assign (← mkEqTrans h mvarNew)
|
||||
return mvarNew.mvarId!
|
||||
|
||||
private partial def mkUnfoldProof (declName : Name) (mvarId : MVarId) : MetaM Unit := do
|
||||
trace[Elab.definition.wf.eqns] "step\n{MessageData.ofGoal mvarId}"
|
||||
if ← withAtLeastTransparency .all (tryURefl mvarId) then
|
||||
trace[Elab.definition.wf.eqns] "refl!"
|
||||
return ()
|
||||
else if (← tryContradiction mvarId) then
|
||||
trace[Elab.definition.wf.eqns] "contradiction!"
|
||||
return ()
|
||||
else if let some mvarId ← simpMatch? mvarId then
|
||||
trace[Elab.definition.wf.eqns] "simpMatch!"
|
||||
mkUnfoldProof declName mvarId
|
||||
else if let some mvarId ← simpIf? mvarId (useNewSemantics := true) then
|
||||
trace[Elab.definition.wf.eqns] "simpIf!"
|
||||
mkUnfoldProof declName mvarId
|
||||
else
|
||||
let ctx ← Simp.mkContext (config := { dsimp := false, etaStruct := .none })
|
||||
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
|
||||
| TacticResultCNM.closed => return ()
|
||||
| TacticResultCNM.modified mvarId =>
|
||||
trace[Elab.definition.wf.eqns] "simp only!"
|
||||
mkUnfoldProof declName mvarId
|
||||
| TacticResultCNM.noChange =>
|
||||
if let some mvarIds ← casesOnStuckLHS? mvarId then
|
||||
trace[Elab.definition.wf.eqns] "case split into {mvarIds.size} goals"
|
||||
mvarIds.forM (mkUnfoldProof declName)
|
||||
else if let some mvarIds ← splitTarget? mvarId (useNewSemantics := true) then
|
||||
trace[Elab.definition.wf.eqns] "splitTarget into {mvarIds.length} goals"
|
||||
mvarIds.forM (mkUnfoldProof declName)
|
||||
else
|
||||
-- At some point in the past, we looked for occurrences of Wf.fix to fold on the
|
||||
-- LHS (introduced in 096e4eb), but it seems that code path was never used,
|
||||
-- so #3133 removed it again (and can be recovered from there if this was premature).
|
||||
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
|
||||
def isForallMotive (matcherApp : MatcherApp) : MetaM (Option Expr) := do
|
||||
lambdaBoundedTelescope matcherApp.motive matcherApp.discrs.size fun xs t =>
|
||||
if xs.size == matcherApp.discrs.size && t.isForall && !t.bindingBody!.hasLooseBVar 0 then
|
||||
return some (← mkLambdaFVars xs t.bindingBody!)
|
||||
else
|
||||
return none
|
||||
|
||||
def mkUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) (wfPreprocessProof : Simp.Result) : MetaM Unit := do
|
||||
|
||||
/-- Generalization of `splitMatch` that can handle `casesOn` -/
|
||||
def splitMatchOrCasesOn (mvarId : MVarId) (e : Expr) (matcherInfo : MatcherInfo) : MetaM (List MVarId) := do
|
||||
if (← isMatcherApp e) then
|
||||
Split.splitMatch mvarId e
|
||||
else
|
||||
assert! matcherInfo.numDiscrs = 1
|
||||
let discr := e.getAppArgs[matcherInfo.numParams + 1]!
|
||||
assert! discr.isFVar
|
||||
let subgoals ← mvarId.cases discr.fvarId!
|
||||
return subgoals.map (·.mvarId) |>.toList
|
||||
|
||||
/--
|
||||
Generates a theorem of the form
|
||||
```
|
||||
matcherArgPusher params motive {α} {β} (f : ∀ (x : α), β x) rel alt1 .. x1 x2
|
||||
:
|
||||
matcher params (motive := fun x1 x2 => ((y : α) → rel x1 x2 y → β y) → motive x1 x2)
|
||||
(alt1 := fun z1 z2 z2 f => alt1 z1 z2 z2 f) …
|
||||
x1 x2
|
||||
(fun y _h => f y)
|
||||
=
|
||||
matcher params (motive := motive)
|
||||
(alt1 := fun z1 z2 z2 => alt1 z1 z2 z2 (fun y _ => f y)) …
|
||||
x1 x2
|
||||
```
|
||||
-/
|
||||
def mkMatchArgPusher (matcherName : Name) (matcherInfo : MatcherInfo) : MetaM Name := do
|
||||
let name := (mkPrivateName (← getEnv) matcherName) ++ `_arg_pusher
|
||||
realizeConst matcherName name do
|
||||
let matcherVal ← getConstVal matcherName
|
||||
forallBoundedTelescope matcherVal.type (some (matcherInfo.numParams + 1)) fun xs _ => do
|
||||
let params := xs[*...matcherInfo.numParams]
|
||||
let motive' := xs[matcherInfo.numParams]!
|
||||
let u ← mkFreshUserName `u
|
||||
let v ← mkFreshUserName `v
|
||||
withLocalDeclD `α (.sort (.param u)) fun alpha => do
|
||||
withLocalDeclD `β (← mkArrow alpha (.sort (.param v))) fun beta => do
|
||||
withLocalDeclD `f (.forallE `x alpha (mkApp beta (.bvar 0)) .default) fun f => do
|
||||
let relType ← forallTelescope (← inferType motive') fun xs _ =>
|
||||
mkForallFVars xs (.forallE `x alpha (.sort 0) .default)
|
||||
withLocalDeclD `rel relType fun rel => do
|
||||
let motive ← forallTelescope (← inferType motive') fun xs _ => do
|
||||
let motiveBody := mkAppN motive' xs
|
||||
let extraArgType := .forallE `y alpha (.forallE `h (mkAppN rel (xs.push (.bvar 0))) (mkApp beta (.bvar 1)) .default) .default
|
||||
let motiveBody ← mkArrow extraArgType motiveBody
|
||||
mkLambdaFVars xs motiveBody
|
||||
|
||||
let uElim ← lambdaBoundedTelescope motive matcherInfo.numDiscrs fun _ motiveBody => do
|
||||
getLevel motiveBody
|
||||
let us := matcherVal.levelParams ++ [u, v]
|
||||
let matcherLevels' := matcherVal.levelParams.map mkLevelParam
|
||||
let matcherLevels ← match matcherInfo.uElimPos? with
|
||||
| none =>
|
||||
unless uElim.isZero do
|
||||
throwError "unexpected matcher application for {.ofConstName matcherName}, motive is not a proposition"
|
||||
pure matcherLevels'
|
||||
| some pos =>
|
||||
pure <| (matcherLevels'.toArray.set! pos uElim).toList
|
||||
let lhs := .const matcherName matcherLevels
|
||||
let rhs := .const matcherName matcherLevels'
|
||||
let lhs := mkAppN lhs params
|
||||
let rhs := mkAppN rhs params
|
||||
let lhs := mkApp lhs motive
|
||||
let rhs := mkApp rhs motive'
|
||||
forallBoundedTelescope (← inferType lhs) matcherInfo.numDiscrs fun discrs _ => do
|
||||
let lhs := mkAppN lhs discrs
|
||||
let rhs := mkAppN rhs discrs
|
||||
forallBoundedTelescope (← inferType lhs) matcherInfo.numAlts fun alts _ => do
|
||||
let lhs := mkAppN lhs alts
|
||||
|
||||
let mut rhs := rhs
|
||||
for alt in alts, altNumParams in matcherInfo.altNumParams do
|
||||
let alt' ← forallBoundedTelescope (← inferType alt) altNumParams fun ys altBodyType => do
|
||||
assert! altBodyType.isForall
|
||||
let altArg ← forallBoundedTelescope altBodyType.bindingDomain! (some 2) fun ys _ => do
|
||||
mkLambdaFVars ys (.app f ys[0]!)
|
||||
mkLambdaFVars ys (mkAppN alt (ys.push altArg))
|
||||
rhs := mkApp rhs alt'
|
||||
|
||||
let extraArg := .lam `y alpha (.lam `h (mkAppN rel (discrs.push (.bvar 0))) (mkApp f (.bvar 1)) .default) .default
|
||||
let lhs := mkApp lhs extraArg
|
||||
let goal ← mkEq lhs rhs
|
||||
|
||||
let value ← mkFreshExprSyntheticOpaqueMVar goal
|
||||
let mvarId := value.mvarId!
|
||||
let mvarIds ← splitMatchOrCasesOn mvarId rhs matcherInfo
|
||||
for mvarId in mvarIds do
|
||||
mvarId.refl
|
||||
let value ← instantiateMVars value
|
||||
let type ← mkForallFVars (params ++ #[motive', alpha, beta, f, rel] ++ discrs ++ alts) goal
|
||||
let value ← mkLambdaFVars (params ++ #[motive', alpha, beta, f, rel] ++ discrs ++ alts) value
|
||||
addDecl <| Declaration.thmDecl { name, levelParams := us, type, value}
|
||||
return name
|
||||
|
||||
builtin_simproc_decl matcherPushArg (_) := fun e => do
|
||||
let e := e.headBeta
|
||||
let some matcherApp ← matchMatcherApp? e (alsoCasesOn := true) | return .continue
|
||||
-- Check that the first remaining argument is of the form `(fun (x : α) p => (f x : β x))`
|
||||
let some fArg := matcherApp.remaining[0]? | return .continue
|
||||
unless fArg.isLambda do return .continue
|
||||
unless fArg.bindingBody!.isLambda do return .continue
|
||||
unless fArg.bindingBody!.bindingBody!.isApp do return .continue
|
||||
if fArg.bindingBody!.bindingBody!.hasLooseBVar 0 then return .continue
|
||||
unless fArg.bindingBody!.bindingBody!.appArg! == .bvar 1 do return .continue
|
||||
if fArg.bindingBody!.bindingBody!.appFn!.hasLooseBVar 1 then return .continue
|
||||
|
||||
let fExpr := fArg.bindingBody!.bindingBody!.appFn!
|
||||
let fExprType ← inferType fExpr
|
||||
let fExprType ← withTransparency .all (whnfForall fExprType)
|
||||
assert! fExprType.isForall
|
||||
let alpha := fExprType.bindingDomain!
|
||||
let beta := .lam fExprType.bindingName! fExprType.bindingDomain! fExprType.bindingBody! .default
|
||||
|
||||
-- Check that the motive has an extra parameter (from MatcherApp.addArg)
|
||||
let some motive' ← isForallMotive matcherApp | return .continue
|
||||
let rel ← lambdaTelescope matcherApp.motive fun xs motiveBody =>
|
||||
let motiveBodyArg := motiveBody.bindingDomain!
|
||||
mkLambdaFVars xs (.lam motiveBodyArg.bindingName! motiveBodyArg.bindingDomain! motiveBodyArg.bindingBody!.bindingDomain! .default)
|
||||
|
||||
let argPusher ← mkMatchArgPusher matcherApp.matcherName matcherApp.toMatcherInfo
|
||||
-- Let's infer the level paramters:
|
||||
let proof ← withTransparency .all <| mkAppOptM
|
||||
argPusher ((matcherApp.params ++ #[motive', alpha, beta, fExpr, rel] ++ matcherApp.discrs ++ matcherApp.alts).map some)
|
||||
let some (_, _, rhs) := (← inferType proof).eq? | throwError "matcherPushArg: expected equality:{indentExpr (← inferType proof)}"
|
||||
let step : Simp.Result := { expr := rhs, proof? := some proof }
|
||||
let step ← step.addExtraArgs matcherApp.remaining[1...*]
|
||||
return .continue (some step)
|
||||
|
||||
def mkUnfoldProof (declName : Name) (mvarId : MVarId) : MetaM Unit := withTransparency .all do
|
||||
let ctx ← Simp.mkContext (config := { dsimp := false, etaStruct := .none, letToHave := false, singlePass := true })
|
||||
let simprocs := ({} : Simp.SimprocsArray)
|
||||
let simprocs ← simprocs.add ``matcherPushArg (post := false)
|
||||
match (← simpTarget mvarId ctx (simprocs := simprocs)).1 with
|
||||
| none => return ()
|
||||
| some mvarId' =>
|
||||
prependError m!"failed to finish proof for equational theorem for '{.ofConstName declName}'" do
|
||||
mvarId'.refl
|
||||
|
||||
public def mkUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) (wfPreprocessProof : Simp.Result) : MetaM Unit := do
|
||||
let name := mkEqLikeNameFor (← getEnv) preDef.declName unfoldThmSuffix
|
||||
prependError m!"Cannot derive {name}" do
|
||||
prependError m!"Cannot derive unfold equation {name}" do
|
||||
withOptions (tactic.hygienic.set · false) do
|
||||
withoutExporting do
|
||||
lambdaTelescope preDef.value fun xs body => do
|
||||
let us := preDef.levelParams.map mkLevelParam
|
||||
let lhs := mkAppN (Lean.mkConst preDef.declName us) xs
|
||||
|
|
@ -111,7 +235,7 @@ theorem of `foo._unary` or `foo._binary`.
|
|||
|
||||
It should just be a specialization of that one, due to defeq.
|
||||
-/
|
||||
def mkBinaryUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) : MetaM Unit := do
|
||||
public def mkBinaryUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) : MetaM Unit := do
|
||||
let name := mkEqLikeNameFor (← getEnv) preDef.declName unfoldThmSuffix
|
||||
let unaryEqName:= mkEqLikeNameFor (← getEnv) unaryPreDefName unfoldThmSuffix
|
||||
prependError m!"Cannot derive {name} from {unaryEqName}" do
|
||||
|
|
|
|||
|
|
@ -23,5 +23,6 @@ def mkCasesOn (declName : Name) : MetaM Unit := do
|
|||
addDecl decl
|
||||
setReducibleAttribute name
|
||||
modifyEnv fun env => markAuxRecursor env name
|
||||
enableRealizationsForConst name
|
||||
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ of matcher applications.
|
|||
-/
|
||||
def matchMatcherApp? [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (alsoCasesOn := false) :
|
||||
m (Option MatcherApp) := do
|
||||
unless e.isApp do return none
|
||||
if let .const declName declLevels := e.getAppFn then
|
||||
if let some info ← getMatcherInfo? declName then
|
||||
let args := e.getAppArgs
|
||||
|
|
@ -74,6 +75,13 @@ def matchMatcherApp? [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (alsoCases
|
|||
|
||||
return none
|
||||
|
||||
def MatcherApp.toMatcherInfo (matcherApp : MatcherApp) : MatcherInfo where
|
||||
uElimPos? := matcherApp.uElimPos?
|
||||
discrInfos := matcherApp.discrInfos
|
||||
numParams := matcherApp.params.size
|
||||
numDiscrs := matcherApp.discrs.size
|
||||
altNumParams := matcherApp.altNumParams
|
||||
|
||||
def MatcherApp.toExpr (matcherApp : MatcherApp) : Expr :=
|
||||
let result := mkAppN (mkConst matcherApp.matcherName matcherApp.matcherLevels.toList) matcherApp.params
|
||||
let result := mkApp result matcherApp.motive
|
||||
|
|
|
|||
11
tests/lean/run/issue9646.lean
Normal file
11
tests/lean/run/issue9646.lean
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
/-!
|
||||
Checks that that the wfrec unfold theorem can be generated even if the
|
||||
function type is not manifestly a forall.
|
||||
-/
|
||||
|
||||
def T := Nat → Nat
|
||||
|
||||
def f : T
|
||||
| 0 => 0
|
||||
| n + 1 => f n + 1
|
||||
termination_by n => n
|
||||
Loading…
Add table
Reference in a new issue