feat: pre and post may return "done" in Sym.simp (#11900)

This PR adds a `done` flag to the result returned by `Simproc`s in
`Sym.simp`.

The `done` flag controls whether simplification should continue after
the result:
- `done = false` (default): Continue with subsequent simplification
steps
- `done = true`: Stop processing, return this result as final

## Use cases for `done = true`

### In `pre` simprocs
Skip simplification of certain subterms entirely:
```
def skipLambdas : Simproc := fun e =>
  if e.isLambda then return .rfl (done := true)
  else return .rfl
```

### In `post` simprocs
Perform single-pass normalization without recursive simplification:
```
def singlePassNormalize : Simproc := fun e =>
  if let some (e', h) ← tryNormalize e then
    return .step e' h (done := true)
  else return .rfl
```
With `done = true`, the result `e'` won't be recursively simplified.
This commit is contained in:
Leonardo de Moura 2026-01-04 18:10:06 -08:00 committed by GitHub
parent 6bf2486e13
commit 82f60a7ff3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 71 additions and 30 deletions

View file

@ -48,17 +48,16 @@ def mkCongr (e : Expr) (f a : Expr) (fr : Result) (ar : Result) (_ : e = .app f
let v ← getLevel β
return mkApp2 (mkConst declName [u, v]) α β
match fr, ar with
| .rfl, .rfl =>
return .rfl
| .step f' hf, .rfl =>
| .rfl _, .rfl _ => return .rfl
| .step f' hf _, .rfl _ =>
let e' ← mkAppS f' a
let h := mkApp4 (← mkCongrPrefix ``congrFun') f f' hf a
return .step e' h
| .rfl, .step a' ha =>
| .rfl _, .step a' ha _ =>
let e' ← mkAppS f a'
let h := mkApp4 (← mkCongrPrefix ``congrArg) a a' f ha
return .step e' h
| .step f' hf, .step a' ha =>
| .step f' hf _, .step a' ha _ =>
let e' ← mkAppS f' a'
let h := mkApp6 (← mkCongrPrefix ``congr) f f' a a' hf ha
return .step e' h
@ -131,8 +130,8 @@ where
if rewritable[i - 1] then
mkCongr e f a fr (← simp a) h
else match fr with
| .rfl => return .rfl
| .step f' hf => mkCongrFun e f a f' hf h
| .rfl _ => return .rfl
| .step f' hf _ => mkCongrFun e f a f' hf h
| _ => unreachable!
/--

View file

@ -19,8 +19,8 @@ def simpLambda (e : Expr) : SimpM Result := do
-- **TODO**: Add free variable reuse
lambdaTelescope e fun xs b => do
match (← simp b) with
| .rfl => return .rfl
| .step b' h =>
| .rfl _ => return .rfl
| .step b' h _ =>
let h ← mkLambdaFVars xs h
-- **TODO**: Add `mkLambdaFVarsS`?
let e' ← shareCommonInc (← mkLambdaFVars xs b')
@ -57,8 +57,8 @@ def simpStep : Simproc := fun e => do
| .mdata m b =>
let r ← simp b
match r with
| .rfl => return .rfl
| .step b' h => return .step (← mkMDataS m b') h
| .rfl _ => return .rfl
| .step b' h _ => return .step (← mkMDataS m b') h
| .lam .. => simpLambda e
| .forallE .. => simpForall e
| .letE .. => simpLet e
@ -79,12 +79,14 @@ def simpImpl (e₁ : Expr) : SimpM Result := withIncRecDepth do
if numSteps % 1000 == 0 then
checkSystem "simp"
modify fun s => { s with numSteps }
match (← pre e₁) with
| .step e₂ h₁ => cacheResult e₁ (← mkEqTransResult e₁ e₂ h₁ (← simp e₂))
| .rfl =>
let r₁ ← (simpStep >> post) e₁
let r₁ ← pre e₁
match r₁ with
| .rfl => cacheResult e₁ r₁
| .step e₂ h₁ => cacheResult e₁ (← mkEqTransResult e₁ e₂ h₁ (← simp e₂))
| .rfl true | .step _ _ true => cacheResult e₁ r₁
| .step e₂ h₁ false => cacheResult e₁ (← mkEqTransResult e₁ e₂ h₁ (← simp e₂))
| .rfl false =>
let r₂ ← (simpStep >> post) e₁
match r₂ with
| .rfl _ | .step _ _ true => cacheResult e₁ r₂
| .step e₂ h₁ false => cacheResult e₁ (← mkEqTransResult e₁ e₂ h₁ (← simp e₂))
end Lean.Meta.Sym.Simp

