From 97c23abf8ebaf62fabba2c98569d2d65290694bd Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 1 Jan 2026 15:21:22 -0800 Subject: [PATCH] feat: main loop for `Sym.simp` (#11866) This PR implements the core simplification loop for the `Sym` framework, with efficient congruence-based argument rewriting. --- src/Lean/Meta/Sym.lean | 3 + src/Lean/Meta/Sym/Congr.lean | 159 ++++++++++++++++++++++++++++++ src/Lean/Meta/Sym/EqTrans.lean | 22 +++++ src/Lean/Meta/Sym/InferType.lean | 14 +++ src/Lean/Meta/Sym/Simp.lean | 71 ++++++++++++- src/Lean/Meta/Sym/SimpM.lean | 15 ++- src/Lean/Meta/Sym/SimpResult.lean | 16 +++ src/Lean/Meta/Sym/SymM.lean | 4 + 8 files changed, 302 insertions(+), 2 deletions(-) create mode 100644 src/Lean/Meta/Sym/Congr.lean create mode 100644 src/Lean/Meta/Sym/EqTrans.lean create mode 100644 src/Lean/Meta/Sym/SimpResult.lean diff --git a/src/Lean/Meta/Sym.lean b/src/Lean/Meta/Sym.lean index e956de3b08..8be9880424 100644 --- a/src/Lean/Meta/Sym.lean +++ b/src/Lean/Meta/Sym.lean @@ -22,6 +22,9 @@ public import Lean.Meta.Sym.Apply public import Lean.Meta.Sym.InferType public import Lean.Meta.Sym.SimpM public import Lean.Meta.Sym.CongrInfo +public import Lean.Meta.Sym.EqTrans +public import Lean.Meta.Sym.Congr +public import Lean.Meta.Sym.SimpResult public import Lean.Meta.Sym.Simp /-! diff --git a/src/Lean/Meta/Sym/Congr.lean b/src/Lean/Meta/Sym/Congr.lean new file mode 100644 index 0000000000..1d2e4471f0 --- /dev/null +++ b/src/Lean/Meta/Sym/Congr.lean @@ -0,0 +1,159 @@ +/- +Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Sym.SimpM +import Lean.Meta.Tactic.Grind.AlphaShareBuilder +import Lean.Meta.Sym.InferType +import Lean.Meta.Sym.SimpResult +import Lean.Meta.Sym.CongrInfo +namespace Lean.Meta.Sym.Simp +open Grind + +/-! +# Simplifying Application Arguments and Congruence Lemma Application + +This module provides functions for building congruence proofs during simplification. +Given a function application `f a₁ ... aₙ` where some arguments are rewritable, +we recursively simplify those arguments (via `simp`) and construct a proof that the +original expression equals the simplified one. + +The key challenge is efficiency: we want to avoid repeatedly inferring types, or destroying sharing, +The `CongrInfo` type (see `SymM.lean`) categorizes functions +by their argument structure, allowing us to choose the most efficient proof strategy: + +- `fixedPrefix`: Use simple `congrArg`/`congrFun'`/`congr` for trailing arguments. We exploit + the fact that there are no dependent arguments in the suffix and use the cheaper `congrFun'` + instead of `congrFun`. +- `interlaced`: Mix rewritable and fixed arguments. It may have to use `congrFun` for fixed + dependent arguments. +- `congrTheorem`: Apply a pre-generated congruence theorem for dependent arguments + +**Design principle**: Never infer the type of proofs. This avoids expensive type +inference on proof terms, which can be arbitrarily complex, and often destroys sharing. +-/ + +/-- +Helper function for constructing a congruence proof using `congrFun'`, `congrArg`, `congr`. +For the dependent case, use `mkCongrFun` +-/ +def mkCongr (e : Expr) (f a : Expr) (fr : Result) (ar : Result) (_ : e = .app f a) : SymM Result := do + let mkCongrPrefix (declName : Name) : SymM Expr := do + let α ← inferType a + let u ← getLevel α + let β ← inferType e + let v ← getLevel β + return mkApp2 (mkConst declName [u, v]) α β + match isSameExpr fr.expr f, isSameExpr ar.expr a with + | true, true => + return { expr := e } + | false, true => + let expr ← mkAppS fr.expr a + let proof? := mkApp4 (← mkCongrPrefix ``congrFun') f fr.expr (← fr.getProof) a + return { expr, proof? } + | true, false => + let expr ← mkAppS f ar.expr + let proof? := mkApp4 (← mkCongrPrefix ``congrArg) a ar.expr f (← ar.getProof) + return { expr, proof? } + | false, false => + let expr ← mkAppS fr.expr ar.expr + let proof? := mkApp6 (← mkCongrPrefix ``congr) f fr.expr a ar.expr (← fr.getProof) (← ar.getProof) + return { expr, proof? } + +/-- +Returns a proof using `congrFun` +``` +congrFun.{u, v} {α : Sort u} {β : α → Sort v} {f g : (x : α) → β x} (h : f = g) (a : α) : f a = g a +``` +-/ +def mkCongrFun (e : Expr) (f a : Expr) (fr : Result) (_ : e = .app f a) : SymM Result := do + let .forallE x _ βx _ ← whnfD (← inferType f) + | throwError "failed to build congruence proof, function expected{indentExpr f}" + let α ← inferType a + let u ← getLevel α + let v ← getLevel (← inferType e) + let β := Lean.mkLambda x .default α βx + let expr ← mkAppS fr.expr a + let proof? := mkApp6 (mkConst ``congrFun [u, v]) α β f fr.expr (← fr.getProof) a + return { expr, proof? } + +/-- +Simplify arguments of a function application with a fixed prefix structure. +Recursively simplifies the trailing `suffixSize` arguments, leaving the first +`prefixSize` arguments unchanged. +-/ +def congrFixedPrefix (e : Expr) (prefixSize : Nat) (suffixSize : Nat) : SimpM Result := do + let numArgs := e.getAppNumArgs + if numArgs ≤ prefixSize then + -- Nothing to be done + return { expr := e } + else if numArgs > prefixSize + suffixSize then + -- **TODO**: over-applied case + return { expr := e } + else + go numArgs e +where + go (i : Nat) (e : Expr) : SimpM Result := do + if i == prefixSize then + return { expr := e } + else + match h : e with + | .app f a => mkCongr e f a (← go (i - 1) f) (← simp a) h + | _ => unreachable! + +/-- +Simplify arguments of a function application with interlaced rewritable/fixed arguments. +Uses `rewritable[i]` to determine whether argument `i` should be simplified. +For rewritable arguments, calls `simp` and uses `congrFun'`, `congrArg`, and `congr`; for fixed arguments, +uses `congrFun` to propagate changes from earlier arguments. +-/ +def congrInterlaced (e : Expr) (rewritable : Array Bool) : SimpM Result := do + let numArgs := e.getAppNumArgs + if h : numArgs = 0 then + -- Nothing to be done + return { expr := e } + else if h : numArgs > rewritable.size then + -- **TODO**: over-applied case + return { expr := e } + else + go numArgs e (by omega) +where + go (i : Nat) (e : Expr) (h : i ≤ rewritable.size) : SimpM Result := do + if h : i = 0 then + return { expr := e } + else + match h : e with + | .app f a => + let fr ← go (i - 1) f (by omega) + if rewritable[i - 1] then + mkCongr e f a fr (← simp a) h + else if isSameExpr fr.expr f then + return { expr := e } + else + mkCongrFun e f a fr h + | _ => unreachable! + +/-- +Simplify arguments using a pre-generated congruence theorem. +Used for functions with proof or `Decidable` arguments. +-/ +def congrThm (e : Expr) (_ : CongrTheorem) : SimpM Result := do + -- **TODO** + return { expr := e } + +/-- +Main entry point for simplifying function application arguments. +Dispatches to the appropriate strategy based on the function's `CongrInfo`. +-/ +public def congrArgs (e : Expr) : SimpM Result := do + let f := e.getAppFn + match (← getCongrInfo f) with + | .none => return { expr := e } + | .fixedPrefix prefixSize suffixSize => congrFixedPrefix e prefixSize suffixSize + | .interlaced rewritable => congrInterlaced e rewritable + | .congrTheorem thm => congrThm e thm + +end Lean.Meta.Sym.Simp diff --git a/src/Lean/Meta/Sym/EqTrans.lean b/src/Lean/Meta/Sym/EqTrans.lean new file mode 100644 index 0000000000..6e0393e55a --- /dev/null +++ b/src/Lean/Meta/Sym/EqTrans.lean @@ -0,0 +1,22 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Sym.SimpM +import Lean.Meta.Sym.InferType +namespace Lean.Meta.Sym + +public def mkEqTrans (e₁ : Expr) (e₂ : Expr) (h₁? : Option Expr) (e₃ : Expr) (h₂? : Option Expr) : SymM (Option Expr) := do + match h₁?, h₂? with + | none, none => return none + | some _, none => return h₁? + | none, some _ => return h₂? + | some h₁, some h₂ => + let α ← inferType e₁ + let u ← getLevel α + return mkApp6 (mkConst ``Eq.trans [u]) α e₁ e₂ e₃ h₁ h₂ + +end Lean.Meta.Sym diff --git a/src/Lean/Meta/Sym/InferType.lean b/src/Lean/Meta/Sym/InferType.lean index 3bb6bd2d83..e5252ab7d7 100644 --- a/src/Lean/Meta/Sym/InferType.lean +++ b/src/Lean/Meta/Sym/InferType.lean @@ -17,4 +17,18 @@ public def inferType (e : Expr) : SymM Expr := do modify fun s => { s with inferType := s.inferType.insert { expr := e } type } return type +@[inherit_doc Meta.getLevel] +public def getLevel (type : Expr) : SymM Level := do + if let some u := (← get).getLevel.find? { expr := type } then + return u + else + let u ← Meta.getLevel type + modify fun s => { s with getLevel := s.getLevel.insert { expr := type } u } + return u + +public def mkEqRefl (e : Expr) : SymM Expr := do + let α ← inferType e + let u ← getLevel α + return mkApp2 (mkConst ``Eq.refl [u]) α e + end Lean.Meta.Sym diff --git a/src/Lean/Meta/Sym/Simp.lean b/src/Lean/Meta/Sym/Simp.lean index 8a03d4bb72..d76446cf61 100644 --- a/src/Lean/Meta/Sym/Simp.lean +++ b/src/Lean/Meta/Sym/Simp.lean @@ -6,10 +6,79 @@ Authors: Leonardo de Moura module prelude public import Lean.Meta.Sym.SimpM +import Lean.Meta.Tactic.Grind.AlphaShareBuilder +import Lean.Meta.Sym.EqTrans +import Lean.Meta.Sym.Congr namespace Lean.Meta.Sym.Simp +open Grind + +def simpConst (e : Expr) : SimpM Result := do + -- **TODO** + return { expr := e } + +def simpLambda (e : Expr) : SimpM Result := do + -- **TODO** + return { expr := e } + +def simpForall (e : Expr) : SimpM Result := do + -- **TODO** + return { expr := e } + +def simpLet (e : Expr) : SimpM Result := do + -- **TODO** + return { expr := e } + +def simpFVar (e : Expr) : SimpM Result := do + -- **TODO** + return { expr := e } + +def simpMVar (e : Expr) : SimpM Result := do + -- **TODO** + return { expr := e } + +def simpApp (e : Expr) : SimpM Result := do + congrArgs e + +def simpStep (e : Expr) : SimpM Result := do + match e with + | .lit _ | .sort _ | .bvar _ => return { expr := e } + | .proj .. => + reportIssue! "unexpected kernel projection term during simplification{indentExpr e}\npre-process and fold them as projection applications" + return { expr := e } + | .mdata m b => + let r ← simp b + if isSameExpr r.expr b then return { expr := e } + else return { r with expr := (← mkMDataS m r.expr) } + | .const .. => simpConst e + | .lam .. => simpLambda e + | .forallE .. => simpForall e + | .letE .. => simpLet e + | .fvar .. => simpFVar e + | .mvar .. => simpMVar e + | .app .. => simpApp e + +def mkEqTrans (e : Expr) (r₁ : Result) (r₂ : Result) : SimpM Result := do + let proof? ← Sym.mkEqTrans e r₁.expr r₁.proof? r₂.expr r₂.proof? + return { r₂ with proof? } + +def cacheResult (e : Expr) (r : Result) : SimpM Result := do + modify fun s => { s with cache := s.cache.insert { expr := e } r } + return r @[export lean_sym_simp] def simpImpl (e : Expr) : SimpM Result := do - throwError "NIY {e}" + if (← get).numSteps >= (← getConfig).maxSteps then + throwError "`simp` failed: maximum number of steps exceeded" + if let some result := (← getCache).find? { expr := e } then + return result + let r ← simpStep e + if isSameExpr r.expr e then + cacheResult e r + else + let r' ← simp r.expr + if isSameExpr r'.expr r.expr then + cacheResult e r + else + cacheResult e (← mkEqTrans e r r') end Lean.Meta.Sym.Simp diff --git a/src/Lean/Meta/Sym/SimpM.lean b/src/Lean/Meta/Sym/SimpM.lean index 9c2dfeb991..5dc53eb8d0 100644 --- a/src/Lean/Meta/Sym/SimpM.lean +++ b/src/Lean/Meta/Sym/SimpM.lean @@ -103,6 +103,8 @@ invalidating the cache and causing O(2^n) behavior on conditional trees. structure Config where /-- If `true`, unfold let-bindings (zeta reduction) during simplification. -/ zetaDelta : Bool := true + /-- Maximum number of steps that can be performed by the simplifier. -/ + maxSteps : Nat := 0 -- **TODO**: many are still missing /-- The result of simplifying some expression `e`. -/ @@ -145,11 +147,16 @@ abbrev Cache := PHashMap ExprPtr Result /-- Mutable state for the simplifier. -/ structure State where - /-- Cache of previously simplified expressions to avoid redundant work. -/ + /-- + Cache of previously simplified expressions to avoid redundant work. + **Note**: Consider moving to `SymM.State` + -/ cache : Cache := {} /-- Stack of free variables available for reuse when re-entering binders. Each entry is (type pointer, fvarId). -/ binderStack : List (ExprPtr × FVarId) := [] + /-- Number of steps performed so far. -/ + numSteps := 0 /-- Monad for the structural simplifier, layered on top of `SymM`. -/ abbrev SimpM := ReaderT Context StateRefT State SymM @@ -165,4 +172,10 @@ abbrev SimpM.run (x : SimpM α) (thms : Theorems := {}) (config : Config := {}) @[extern "lean_sym_simp"] -- Forward declaration opaque simp (e : Expr) : SimpM Result +def getConfig : SimpM Config := + return (← read).config + +abbrev getCache : SimpM Cache := + return (← get).cache + end Lean.Meta.Sym.Simp diff --git a/src/Lean/Meta/Sym/SimpResult.lean b/src/Lean/Meta/Sym/SimpResult.lean new file mode 100644 index 0000000000..3950c69dcc --- /dev/null +++ b/src/Lean/Meta/Sym/SimpResult.lean @@ -0,0 +1,16 @@ +/- +Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Sym.SimpM +import Lean.Meta.Sym.InferType +namespace Lean.Meta.Sym.Simp + +public def Result.getProof (result : Result) : SymM Expr := do + let some proof := result.proof? | mkEqRefl result.expr + return proof + +end Lean.Meta.Sym.Simp diff --git a/src/Lean/Meta/Sym/SymM.lean b/src/Lean/Meta/Sym/SymM.lean index 9fab0dba1e..345d1e8502 100644 --- a/src/Lean/Meta/Sym/SymM.lean +++ b/src/Lean/Meta/Sym/SymM.lean @@ -106,6 +106,10 @@ structure State where Remark: type inference is a bottleneck on `Meta.Tactic.Simp` simplifier. -/ inferType : PHashMap ExprPtr Expr := {} + /-- + Cache for `getLevel` results, keyed by pointer equality. + -/ + getLevel : PHashMap ExprPtr Level := {} congrInfo : PHashMap ExprPtr CongrInfo := {} abbrev SymM := ReaderT Grind.Params StateRefT State Grind.GrindM