From 82f60a7ff3118e649bf51f796eb22cc2b5a24a04 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 4 Jan 2026 18:10:06 -0800 Subject: [PATCH] feat: `pre` and `post` may return "done" in `Sym.simp` (#11900) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds a `done` flag to the result returned by `Simproc`s in `Sym.simp`. The `done` flag controls whether simplification should continue after the result: - `done = false` (default): Continue with subsequent simplification steps - `done = true`: Stop processing, return this result as final ## Use cases for `done = true` ### In `pre` simprocs Skip simplification of certain subterms entirely: ``` def skipLambdas : Simproc := fun e => if e.isLambda then return .rfl (done := true) else return .rfl ``` ### In `post` simprocs Perform single-pass normalization without recursive simplification: ``` def singlePassNormalize : Simproc := fun e => if let some (e', h) ← tryNormalize e then return .step e' h (done := true) else return .rfl ``` With `done = true`, the result `e'` won't be recursively simplified. --- src/Lean/Meta/Sym/Simp/Congr.lean | 13 ++++---- src/Lean/Meta/Sym/Simp/Main.lean | 22 +++++++------ src/Lean/Meta/Sym/Simp/Result.lean | 4 +-- src/Lean/Meta/Sym/Simp/SimpM.lean | 49 ++++++++++++++++++++++++++--- src/Lean/Meta/Sym/Simp/Simproc.lean | 5 +-- tests/bench/sym/simp_1.lean | 4 +-- tests/bench/sym/simp_2.lean | 4 +-- 7 files changed, 71 insertions(+), 30 deletions(-) diff --git a/src/Lean/Meta/Sym/Simp/Congr.lean b/src/Lean/Meta/Sym/Simp/Congr.lean index f268e04069..3f773518ea 100644 --- a/src/Lean/Meta/Sym/Simp/Congr.lean +++ b/src/Lean/Meta/Sym/Simp/Congr.lean @@ -48,17 +48,16 @@ def mkCongr (e : Expr) (f a : Expr) (fr : Result) (ar : Result) (_ : e = .app f let v ← getLevel β return mkApp2 (mkConst declName [u, v]) α β match fr, ar with - | .rfl, .rfl => - return .rfl - | .step f' hf, .rfl => + | .rfl _, .rfl _ => return .rfl + | .step f' hf _, .rfl _ => let e' ← mkAppS f' a let h := mkApp4 (← mkCongrPrefix ``congrFun') f f' hf a return .step e' h - | .rfl, .step a' ha => + | .rfl _, .step a' ha _ => let e' ← mkAppS f a' let h := mkApp4 (← mkCongrPrefix ``congrArg) a a' f ha return .step e' h - | .step f' hf, .step a' ha => + | .step f' hf _, .step a' ha _ => let e' ← mkAppS f' a' let h := mkApp6 (← mkCongrPrefix ``congr) f f' a a' hf ha return .step e' h @@ -131,8 +130,8 @@ where if rewritable[i - 1] then mkCongr e f a fr (← simp a) h else match fr with - | .rfl => return .rfl - | .step f' hf => mkCongrFun e f a f' hf h + | .rfl _ => return .rfl + | .step f' hf _ => mkCongrFun e f a f' hf h | _ => unreachable! /-- diff --git a/src/Lean/Meta/Sym/Simp/Main.lean b/src/Lean/Meta/Sym/Simp/Main.lean index 70c1b75f10..232c953c81 100644 --- a/src/Lean/Meta/Sym/Simp/Main.lean +++ b/src/Lean/Meta/Sym/Simp/Main.lean @@ -19,8 +19,8 @@ def simpLambda (e : Expr) : SimpM Result := do -- **TODO**: Add free variable reuse lambdaTelescope e fun xs b => do match (← simp b) with - | .rfl => return .rfl - | .step b' h => + | .rfl _ => return .rfl + | .step b' h _ => let h ← mkLambdaFVars xs h -- **TODO**: Add `mkLambdaFVarsS`? let e' ← shareCommonInc (← mkLambdaFVars xs b') @@ -57,8 +57,8 @@ def simpStep : Simproc := fun e => do | .mdata m b => let r ← simp b match r with - | .rfl => return .rfl - | .step b' h => return .step (← mkMDataS m b') h + | .rfl _ => return .rfl + | .step b' h _ => return .step (← mkMDataS m b') h | .lam .. => simpLambda e | .forallE .. => simpForall e | .letE .. => simpLet e @@ -79,12 +79,14 @@ def simpImpl (e₁ : Expr) : SimpM Result := withIncRecDepth do if numSteps % 1000 == 0 then checkSystem "simp" modify fun s => { s with numSteps } - match (← pre e₁) with - | .step e₂ h₁ => cacheResult e₁ (← mkEqTransResult e₁ e₂ h₁ (← simp e₂)) - | .rfl => - let r₁ ← (simpStep >> post) e₁ + let r₁ ← pre e₁ match r₁ with - | .rfl => cacheResult e₁ r₁ - | .step e₂ h₁ => cacheResult e₁ (← mkEqTransResult e₁ e₂ h₁ (← simp e₂)) + | .rfl true | .step _ _ true => cacheResult e₁ r₁ + | .step e₂ h₁ false => cacheResult e₁ (← mkEqTransResult e₁ e₂ h₁ (← simp e₂)) + | .rfl false => + let r₂ ← (simpStep >> post) e₁ + match r₂ with + | .rfl _ | .step _ _ true => cacheResult e₁ r₂ + | .step e₂ h₁ false => cacheResult e₁ (← mkEqTransResult e₁ e₂ h₁ (← simp e₂)) end Lean.Meta.Sym.Simp diff --git a/src/Lean/Meta/Sym/Simp/Result.lean b/src/Lean/Meta/Sym/Simp/Result.lean index 1b5a5a3131..59bc0e3797 100644 --- a/src/Lean/Meta/Sym/Simp/Result.lean +++ b/src/Lean/Meta/Sym/Simp/Result.lean @@ -19,7 +19,7 @@ public def mkEqTrans (e₁ : Expr) (e₂ : Expr) (h₁ : Expr) (e₃ : Expr) (h public abbrev mkEqTransResult (e₁ : Expr) (e₂ : Expr) (h₁ : Expr) (r₂ : Result) : SymM Result := match r₂ with - | .rfl => return .step e₂ h₁ - | .step e₃ h₂ => return .step e₃ (← mkEqTrans e₁ e₂ h₁ e₃ h₂) + | .rfl done => return .step e₂ h₁ done + | .step e₃ h₂ done => return .step e₃ (← mkEqTrans e₁ e₂ h₁ e₃ h₂) done end Lean.Meta.Sym.Simp diff --git a/src/Lean/Meta/Sym/Simp/SimpM.lean b/src/Lean/Meta/Sym/Simp/SimpM.lean index b17fdcf963..a643ab6fc8 100644 --- a/src/Lean/Meta/Sym/Simp/SimpM.lean +++ b/src/Lean/Meta/Sym/Simp/SimpM.lean @@ -104,12 +104,51 @@ structure Config where maxSteps : Nat := 0 -- **TODO**: many are still missing -/-- The result of simplifying some expression `e`. -/ +/-- +The result of simplifying an expression `e`. + +The `done` flag controls whether simplification should continue after this result: +- `done = false` (default): Continue with subsequent simplification steps +- `done = true`: Stop processing, return this result as final + +## Use cases for `done = true` + +### In `pre` simprocs +Skip simplification of certain subterms entirely: +``` +def skipLambdas : Simproc := fun e => + if e.isLambda then return .rfl (done := true) + else return .rfl +``` + +### In `post` simprocs +Perform single-pass normalization without recursive simplification: +``` +def singlePassNormalize : Simproc := fun e => + if let some (e', h) ← tryNormalize e then + return .step e' h (done := true) + else return .rfl +``` +With `done = true`, the result `e'` won't be recursively simplified. + +## Behavior + +The `done` flag affects: +1. **`andThen` composition**: If the first simproc returns `done = true`, + the second simproc is skipped. +2. **Recursive simplification**: After `pre` or `post` returns `.step e' h`, + `simp` normally recurses on `e'`. With `done = true`, recursion is skipped. + +The flag is orthogonal to caching: both `.rfl` and `.step` results are cached +regardless of the `done` flag, and cached results are always treated as final. +-/ inductive Result where - | /-- No change -/ - rfl - | /-- Simplified expression `e'` and a proof that `e = e'` -/ - step (e' : Expr) (proof : Expr) + /-- No change. If `done = true`, skip remaining simplification steps for this term. -/ + | rfl (done : Bool := false) + /-- + Simplified to `e'` with proof `proof : e = e'`. + If `done = true`, skip recursive simplification of `e'`. -/ + | step (e' : Expr) (proof : Expr) (done : Bool := false) private opaque MethodsRefPointed : NonemptyType.{0} def MethodsRef : Type := MethodsRefPointed.type diff --git a/src/Lean/Meta/Sym/Simp/Simproc.lean b/src/Lean/Meta/Sym/Simp/Simproc.lean index 8a9246edcf..9080e249be 100644 --- a/src/Lean/Meta/Sym/Simp/Simproc.lean +++ b/src/Lean/Meta/Sym/Simp/Simproc.lean @@ -13,8 +13,9 @@ open Grind public abbrev Simproc.andThen (f g : Simproc) : Simproc := fun e₁ => do let r ← f e₁ match r with - | .rfl => g e₁ - | .step e₂ h₁ => mkEqTransResult e₁ e₂ h₁ (← g e₂) + | .step _ _ true | .rfl true => return r + | .rfl false => g e₁ + | .step e₂ h₁ false => mkEqTransResult e₁ e₂ h₁ (← g e₂) public instance : AndThen Simproc where andThen f g := Simproc.andThen f (g ()) diff --git a/tests/bench/sym/simp_1.lean b/tests/bench/sym/simp_1.lean index 1159e615a4..bfd3f11ad3 100644 --- a/tests/bench/sym/simp_1.lean +++ b/tests/bench/sym/simp_1.lean @@ -9,8 +9,8 @@ namespace SimpBench def getProofSize (r : Sym.Simp.Result) : MetaM Nat := match r with - | .rfl => return 0 - | .step _ p => p.numObjs + | .rfl _ => return 0 + | .step _ p _ => p.numObjs def mkSimpMethods : MetaM Sym.Simp.Methods := do let thms : Sym.Simp.Theorems := {} diff --git a/tests/bench/sym/simp_2.lean b/tests/bench/sym/simp_2.lean index f8e75a767c..3909931917 100644 --- a/tests/bench/sym/simp_2.lean +++ b/tests/bench/sym/simp_2.lean @@ -9,8 +9,8 @@ namespace SimpBench def getProofSize (r : Sym.Simp.Result) : MetaM Nat := match r with - | .rfl => return 0 - | .step _ p => p.numObjs + | .rfl _ => return 0 + | .step _ p _ => p.numObjs def mkSimpMethods : MetaM Sym.Simp.Methods := do let thms : Sym.Simp.Theorems := {}