View file

@ -19,7 +19,7 @@ public def mkEqTrans (e₁ : Expr) (e₂ : Expr) (h₁ : Expr) (e₃ : Expr) (h
public abbrev mkEqTransResult (e₁ : Expr) (e₂ : Expr) (h₁ : Expr) (r₂ : Result) : SymM Result :=
match r₂ with
| .rfl => return .step e₂ h₁
| .step e₃ h₂ => return .step e₃ (← mkEqTrans e₁ e₂ h₁ e₃ h₂)
| .rfl done => return .step e₂ h₁ done
| .step e₃ h₂ done => return .step e₃ (← mkEqTrans e₁ e₂ h₁ e₃ h₂) done
end Lean.Meta.Sym.Simp

View file

@ -104,12 +104,51 @@ structure Config where
maxSteps : Nat := 0
-- **TODO**: many are still missing
/-- The result of simplifying some expression `e`. -/
/--
The result of simplifying an expression `e`.
The `done` flag controls whether simplification should continue after this result:
- `done = false` (default): Continue with subsequent simplification steps
- `done = true`: Stop processing, return this result as final
## Use cases for `done = true`
### In `pre` simprocs
Skip simplification of certain subterms entirely:
```
def skipLambdas : Simproc := fun e =>
if e.isLambda then return .rfl (done := true)
else return .rfl
```
### In `post` simprocs
Perform single-pass normalization without recursive simplification:
```
def singlePassNormalize : Simproc := fun e =>
if let some (e', h) ← tryNormalize e then
return .step e' h (done := true)
else return .rfl
```
With `done = true`, the result `e'` won't be recursively simplified.
## Behavior
The `done` flag affects:
1. **`andThen` composition**: If the first simproc returns `done = true`,
the second simproc is skipped.
2. **Recursive simplification**: After `pre` or `post` returns `.step e' h`,
`simp` normally recurses on `e'`. With `done = true`, recursion is skipped.
The flag is orthogonal to caching: both `.rfl` and `.step` results are cached
regardless of the `done` flag, and cached results are always treated as final.
-/
inductive Result where
| /-- No change -/
rfl
| /-- Simplified expression `e'` and a proof that `e = e'` -/
step (e' : Expr) (proof : Expr)
/-- No change. If `done = true`, skip remaining simplification steps for this term. -/
| rfl (done : Bool := false)
/--
Simplified to `e'` with proof `proof : e = e'`.
If `done = true`, skip recursive simplification of `e'`. -/
| step (e' : Expr) (proof : Expr) (done : Bool := false)
private opaque MethodsRefPointed : NonemptyType.{0}
def MethodsRef : Type := MethodsRefPointed.type

View file

@ -13,8 +13,9 @@ open Grind
public abbrev Simproc.andThen (f g : Simproc) : Simproc := fun e₁ => do
let r ← f e₁
match r with
| .rfl => g e₁
| .step e₂ h₁ => mkEqTransResult e₁ e₂ h₁ (← g e₂)
| .step _ _ true | .rfl true => return r
| .rfl false => g e₁
| .step e₂ h₁ false => mkEqTransResult e₁ e₂ h₁ (← g e₂)
public instance : AndThen Simproc where
andThen f g := Simproc.andThen f (g ())

View file

@ -9,8 +9,8 @@ namespace SimpBench
def getProofSize (r : Sym.Simp.Result) : MetaM Nat :=
match r with
| .rfl => return 0
| .step _ p => p.numObjs
| .rfl _ => return 0
| .step _ p _ => p.numObjs
def mkSimpMethods : MetaM Sym.Simp.Methods := do
let thms : Sym.Simp.Theorems := {}

View file

@ -9,8 +9,8 @@ namespace SimpBench
def getProofSize (r : Sym.Simp.Result) : MetaM Nat :=
match r with
| .rfl => return 0
| .step _ p => p.numObjs
| .rfl _ => return 0
| .step _ p _ => p.numObjs
def mkSimpMethods : MetaM Sym.Simp.Methods := do
let thms : Sym.Simp.Theorems := {}