feat: main loop for Sym.simp (#11866)

This PR implements the core simplification loop for the `Sym` framework,
with efficient congruence-based argument rewriting.
This commit is contained in:
Leonardo de Moura 2026-01-01 15:21:22 -08:00 committed by GitHub
parent ef9777ec0d
commit 97c23abf8e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 302 additions and 2 deletions

View file

@ -22,6 +22,9 @@ public import Lean.Meta.Sym.Apply
public import Lean.Meta.Sym.InferType
public import Lean.Meta.Sym.SimpM
public import Lean.Meta.Sym.CongrInfo
public import Lean.Meta.Sym.EqTrans
public import Lean.Meta.Sym.Congr
public import Lean.Meta.Sym.SimpResult
public import Lean.Meta.Sym.Simp
/-!

View file

@ -0,0 +1,159 @@
/-
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.SimpM
import Lean.Meta.Tactic.Grind.AlphaShareBuilder
import Lean.Meta.Sym.InferType
import Lean.Meta.Sym.SimpResult
import Lean.Meta.Sym.CongrInfo
namespace Lean.Meta.Sym.Simp
open Grind
/-!
# Simplifying Application Arguments and Congruence Lemma Application
This module provides functions for building congruence proofs during simplification.
Given a function application `f a₁ ... aₙ` where some arguments are rewritable,
we recursively simplify those arguments (via `simp`) and construct a proof that the
original expression equals the simplified one.
The key challenge is efficiency: we want to avoid repeatedly inferring types, or destroying sharing,
The `CongrInfo` type (see `SymM.lean`) categorizes functions
by their argument structure, allowing us to choose the most efficient proof strategy:
- `fixedPrefix`: Use simple `congrArg`/`congrFun'`/`congr` for trailing arguments. We exploit
the fact that there are no dependent arguments in the suffix and use the cheaper `congrFun'`
instead of `congrFun`.
- `interlaced`: Mix rewritable and fixed arguments. It may have to use `congrFun` for fixed
dependent arguments.
- `congrTheorem`: Apply a pre-generated congruence theorem for dependent arguments
**Design principle**: Never infer the type of proofs. This avoids expensive type
inference on proof terms, which can be arbitrarily complex, and often destroys sharing.
-/
/--
Helper function for constructing a congruence proof using `congrFun'`, `congrArg`, `congr`.
For the dependent case, use `mkCongrFun`
-/
def mkCongr (e : Expr) (f a : Expr) (fr : Result) (ar : Result) (_ : e = .app f a) : SymM Result := do
let mkCongrPrefix (declName : Name) : SymM Expr := do
let α ← inferType a
let u ← getLevel α
let β ← inferType e
let v ← getLevel β
return mkApp2 (mkConst declName [u, v]) α β
match isSameExpr fr.expr f, isSameExpr ar.expr a with
| true, true =>
return { expr := e }
| false, true =>
let expr ← mkAppS fr.expr a
let proof? := mkApp4 (← mkCongrPrefix ``congrFun') f fr.expr (← fr.getProof) a
return { expr, proof? }
| true, false =>
let expr ← mkAppS f ar.expr
let proof? := mkApp4 (← mkCongrPrefix ``congrArg) a ar.expr f (← ar.getProof)
return { expr, proof? }
| false, false =>
let expr ← mkAppS fr.expr ar.expr
let proof? := mkApp6 (← mkCongrPrefix ``congr) f fr.expr a ar.expr (← fr.getProof) (← ar.getProof)
return { expr, proof? }
/--
Returns a proof using `congrFun`
```
congrFun.{u, v} {α : Sort u} {β : α → Sort v} {f g : (x : α) → β x} (h : f = g) (a : α) : f a = g a
```
-/
def mkCongrFun (e : Expr) (f a : Expr) (fr : Result) (_ : e = .app f a) : SymM Result := do
let .forallE x _ βx _ ← whnfD (← inferType f)
| throwError "failed to build congruence proof, function expected{indentExpr f}"
let α ← inferType a
let u ← getLevel α
let v ← getLevel (← inferType e)
let β := Lean.mkLambda x .default α βx
let expr ← mkAppS fr.expr a
let proof? := mkApp6 (mkConst ``congrFun [u, v]) α β f fr.expr (← fr.getProof) a
return { expr, proof? }
/--
Simplify arguments of a function application with a fixed prefix structure.
Recursively simplifies the trailing `suffixSize` arguments, leaving the first
`prefixSize` arguments unchanged.
-/
def congrFixedPrefix (e : Expr) (prefixSize : Nat) (suffixSize : Nat) : SimpM Result := do
let numArgs := e.getAppNumArgs
if numArgs ≤ prefixSize then
-- Nothing to be done
return { expr := e }
else if numArgs > prefixSize + suffixSize then
-- **TODO**: over-applied case
return { expr := e }
else
go numArgs e
where
go (i : Nat) (e : Expr) : SimpM Result := do
if i == prefixSize then
return { expr := e }
else
match h : e with
| .app f a => mkCongr e f a (← go (i - 1) f) (← simp a) h
| _ => unreachable!
/--
Simplify arguments of a function application with interlaced rewritable/fixed arguments.
Uses `rewritable[i]` to determine whether argument `i` should be simplified.
For rewritable arguments, calls `simp` and uses `congrFun'`, `congrArg`, and `congr`; for fixed arguments,
uses `congrFun` to propagate changes from earlier arguments.
-/
def congrInterlaced (e : Expr) (rewritable : Array Bool) : SimpM Result := do
let numArgs := e.getAppNumArgs
if h : numArgs = 0 then
-- Nothing to be done
return { expr := e }
else if h : numArgs > rewritable.size then
-- **TODO**: over-applied case
return { expr := e }
else
go numArgs e (by omega)
where
go (i : Nat) (e : Expr) (h : i ≤ rewritable.size) : SimpM Result := do
if h : i = 0 then
return { expr := e }
else
match h : e with
| .app f a =>
let fr ← go (i - 1) f (by omega)
if rewritable[i - 1] then
mkCongr e f a fr (← simp a) h
else if isSameExpr fr.expr f then
return { expr := e }
else
mkCongrFun e f a fr h
| _ => unreachable!
/--
Simplify arguments using a pre-generated congruence theorem.
Used for functions with proof or `Decidable` arguments.
-/
def congrThm (e : Expr) (_ : CongrTheorem) : SimpM Result := do
-- **TODO**
return { expr := e }
/--
Main entry point for simplifying function application arguments.
Dispatches to the appropriate strategy based on the function's `CongrInfo`.
-/
public def congrArgs (e : Expr) : SimpM Result := do
let f := e.getAppFn
match (← getCongrInfo f) with
| .none => return { expr := e }
| .fixedPrefix prefixSize suffixSize => congrFixedPrefix e prefixSize suffixSize
| .interlaced rewritable => congrInterlaced e rewritable
| .congrTheorem thm => congrThm e thm
end Lean.Meta.Sym.Simp

View file

@ -0,0 +1,22 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.SimpM
import Lean.Meta.Sym.InferType
namespace Lean.Meta.Sym
public def mkEqTrans (e₁ : Expr) (e₂ : Expr) (h₁? : Option Expr) (e₃ : Expr) (h₂? : Option Expr) : SymM (Option Expr) := do
match h₁?, h₂? with
| none, none => return none
| some _, none => return h₁?
| none, some _ => return h₂?
| some h₁, some h₂ =>
let α ← inferType e₁
let u ← getLevel α
return mkApp6 (mkConst ``Eq.trans [u]) α e₁ e₂ e₃ h₁ h₂
end Lean.Meta.Sym

View file

@ -17,4 +17,18 @@ public def inferType (e : Expr) : SymM Expr := do
modify fun s => { s with inferType := s.inferType.insert { expr := e } type }
return type
@[inherit_doc Meta.getLevel]
public def getLevel (type : Expr) : SymM Level := do
if let some u := (← get).getLevel.find? { expr := type } then
return u
else
let u ← Meta.getLevel type
modify fun s => { s with getLevel := s.getLevel.insert { expr := type } u }
return u
public def mkEqRefl (e : Expr) : SymM Expr := do
let α ← inferType e
let u ← getLevel α
return mkApp2 (mkConst ``Eq.refl [u]) α e
end Lean.Meta.Sym

View file

@ -6,10 +6,79 @@ Authors: Leonardo de Moura
module
prelude
public import Lean.Meta.Sym.SimpM
import Lean.Meta.Tactic.Grind.AlphaShareBuilder
import Lean.Meta.Sym.EqTrans
import Lean.Meta.Sym.Congr
namespace Lean.Meta.Sym.Simp
open Grind
def simpConst (e : Expr) : SimpM Result := do
-- **TODO**
return { expr := e }
def simpLambda (e : Expr) : SimpM Result := do
-- **TODO**
return { expr := e }
def simpForall (e : Expr) : SimpM Result := do
-- **TODO**
return { expr := e }
def simpLet (e : Expr) : SimpM Result := do
-- **TODO**
return { expr := e }
def simpFVar (e : Expr) : SimpM Result := do
-- **TODO**
return { expr := e }
def simpMVar (e : Expr) : SimpM Result := do
-- **TODO**
return { expr := e }
def simpApp (e : Expr) : SimpM Result := do
congrArgs e
def simpStep (e : Expr) : SimpM Result := do
match e with
| .lit _ | .sort _ | .bvar _ => return { expr := e }
| .proj .. =>
reportIssue! "unexpected kernel projection term during simplification{indentExpr e}\npre-process and fold them as projection applications"
return { expr := e }
| .mdata m b =>
let r ← simp b
if isSameExpr r.expr b then return { expr := e }
else return { r with expr := (← mkMDataS m r.expr) }
| .const .. => simpConst e
| .lam .. => simpLambda e
| .forallE .. => simpForall e
| .letE .. => simpLet e
| .fvar .. => simpFVar e
| .mvar .. => simpMVar e
| .app .. => simpApp e
def mkEqTrans (e : Expr) (r₁ : Result) (r₂ : Result) : SimpM Result := do
let proof? ← Sym.mkEqTrans e r₁.expr r₁.proof? r₂.expr r₂.proof?
return { r₂ with proof? }
def cacheResult (e : Expr) (r : Result) : SimpM Result := do
modify fun s => { s with cache := s.cache.insert { expr := e } r }
return r
@[export lean_sym_simp]
def simpImpl (e : Expr) : SimpM Result := do
throwError "NIY {e}"
if (← get).numSteps >= (← getConfig).maxSteps then
throwError "`simp` failed: maximum number of steps exceeded"
if let some result := (← getCache).find? { expr := e } then
return result
let r ← simpStep e
if isSameExpr r.expr e then
cacheResult e r
else
let r' ← simp r.expr
if isSameExpr r'.expr r.expr then
cacheResult e r
else
cacheResult e (← mkEqTrans e r r')
end Lean.Meta.Sym.Simp

View file

@ -103,6 +103,8 @@ invalidating the cache and causing O(2^n) behavior on conditional trees.
structure Config where
/-- If `true`, unfold let-bindings (zeta reduction) during simplification. -/
zetaDelta : Bool := true
/-- Maximum number of steps that can be performed by the simplifier. -/
maxSteps : Nat := 0
-- **TODO**: many are still missing
/-- The result of simplifying some expression `e`. -/
@ -145,11 +147,16 @@ abbrev Cache := PHashMap ExprPtr Result
/-- Mutable state for the simplifier. -/
structure State where
/-- Cache of previously simplified expressions to avoid redundant work. -/
/--
Cache of previously simplified expressions to avoid redundant work.
**Note**: Consider moving to `SymM.State`
-/
cache : Cache := {}
/-- Stack of free variables available for reuse when re-entering binders.
Each entry is (type pointer, fvarId). -/
binderStack : List (ExprPtr × FVarId) := []
/-- Number of steps performed so far. -/
numSteps := 0
/-- Monad for the structural simplifier, layered on top of `SymM`. -/
abbrev SimpM := ReaderT Context StateRefT State SymM
@ -165,4 +172,10 @@ abbrev SimpM.run (x : SimpM α) (thms : Theorems := {}) (config : Config := {})
@[extern "lean_sym_simp"] -- Forward declaration
opaque simp (e : Expr) : SimpM Result
def getConfig : SimpM Config :=
return (← read).config
abbrev getCache : SimpM Cache :=
return (← get).cache
end Lean.Meta.Sym.Simp

View file

@ -0,0 +1,16 @@
/-
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.SimpM
import Lean.Meta.Sym.InferType
namespace Lean.Meta.Sym.Simp
public def Result.getProof (result : Result) : SymM Expr := do
let some proof := result.proof? | mkEqRefl result.expr
return proof
end Lean.Meta.Sym.Simp

View file

@ -106,6 +106,10 @@ structure State where
Remark: type inference is a bottleneck on `Meta.Tactic.Simp` simplifier.
-/
inferType : PHashMap ExprPtr Expr := {}
/--
Cache for `getLevel` results, keyed by pointer equality.
-/
getLevel : PHashMap ExprPtr Level := {}
congrInfo : PHashMap ExprPtr CongrInfo := {}
abbrev SymM := ReaderT Grind.Params StateRefT State Grind.GrindM