perf: create unfolding theorem for wf-rec in one go (#9646)

This PR uses a more simple approach to proving the unfolding theorem for
a function defined by well-founded recursion. Instead of looping a bunch
of tactics, it uses simp in single-pass mode to (try to) exactly undo
the changes done in `WF.Fix`, using a dedicated theorem that pushes the
extra argument in for each matcher (or `casesOn`).

Improves performance for recursive functions with large `match`
statements, as in #9598.
This commit is contained in:
Joachim Breitner 2025-08-02 17:26:02 +02:00 committed by GitHub
parent b60f97cc19
commit df9ca20339
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 208 additions and 64 deletions

View file

@ -141,28 +141,29 @@ private def betaReduceLetRecApps (preDefs : Array PreDefinition) : MetaM (Array
private def addSorried (preDefs : Array PreDefinition) : TermElabM Unit := do
for preDef in preDefs do
let value ← mkSorry (synthetic := true) preDef.type
let decl := if preDef.kind.isTheorem then
Declaration.thmDecl {
name := preDef.declName,
levelParams := preDef.levelParams,
type := preDef.type,
value
}
else
Declaration.defnDecl {
name := preDef.declName,
levelParams := preDef.levelParams,
type := preDef.type,
hints := .abbrev
safety := .safe
value
}
addDecl decl
withSaveInfoContext do -- save new env
addTermInfo' preDef.ref (← mkConstWithLevelParams preDef.declName) (isBinder := true)
applyAttributesOf #[preDef] AttributeApplicationTime.afterTypeChecking
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
unless (← hasConst preDef.declName) do
let value ← mkSorry (synthetic := true) preDef.type
let decl := if preDef.kind.isTheorem then
Declaration.thmDecl {
name := preDef.declName,
levelParams := preDef.levelParams,
type := preDef.type,
value
}
else
Declaration.defnDecl {
name := preDef.declName,
levelParams := preDef.levelParams,
type := preDef.type,
hints := .abbrev
safety := .safe
value
}
addDecl decl
withSaveInfoContext do -- save new env
addTermInfo' preDef.ref (← mkConstWithLevelParams preDef.declName) (isBinder := true)
applyAttributesOf #[preDef] AttributeApplicationTime.afterTypeChecking
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
def ensureFunIndReservedNamesAvailable (preDefs : Array PreDefinition) : MetaM Unit := do
preDefs.forM fun preDef =>

View file

@ -11,7 +11,6 @@ public import Lean.Meta.Tactic.Split
public import Lean.Elab.PreDefinition.Basic
public import Lean.Elab.PreDefinition.Eqns
public import Lean.Meta.ArgsPacker.Basic
public import Lean.Elab.PreDefinition.WF.Unfold
public import Lean.Elab.PreDefinition.FixedParams
public import Init.Data.Array.Basic

View file

@ -1,23 +1,34 @@
/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
Authors: Leonardo de Moura, Joachim Breitner
-/
module
prelude
public import Lean.Elab.PreDefinition.Basic
public import Lean.Elab.PreDefinition.Eqns
public import Lean.Meta.Tactic.Apply
import Lean.Elab.PreDefinition.Eqns
import Lean.Meta.Tactic.Apply
import Lean.Meta.Tactic.Split
public import Lean.Meta.Tactic.Simp.Types
import Lean.Meta.Tactic.Simp.Main
import Lean.Meta.Tactic.Simp.BuiltinSimprocs
public section
/-!
This module is responsible for proving the unfolding equation for functions defined
by well-founded recursion. It uses `WellFounded.fix_eq`, and then has to undo
the changes to matchers that `WF.Fix` did using `MatcherApp.addArg`.
This is done using a single-pass `simp` traversal of the expression that looks
for expressions that were modified that way, and rewrites them back using the
rather specialized `_arg_pusher` theorem that is generated by `mkMatchArgPusher`.
-/
namespace Lean.Elab.WF
open Meta
open Eqns
private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
let target ← mvarId.getType'
let some (_, lhs, rhs) := target.eq? | unreachable!
@ -43,44 +54,157 @@ private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
mvarId.assign (← mkEqTrans h mvarNew)
return mvarNew.mvarId!
private partial def mkUnfoldProof (declName : Name) (mvarId : MVarId) : MetaM Unit := do
trace[Elab.definition.wf.eqns] "step\n{MessageData.ofGoal mvarId}"
if ← withAtLeastTransparency .all (tryURefl mvarId) then
trace[Elab.definition.wf.eqns] "refl!"
return ()
else if (← tryContradiction mvarId) then
trace[Elab.definition.wf.eqns] "contradiction!"
return ()
else if let some mvarId ← simpMatch? mvarId then
trace[Elab.definition.wf.eqns] "simpMatch!"
mkUnfoldProof declName mvarId
else if let some mvarId ← simpIf? mvarId (useNewSemantics := true) then
trace[Elab.definition.wf.eqns] "simpIf!"
mkUnfoldProof declName mvarId
else
let ctx ← Simp.mkContext (config := { dsimp := false, etaStruct := .none })
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
| TacticResultCNM.closed => return ()
| TacticResultCNM.modified mvarId =>
trace[Elab.definition.wf.eqns] "simp only!"
mkUnfoldProof declName mvarId
| TacticResultCNM.noChange =>
if let some mvarIds ← casesOnStuckLHS? mvarId then
trace[Elab.definition.wf.eqns] "case split into {mvarIds.size} goals"
mvarIds.forM (mkUnfoldProof declName)
else if let some mvarIds ← splitTarget? mvarId (useNewSemantics := true) then
trace[Elab.definition.wf.eqns] "splitTarget into {mvarIds.length} goals"
mvarIds.forM (mkUnfoldProof declName)
else
-- At some point in the past, we looked for occurrences of Wf.fix to fold on the
-- LHS (introduced in 096e4eb), but it seems that code path was never used,
-- so #3133 removed it again (and can be recovered from there if this was premature).
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
def isForallMotive (matcherApp : MatcherApp) : MetaM (Option Expr) := do
lambdaBoundedTelescope matcherApp.motive matcherApp.discrs.size fun xs t =>
if xs.size == matcherApp.discrs.size && t.isForall && !t.bindingBody!.hasLooseBVar 0 then
return some (← mkLambdaFVars xs t.bindingBody!)
else
return none
def mkUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) (wfPreprocessProof : Simp.Result) : MetaM Unit := do
/-- Generalization of `splitMatch` that can handle `casesOn` -/
def splitMatchOrCasesOn (mvarId : MVarId) (e : Expr) (matcherInfo : MatcherInfo) : MetaM (List MVarId) := do
if (← isMatcherApp e) then
Split.splitMatch mvarId e
else
assert! matcherInfo.numDiscrs = 1
let discr := e.getAppArgs[matcherInfo.numParams + 1]!
assert! discr.isFVar
let subgoals ← mvarId.cases discr.fvarId!
return subgoals.map (·.mvarId) |>.toList
/--
Generates a theorem of the form
```
matcherArgPusher params motive {α} {β} (f : ∀ (x : α), β x) rel alt1 .. x1 x2
:
matcher params (motive := fun x1 x2 => ((y : α) → rel x1 x2 y → β y) → motive x1 x2)
(alt1 := fun z1 z2 z2 f => alt1 z1 z2 z2 f) …
x1 x2
(fun y _h => f y)
=
matcher params (motive := motive)
(alt1 := fun z1 z2 z2 => alt1 z1 z2 z2 (fun y _ => f y)) …
x1 x2
```
-/
def mkMatchArgPusher (matcherName : Name) (matcherInfo : MatcherInfo) : MetaM Name := do
let name := (mkPrivateName (← getEnv) matcherName) ++ `_arg_pusher
realizeConst matcherName name do
let matcherVal ← getConstVal matcherName
forallBoundedTelescope matcherVal.type (some (matcherInfo.numParams + 1)) fun xs _ => do
let params := xs[*...matcherInfo.numParams]
let motive' := xs[matcherInfo.numParams]!
let u ← mkFreshUserName `u
let v ← mkFreshUserName `v
withLocalDeclD `α (.sort (.param u)) fun alpha => do
withLocalDeclD `β (← mkArrow alpha (.sort (.param v))) fun beta => do
withLocalDeclD `f (.forallE `x alpha (mkApp beta (.bvar 0)) .default) fun f => do
let relType ← forallTelescope (← inferType motive') fun xs _ =>
mkForallFVars xs (.forallE `x alpha (.sort 0) .default)
withLocalDeclD `rel relType fun rel => do
let motive ← forallTelescope (← inferType motive') fun xs _ => do
let motiveBody := mkAppN motive' xs
let extraArgType := .forallE `y alpha (.forallE `h (mkAppN rel (xs.push (.bvar 0))) (mkApp beta (.bvar 1)) .default) .default
let motiveBody ← mkArrow extraArgType motiveBody
mkLambdaFVars xs motiveBody
let uElim ← lambdaBoundedTelescope motive matcherInfo.numDiscrs fun _ motiveBody => do
getLevel motiveBody
let us := matcherVal.levelParams ++ [u, v]
let matcherLevels' := matcherVal.levelParams.map mkLevelParam
let matcherLevels ← match matcherInfo.uElimPos? with
| none =>
unless uElim.isZero do
throwError "unexpected matcher application for {.ofConstName matcherName}, motive is not a proposition"
pure matcherLevels'
| some pos =>
pure <| (matcherLevels'.toArray.set! pos uElim).toList
let lhs := .const matcherName matcherLevels
let rhs := .const matcherName matcherLevels'
let lhs := mkAppN lhs params
let rhs := mkAppN rhs params
let lhs := mkApp lhs motive
let rhs := mkApp rhs motive'
forallBoundedTelescope (← inferType lhs) matcherInfo.numDiscrs fun discrs _ => do
let lhs := mkAppN lhs discrs
let rhs := mkAppN rhs discrs
forallBoundedTelescope (← inferType lhs) matcherInfo.numAlts fun alts _ => do
let lhs := mkAppN lhs alts
let mut rhs := rhs
for alt in alts, altNumParams in matcherInfo.altNumParams do
let alt' ← forallBoundedTelescope (← inferType alt) altNumParams fun ys altBodyType => do
assert! altBodyType.isForall
let altArg ← forallBoundedTelescope altBodyType.bindingDomain! (some 2) fun ys _ => do
mkLambdaFVars ys (.app f ys[0]!)
mkLambdaFVars ys (mkAppN alt (ys.push altArg))
rhs := mkApp rhs alt'
let extraArg := .lam `y alpha (.lam `h (mkAppN rel (discrs.push (.bvar 0))) (mkApp f (.bvar 1)) .default) .default
let lhs := mkApp lhs extraArg
let goal ← mkEq lhs rhs
let value ← mkFreshExprSyntheticOpaqueMVar goal
let mvarId := value.mvarId!
let mvarIds ← splitMatchOrCasesOn mvarId rhs matcherInfo
for mvarId in mvarIds do
mvarId.refl
let value ← instantiateMVars value
let type ← mkForallFVars (params ++ #[motive', alpha, beta, f, rel] ++ discrs ++ alts) goal
let value ← mkLambdaFVars (params ++ #[motive', alpha, beta, f, rel] ++ discrs ++ alts) value
addDecl <| Declaration.thmDecl { name, levelParams := us, type, value}
return name
builtin_simproc_decl matcherPushArg (_) := fun e => do
let e := e.headBeta
let some matcherApp ← matchMatcherApp? e (alsoCasesOn := true) | return .continue
-- Check that the first remaining argument is of the form `(fun (x : α) p => (f x : β x))`
let some fArg := matcherApp.remaining[0]? | return .continue
unless fArg.isLambda do return .continue
unless fArg.bindingBody!.isLambda do return .continue
unless fArg.bindingBody!.bindingBody!.isApp do return .continue
if fArg.bindingBody!.bindingBody!.hasLooseBVar 0 then return .continue
unless fArg.bindingBody!.bindingBody!.appArg! == .bvar 1 do return .continue
if fArg.bindingBody!.bindingBody!.appFn!.hasLooseBVar 1 then return .continue
let fExpr := fArg.bindingBody!.bindingBody!.appFn!
let fExprType ← inferType fExpr
let fExprType ← withTransparency .all (whnfForall fExprType)
assert! fExprType.isForall
let alpha := fExprType.bindingDomain!
let beta := .lam fExprType.bindingName! fExprType.bindingDomain! fExprType.bindingBody! .default
-- Check that the motive has an extra parameter (from MatcherApp.addArg)
let some motive' ← isForallMotive matcherApp | return .continue
let rel ← lambdaTelescope matcherApp.motive fun xs motiveBody =>
let motiveBodyArg := motiveBody.bindingDomain!
mkLambdaFVars xs (.lam motiveBodyArg.bindingName! motiveBodyArg.bindingDomain! motiveBodyArg.bindingBody!.bindingDomain! .default)
let argPusher ← mkMatchArgPusher matcherApp.matcherName matcherApp.toMatcherInfo
-- Let's infer the level paramters:
let proof ← withTransparency .all <| mkAppOptM
argPusher ((matcherApp.params ++ #[motive', alpha, beta, fExpr, rel] ++ matcherApp.discrs ++ matcherApp.alts).map some)
let some (_, _, rhs) := (← inferType proof).eq? | throwError "matcherPushArg: expected equality:{indentExpr (← inferType proof)}"
let step : Simp.Result := { expr := rhs, proof? := some proof }
let step ← step.addExtraArgs matcherApp.remaining[1...*]
return .continue (some step)
def mkUnfoldProof (declName : Name) (mvarId : MVarId) : MetaM Unit := withTransparency .all do
let ctx ← Simp.mkContext (config := { dsimp := false, etaStruct := .none, letToHave := false, singlePass := true })
let simprocs := ({} : Simp.SimprocsArray)
let simprocs ← simprocs.add ``matcherPushArg (post := false)
match (← simpTarget mvarId ctx (simprocs := simprocs)).1 with
| none => return ()
| some mvarId' =>
prependError m!"failed to finish proof for equational theorem for '{.ofConstName declName}'" do
mvarId'.refl
public def mkUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) (wfPreprocessProof : Simp.Result) : MetaM Unit := do
let name := mkEqLikeNameFor (← getEnv) preDef.declName unfoldThmSuffix
prependError m!"Cannot derive {name}" do
prependError m!"Cannot derive unfold equation {name}" do
withOptions (tactic.hygienic.set · false) do
withoutExporting do
lambdaTelescope preDef.value fun xs body => do
let us := preDef.levelParams.map mkLevelParam
let lhs := mkAppN (Lean.mkConst preDef.declName us) xs
@ -111,7 +235,7 @@ theorem of `foo._unary` or `foo._binary`.
It should just be a specialization of that one, due to defeq.
-/
def mkBinaryUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) : MetaM Unit := do
public def mkBinaryUnfoldEq (preDef : PreDefinition) (unaryPreDefName : Name) : MetaM Unit := do
let name := mkEqLikeNameFor (← getEnv) preDef.declName unfoldThmSuffix
let unaryEqName:= mkEqLikeNameFor (← getEnv) unaryPreDefName unfoldThmSuffix
prependError m!"Cannot derive {name} from {unaryEqName}" do

View file

@ -23,5 +23,6 @@ def mkCasesOn (declName : Name) : MetaM Unit := do
addDecl decl
setReducibleAttribute name
modifyEnv fun env => markAuxRecursor env name
enableRealizationsForConst name
end Lean

View file

@ -32,6 +32,7 @@ of matcher applications.
-/
def matchMatcherApp? [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (alsoCasesOn := false) :
m (Option MatcherApp) := do
unless e.isApp do return none
if let .const declName declLevels := e.getAppFn then
if let some info ← getMatcherInfo? declName then
let args := e.getAppArgs
@ -74,6 +75,13 @@ def matchMatcherApp? [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (alsoCases
return none
def MatcherApp.toMatcherInfo (matcherApp : MatcherApp) : MatcherInfo where
uElimPos? := matcherApp.uElimPos?
discrInfos := matcherApp.discrInfos
numParams := matcherApp.params.size
numDiscrs := matcherApp.discrs.size
altNumParams := matcherApp.altNumParams
def MatcherApp.toExpr (matcherApp : MatcherApp) : Expr :=
let result := mkAppN (mkConst matcherApp.matcherName matcherApp.matcherLevels.toList) matcherApp.params
let result := mkApp result matcherApp.motive

View file

@ -0,0 +1,11 @@
/-!
Checks that that the wfrec unfold theorem can be generated even if the
function type is not manifestly a forall.
-/
def T := Nat → Nat
def f : T
| 0 => 0
| n + 1 => f n + 1
termination_by n => n