feat: main loop for Sym.simp (#11866)
This PR implements the core simplification loop for the `Sym` framework, with efficient congruence-based argument rewriting.
This commit is contained in:
parent
ef9777ec0d
commit
97c23abf8e
8 changed files with 302 additions and 2 deletions
|
|
@ -22,6 +22,9 @@ public import Lean.Meta.Sym.Apply
|
|||
public import Lean.Meta.Sym.InferType
|
||||
public import Lean.Meta.Sym.SimpM
|
||||
public import Lean.Meta.Sym.CongrInfo
|
||||
public import Lean.Meta.Sym.EqTrans
|
||||
public import Lean.Meta.Sym.Congr
|
||||
public import Lean.Meta.Sym.SimpResult
|
||||
public import Lean.Meta.Sym.Simp
|
||||
|
||||
/-!
|
||||
|
|
|
|||
159
src/Lean/Meta/Sym/Congr.lean
Normal file
159
src/Lean/Meta/Sym/Congr.lean
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.SimpM
|
||||
import Lean.Meta.Tactic.Grind.AlphaShareBuilder
|
||||
import Lean.Meta.Sym.InferType
|
||||
import Lean.Meta.Sym.SimpResult
|
||||
import Lean.Meta.Sym.CongrInfo
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
open Grind
|
||||
|
||||
/-!
|
||||
# Simplifying Application Arguments and Congruence Lemma Application
|
||||
|
||||
This module provides functions for building congruence proofs during simplification.
|
||||
Given a function application `f a₁ ... aₙ` where some arguments are rewritable,
|
||||
we recursively simplify those arguments (via `simp`) and construct a proof that the
|
||||
original expression equals the simplified one.
|
||||
|
||||
The key challenge is efficiency: we want to avoid repeatedly inferring types, or destroying sharing,
|
||||
The `CongrInfo` type (see `SymM.lean`) categorizes functions
|
||||
by their argument structure, allowing us to choose the most efficient proof strategy:
|
||||
|
||||
- `fixedPrefix`: Use simple `congrArg`/`congrFun'`/`congr` for trailing arguments. We exploit
|
||||
the fact that there are no dependent arguments in the suffix and use the cheaper `congrFun'`
|
||||
instead of `congrFun`.
|
||||
- `interlaced`: Mix rewritable and fixed arguments. It may have to use `congrFun` for fixed
|
||||
dependent arguments.
|
||||
- `congrTheorem`: Apply a pre-generated congruence theorem for dependent arguments
|
||||
|
||||
**Design principle**: Never infer the type of proofs. This avoids expensive type
|
||||
inference on proof terms, which can be arbitrarily complex, and often destroys sharing.
|
||||
-/
|
||||
|
||||
/--
|
||||
Helper function for constructing a congruence proof using `congrFun'`, `congrArg`, `congr`.
|
||||
For the dependent case, use `mkCongrFun`
|
||||
-/
|
||||
def mkCongr (e : Expr) (f a : Expr) (fr : Result) (ar : Result) (_ : e = .app f a) : SymM Result := do
|
||||
let mkCongrPrefix (declName : Name) : SymM Expr := do
|
||||
let α ← inferType a
|
||||
let u ← getLevel α
|
||||
let β ← inferType e
|
||||
let v ← getLevel β
|
||||
return mkApp2 (mkConst declName [u, v]) α β
|
||||
match isSameExpr fr.expr f, isSameExpr ar.expr a with
|
||||
| true, true =>
|
||||
return { expr := e }
|
||||
| false, true =>
|
||||
let expr ← mkAppS fr.expr a
|
||||
let proof? := mkApp4 (← mkCongrPrefix ``congrFun') f fr.expr (← fr.getProof) a
|
||||
return { expr, proof? }
|
||||
| true, false =>
|
||||
let expr ← mkAppS f ar.expr
|
||||
let proof? := mkApp4 (← mkCongrPrefix ``congrArg) a ar.expr f (← ar.getProof)
|
||||
return { expr, proof? }
|
||||
| false, false =>
|
||||
let expr ← mkAppS fr.expr ar.expr
|
||||
let proof? := mkApp6 (← mkCongrPrefix ``congr) f fr.expr a ar.expr (← fr.getProof) (← ar.getProof)
|
||||
return { expr, proof? }
|
||||
|
||||
/--
|
||||
Returns a proof using `congrFun`
|
||||
```
|
||||
congrFun.{u, v} {α : Sort u} {β : α → Sort v} {f g : (x : α) → β x} (h : f = g) (a : α) : f a = g a
|
||||
```
|
||||
-/
|
||||
def mkCongrFun (e : Expr) (f a : Expr) (fr : Result) (_ : e = .app f a) : SymM Result := do
|
||||
let .forallE x _ βx _ ← whnfD (← inferType f)
|
||||
| throwError "failed to build congruence proof, function expected{indentExpr f}"
|
||||
let α ← inferType a
|
||||
let u ← getLevel α
|
||||
let v ← getLevel (← inferType e)
|
||||
let β := Lean.mkLambda x .default α βx
|
||||
let expr ← mkAppS fr.expr a
|
||||
let proof? := mkApp6 (mkConst ``congrFun [u, v]) α β f fr.expr (← fr.getProof) a
|
||||
return { expr, proof? }
|
||||
|
||||
/--
|
||||
Simplify arguments of a function application with a fixed prefix structure.
|
||||
Recursively simplifies the trailing `suffixSize` arguments, leaving the first
|
||||
`prefixSize` arguments unchanged.
|
||||
-/
|
||||
def congrFixedPrefix (e : Expr) (prefixSize : Nat) (suffixSize : Nat) : SimpM Result := do
|
||||
let numArgs := e.getAppNumArgs
|
||||
if numArgs ≤ prefixSize then
|
||||
-- Nothing to be done
|
||||
return { expr := e }
|
||||
else if numArgs > prefixSize + suffixSize then
|
||||
-- **TODO**: over-applied case
|
||||
return { expr := e }
|
||||
else
|
||||
go numArgs e
|
||||
where
|
||||
go (i : Nat) (e : Expr) : SimpM Result := do
|
||||
if i == prefixSize then
|
||||
return { expr := e }
|
||||
else
|
||||
match h : e with
|
||||
| .app f a => mkCongr e f a (← go (i - 1) f) (← simp a) h
|
||||
| _ => unreachable!
|
||||
|
||||
/--
|
||||
Simplify arguments of a function application with interlaced rewritable/fixed arguments.
|
||||
Uses `rewritable[i]` to determine whether argument `i` should be simplified.
|
||||
For rewritable arguments, calls `simp` and uses `congrFun'`, `congrArg`, and `congr`; for fixed arguments,
|
||||
uses `congrFun` to propagate changes from earlier arguments.
|
||||
-/
|
||||
def congrInterlaced (e : Expr) (rewritable : Array Bool) : SimpM Result := do
|
||||
let numArgs := e.getAppNumArgs
|
||||
if h : numArgs = 0 then
|
||||
-- Nothing to be done
|
||||
return { expr := e }
|
||||
else if h : numArgs > rewritable.size then
|
||||
-- **TODO**: over-applied case
|
||||
return { expr := e }
|
||||
else
|
||||
go numArgs e (by omega)
|
||||
where
|
||||
go (i : Nat) (e : Expr) (h : i ≤ rewritable.size) : SimpM Result := do
|
||||
if h : i = 0 then
|
||||
return { expr := e }
|
||||
else
|
||||
match h : e with
|
||||
| .app f a =>
|
||||
let fr ← go (i - 1) f (by omega)
|
||||
if rewritable[i - 1] then
|
||||
mkCongr e f a fr (← simp a) h
|
||||
else if isSameExpr fr.expr f then
|
||||
return { expr := e }
|
||||
else
|
||||
mkCongrFun e f a fr h
|
||||
| _ => unreachable!
|
||||
|
||||
/--
|
||||
Simplify arguments using a pre-generated congruence theorem.
|
||||
Used for functions with proof or `Decidable` arguments.
|
||||
-/
|
||||
def congrThm (e : Expr) (_ : CongrTheorem) : SimpM Result := do
|
||||
-- **TODO**
|
||||
return { expr := e }
|
||||
|
||||
/--
|
||||
Main entry point for simplifying function application arguments.
|
||||
Dispatches to the appropriate strategy based on the function's `CongrInfo`.
|
||||
-/
|
||||
public def congrArgs (e : Expr) : SimpM Result := do
|
||||
let f := e.getAppFn
|
||||
match (← getCongrInfo f) with
|
||||
| .none => return { expr := e }
|
||||
| .fixedPrefix prefixSize suffixSize => congrFixedPrefix e prefixSize suffixSize
|
||||
| .interlaced rewritable => congrInterlaced e rewritable
|
||||
| .congrTheorem thm => congrThm e thm
|
||||
|
||||
end Lean.Meta.Sym.Simp
|
||||
22
src/Lean/Meta/Sym/EqTrans.lean
Normal file
22
src/Lean/Meta/Sym/EqTrans.lean
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.SimpM
|
||||
import Lean.Meta.Sym.InferType
|
||||
namespace Lean.Meta.Sym
|
||||
|
||||
public def mkEqTrans (e₁ : Expr) (e₂ : Expr) (h₁? : Option Expr) (e₃ : Expr) (h₂? : Option Expr) : SymM (Option Expr) := do
|
||||
match h₁?, h₂? with
|
||||
| none, none => return none
|
||||
| some _, none => return h₁?
|
||||
| none, some _ => return h₂?
|
||||
| some h₁, some h₂ =>
|
||||
let α ← inferType e₁
|
||||
let u ← getLevel α
|
||||
return mkApp6 (mkConst ``Eq.trans [u]) α e₁ e₂ e₃ h₁ h₂
|
||||
|
||||
end Lean.Meta.Sym
|
||||
|
|
@ -17,4 +17,18 @@ public def inferType (e : Expr) : SymM Expr := do
|
|||
modify fun s => { s with inferType := s.inferType.insert { expr := e } type }
|
||||
return type
|
||||
|
||||
@[inherit_doc Meta.getLevel]
|
||||
public def getLevel (type : Expr) : SymM Level := do
|
||||
if let some u := (← get).getLevel.find? { expr := type } then
|
||||
return u
|
||||
else
|
||||
let u ← Meta.getLevel type
|
||||
modify fun s => { s with getLevel := s.getLevel.insert { expr := type } u }
|
||||
return u
|
||||
|
||||
public def mkEqRefl (e : Expr) : SymM Expr := do
|
||||
let α ← inferType e
|
||||
let u ← getLevel α
|
||||
return mkApp2 (mkConst ``Eq.refl [u]) α e
|
||||
|
||||
end Lean.Meta.Sym
|
||||
|
|
|
|||
|
|
@ -6,10 +6,79 @@ Authors: Leonardo de Moura
|
|||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.SimpM
|
||||
import Lean.Meta.Tactic.Grind.AlphaShareBuilder
|
||||
import Lean.Meta.Sym.EqTrans
|
||||
import Lean.Meta.Sym.Congr
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
open Grind
|
||||
|
||||
def simpConst (e : Expr) : SimpM Result := do
|
||||
-- **TODO**
|
||||
return { expr := e }
|
||||
|
||||
def simpLambda (e : Expr) : SimpM Result := do
|
||||
-- **TODO**
|
||||
return { expr := e }
|
||||
|
||||
def simpForall (e : Expr) : SimpM Result := do
|
||||
-- **TODO**
|
||||
return { expr := e }
|
||||
|
||||
def simpLet (e : Expr) : SimpM Result := do
|
||||
-- **TODO**
|
||||
return { expr := e }
|
||||
|
||||
def simpFVar (e : Expr) : SimpM Result := do
|
||||
-- **TODO**
|
||||
return { expr := e }
|
||||
|
||||
def simpMVar (e : Expr) : SimpM Result := do
|
||||
-- **TODO**
|
||||
return { expr := e }
|
||||
|
||||
def simpApp (e : Expr) : SimpM Result := do
|
||||
congrArgs e
|
||||
|
||||
def simpStep (e : Expr) : SimpM Result := do
|
||||
match e with
|
||||
| .lit _ | .sort _ | .bvar _ => return { expr := e }
|
||||
| .proj .. =>
|
||||
reportIssue! "unexpected kernel projection term during simplification{indentExpr e}\npre-process and fold them as projection applications"
|
||||
return { expr := e }
|
||||
| .mdata m b =>
|
||||
let r ← simp b
|
||||
if isSameExpr r.expr b then return { expr := e }
|
||||
else return { r with expr := (← mkMDataS m r.expr) }
|
||||
| .const .. => simpConst e
|
||||
| .lam .. => simpLambda e
|
||||
| .forallE .. => simpForall e
|
||||
| .letE .. => simpLet e
|
||||
| .fvar .. => simpFVar e
|
||||
| .mvar .. => simpMVar e
|
||||
| .app .. => simpApp e
|
||||
|
||||
def mkEqTrans (e : Expr) (r₁ : Result) (r₂ : Result) : SimpM Result := do
|
||||
let proof? ← Sym.mkEqTrans e r₁.expr r₁.proof? r₂.expr r₂.proof?
|
||||
return { r₂ with proof? }
|
||||
|
||||
def cacheResult (e : Expr) (r : Result) : SimpM Result := do
|
||||
modify fun s => { s with cache := s.cache.insert { expr := e } r }
|
||||
return r
|
||||
|
||||
@[export lean_sym_simp]
|
||||
def simpImpl (e : Expr) : SimpM Result := do
|
||||
throwError "NIY {e}"
|
||||
if (← get).numSteps >= (← getConfig).maxSteps then
|
||||
throwError "`simp` failed: maximum number of steps exceeded"
|
||||
if let some result := (← getCache).find? { expr := e } then
|
||||
return result
|
||||
let r ← simpStep e
|
||||
if isSameExpr r.expr e then
|
||||
cacheResult e r
|
||||
else
|
||||
let r' ← simp r.expr
|
||||
if isSameExpr r'.expr r.expr then
|
||||
cacheResult e r
|
||||
else
|
||||
cacheResult e (← mkEqTrans e r r')
|
||||
|
||||
end Lean.Meta.Sym.Simp
|
||||
|
|
|
|||
|
|
@ -103,6 +103,8 @@ invalidating the cache and causing O(2^n) behavior on conditional trees.
|
|||
structure Config where
|
||||
/-- If `true`, unfold let-bindings (zeta reduction) during simplification. -/
|
||||
zetaDelta : Bool := true
|
||||
/-- Maximum number of steps that can be performed by the simplifier. -/
|
||||
maxSteps : Nat := 0
|
||||
-- **TODO**: many are still missing
|
||||
|
||||
/-- The result of simplifying some expression `e`. -/
|
||||
|
|
@ -145,11 +147,16 @@ abbrev Cache := PHashMap ExprPtr Result
|
|||
|
||||
/-- Mutable state for the simplifier. -/
|
||||
structure State where
|
||||
/-- Cache of previously simplified expressions to avoid redundant work. -/
|
||||
/--
|
||||
Cache of previously simplified expressions to avoid redundant work.
|
||||
**Note**: Consider moving to `SymM.State`
|
||||
-/
|
||||
cache : Cache := {}
|
||||
/-- Stack of free variables available for reuse when re-entering binders.
|
||||
Each entry is (type pointer, fvarId). -/
|
||||
binderStack : List (ExprPtr × FVarId) := []
|
||||
/-- Number of steps performed so far. -/
|
||||
numSteps := 0
|
||||
|
||||
/-- Monad for the structural simplifier, layered on top of `SymM`. -/
|
||||
abbrev SimpM := ReaderT Context StateRefT State SymM
|
||||
|
|
@ -165,4 +172,10 @@ abbrev SimpM.run (x : SimpM α) (thms : Theorems := {}) (config : Config := {})
|
|||
@[extern "lean_sym_simp"] -- Forward declaration
|
||||
opaque simp (e : Expr) : SimpM Result
|
||||
|
||||
def getConfig : SimpM Config :=
|
||||
return (← read).config
|
||||
|
||||
abbrev getCache : SimpM Cache :=
|
||||
return (← get).cache
|
||||
|
||||
end Lean.Meta.Sym.Simp
|
||||
|
|
|
|||
16
src/Lean/Meta/Sym/SimpResult.lean
Normal file
16
src/Lean/Meta/Sym/SimpResult.lean
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.SimpM
|
||||
import Lean.Meta.Sym.InferType
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
|
||||
public def Result.getProof (result : Result) : SymM Expr := do
|
||||
let some proof := result.proof? | mkEqRefl result.expr
|
||||
return proof
|
||||
|
||||
end Lean.Meta.Sym.Simp
|
||||
|
|
@ -106,6 +106,10 @@ structure State where
|
|||
Remark: type inference is a bottleneck on `Meta.Tactic.Simp` simplifier.
|
||||
-/
|
||||
inferType : PHashMap ExprPtr Expr := {}
|
||||
/--
|
||||
Cache for `getLevel` results, keyed by pointer equality.
|
||||
-/
|
||||
getLevel : PHashMap ExprPtr Level := {}
|
||||
congrInfo : PHashMap ExprPtr CongrInfo := {}
|
||||
|
||||
abbrev SymM := ReaderT Grind.Params StateRefT State Grind.GrindM
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue