feat: pre and post may return "done" in Sym.simp (#11900)
This PR adds a `done` flag to the result returned by `Simproc`s in
`Sym.simp`.
The `done` flag controls whether simplification should continue after
the result:
- `done = false` (default): Continue with subsequent simplification
steps
- `done = true`: Stop processing, return this result as final
## Use cases for `done = true`
### In `pre` simprocs
Skip simplification of certain subterms entirely:
```
def skipLambdas : Simproc := fun e =>
if e.isLambda then return .rfl (done := true)
else return .rfl
```
### In `post` simprocs
Perform single-pass normalization without recursive simplification:
```
def singlePassNormalize : Simproc := fun e =>
if let some (e', h) ← tryNormalize e then
return .step e' h (done := true)
else return .rfl
```
With `done = true`, the result `e'` won't be recursively simplified.
This commit is contained in:
parent
6bf2486e13
commit
82f60a7ff3
7 changed files with 71 additions and 30 deletions
|
|
@ -48,17 +48,16 @@ def mkCongr (e : Expr) (f a : Expr) (fr : Result) (ar : Result) (_ : e = .app f
|
|||
let v ← getLevel β
|
||||
return mkApp2 (mkConst declName [u, v]) α β
|
||||
match fr, ar with
|
||||
| .rfl, .rfl =>
|
||||
return .rfl
|
||||
| .step f' hf, .rfl =>
|
||||
| .rfl _, .rfl _ => return .rfl
|
||||
| .step f' hf _, .rfl _ =>
|
||||
let e' ← mkAppS f' a
|
||||
let h := mkApp4 (← mkCongrPrefix ``congrFun') f f' hf a
|
||||
return .step e' h
|
||||
| .rfl, .step a' ha =>
|
||||
| .rfl _, .step a' ha _ =>
|
||||
let e' ← mkAppS f a'
|
||||
let h := mkApp4 (← mkCongrPrefix ``congrArg) a a' f ha
|
||||
return .step e' h
|
||||
| .step f' hf, .step a' ha =>
|
||||
| .step f' hf _, .step a' ha _ =>
|
||||
let e' ← mkAppS f' a'
|
||||
let h := mkApp6 (← mkCongrPrefix ``congr) f f' a a' hf ha
|
||||
return .step e' h
|
||||
|
|
@ -131,8 +130,8 @@ where
|
|||
if rewritable[i - 1] then
|
||||
mkCongr e f a fr (← simp a) h
|
||||
else match fr with
|
||||
| .rfl => return .rfl
|
||||
| .step f' hf => mkCongrFun e f a f' hf h
|
||||
| .rfl _ => return .rfl
|
||||
| .step f' hf _ => mkCongrFun e f a f' hf h
|
||||
| _ => unreachable!
|
||||
|
||||
/--
|
||||
|
|
|
|||
|
|
@ -19,8 +19,8 @@ def simpLambda (e : Expr) : SimpM Result := do
|
|||
-- **TODO**: Add free variable reuse
|
||||
lambdaTelescope e fun xs b => do
|
||||
match (← simp b) with
|
||||
| .rfl => return .rfl
|
||||
| .step b' h =>
|
||||
| .rfl _ => return .rfl
|
||||
| .step b' h _ =>
|
||||
let h ← mkLambdaFVars xs h
|
||||
-- **TODO**: Add `mkLambdaFVarsS`?
|
||||
let e' ← shareCommonInc (← mkLambdaFVars xs b')
|
||||
|
|
@ -57,8 +57,8 @@ def simpStep : Simproc := fun e => do
|
|||
| .mdata m b =>
|
||||
let r ← simp b
|
||||
match r with
|
||||
| .rfl => return .rfl
|
||||
| .step b' h => return .step (← mkMDataS m b') h
|
||||
| .rfl _ => return .rfl
|
||||
| .step b' h _ => return .step (← mkMDataS m b') h
|
||||
| .lam .. => simpLambda e
|
||||
| .forallE .. => simpForall e
|
||||
| .letE .. => simpLet e
|
||||
|
|
@ -79,12 +79,14 @@ def simpImpl (e₁ : Expr) : SimpM Result := withIncRecDepth do
|
|||
if numSteps % 1000 == 0 then
|
||||
checkSystem "simp"
|
||||
modify fun s => { s with numSteps }
|
||||
match (← pre e₁) with
|
||||
| .step e₂ h₁ => cacheResult e₁ (← mkEqTransResult e₁ e₂ h₁ (← simp e₂))
|
||||
| .rfl =>
|
||||
let r₁ ← (simpStep >> post) e₁
|
||||
let r₁ ← pre e₁
|
||||
match r₁ with
|
||||
| .rfl => cacheResult e₁ r₁
|
||||
| .step e₂ h₁ => cacheResult e₁ (← mkEqTransResult e₁ e₂ h₁ (← simp e₂))
|
||||
| .rfl true | .step _ _ true => cacheResult e₁ r₁
|
||||
| .step e₂ h₁ false => cacheResult e₁ (← mkEqTransResult e₁ e₂ h₁ (← simp e₂))
|
||||
| .rfl false =>
|
||||
let r₂ ← (simpStep >> post) e₁
|
||||
match r₂ with
|
||||
| .rfl _ | .step _ _ true => cacheResult e₁ r₂
|
||||
| .step e₂ h₁ false => cacheResult e₁ (← mkEqTransResult e₁ e₂ h₁ (← simp e₂))
|
||||
|
||||
end Lean.Meta.Sym.Simp
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ public def mkEqTrans (e₁ : Expr) (e₂ : Expr) (h₁ : Expr) (e₃ : Expr) (h
|
|||
|
||||
public abbrev mkEqTransResult (e₁ : Expr) (e₂ : Expr) (h₁ : Expr) (r₂ : Result) : SymM Result :=
|
||||
match r₂ with
|
||||
| .rfl => return .step e₂ h₁
|
||||
| .step e₃ h₂ => return .step e₃ (← mkEqTrans e₁ e₂ h₁ e₃ h₂)
|
||||
| .rfl done => return .step e₂ h₁ done
|
||||
| .step e₃ h₂ done => return .step e₃ (← mkEqTrans e₁ e₂ h₁ e₃ h₂) done
|
||||
|
||||
end Lean.Meta.Sym.Simp
|
||||
|
|
|
|||
|
|
@ -104,12 +104,51 @@ structure Config where
|
|||
maxSteps : Nat := 0
|
||||
-- **TODO**: many are still missing
|
||||
|
||||
/-- The result of simplifying some expression `e`. -/
|
||||
/--
|
||||
The result of simplifying an expression `e`.
|
||||
|
||||
The `done` flag controls whether simplification should continue after this result:
|
||||
- `done = false` (default): Continue with subsequent simplification steps
|
||||
- `done = true`: Stop processing, return this result as final
|
||||
|
||||
## Use cases for `done = true`
|
||||
|
||||
### In `pre` simprocs
|
||||
Skip simplification of certain subterms entirely:
|
||||
```
|
||||
def skipLambdas : Simproc := fun e =>
|
||||
if e.isLambda then return .rfl (done := true)
|
||||
else return .rfl
|
||||
```
|
||||
|
||||
### In `post` simprocs
|
||||
Perform single-pass normalization without recursive simplification:
|
||||
```
|
||||
def singlePassNormalize : Simproc := fun e =>
|
||||
if let some (e', h) ← tryNormalize e then
|
||||
return .step e' h (done := true)
|
||||
else return .rfl
|
||||
```
|
||||
With `done = true`, the result `e'` won't be recursively simplified.
|
||||
|
||||
## Behavior
|
||||
|
||||
The `done` flag affects:
|
||||
1. **`andThen` composition**: If the first simproc returns `done = true`,
|
||||
the second simproc is skipped.
|
||||
2. **Recursive simplification**: After `pre` or `post` returns `.step e' h`,
|
||||
`simp` normally recurses on `e'`. With `done = true`, recursion is skipped.
|
||||
|
||||
The flag is orthogonal to caching: both `.rfl` and `.step` results are cached
|
||||
regardless of the `done` flag, and cached results are always treated as final.
|
||||
-/
|
||||
inductive Result where
|
||||
| /-- No change -/
|
||||
rfl
|
||||
| /-- Simplified expression `e'` and a proof that `e = e'` -/
|
||||
step (e' : Expr) (proof : Expr)
|
||||
/-- No change. If `done = true`, skip remaining simplification steps for this term. -/
|
||||
| rfl (done : Bool := false)
|
||||
/--
|
||||
Simplified to `e'` with proof `proof : e = e'`.
|
||||
If `done = true`, skip recursive simplification of `e'`. -/
|
||||
| step (e' : Expr) (proof : Expr) (done : Bool := false)
|
||||
|
||||
private opaque MethodsRefPointed : NonemptyType.{0}
|
||||
def MethodsRef : Type := MethodsRefPointed.type
|
||||
|
|
|
|||
|
|
@ -13,8 +13,9 @@ open Grind
|
|||
public abbrev Simproc.andThen (f g : Simproc) : Simproc := fun e₁ => do
|
||||
let r ← f e₁
|
||||
match r with
|
||||
| .rfl => g e₁
|
||||
| .step e₂ h₁ => mkEqTransResult e₁ e₂ h₁ (← g e₂)
|
||||
| .step _ _ true | .rfl true => return r
|
||||
| .rfl false => g e₁
|
||||
| .step e₂ h₁ false => mkEqTransResult e₁ e₂ h₁ (← g e₂)
|
||||
|
||||
public instance : AndThen Simproc where
|
||||
andThen f g := Simproc.andThen f (g ())
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ namespace SimpBench
|
|||
|
||||
def getProofSize (r : Sym.Simp.Result) : MetaM Nat :=
|
||||
match r with
|
||||
| .rfl => return 0
|
||||
| .step _ p => p.numObjs
|
||||
| .rfl _ => return 0
|
||||
| .step _ p _ => p.numObjs
|
||||
|
||||
def mkSimpMethods : MetaM Sym.Simp.Methods := do
|
||||
let thms : Sym.Simp.Theorems := {}
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ namespace SimpBench
|
|||
|
||||
def getProofSize (r : Sym.Simp.Result) : MetaM Nat :=
|
||||
match r with
|
||||
| .rfl => return 0
|
||||
| .step _ p => p.numObjs
|
||||
| .rfl _ => return 0
|
||||
| .step _ p _ => p.numObjs
|
||||
|
||||
def mkSimpMethods : MetaM Sym.Simp.Methods := do
|
||||
let thms : Sym.Simp.Theorems := {}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue