From b40dabdecdf666cb50d20ed8f2037a41aaa999e2 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 3 Jan 2026 12:28:07 -0800 Subject: [PATCH] feat: add discrimination tree retrieval for `Sym` (#11886) This PR adds `getMatch` and `getMatchWithExtra` for retrieving patterns from discrimination trees in the symbolic simulation framework. The PR also adds uses `DiscrTree` to implement indexing in `Sym.simp`. --- src/Lean/Meta/Sym/DiscrTree.lean | 99 +++++++++++++++++++++++++++++++- src/Lean/Meta/Sym/Rewrite.lean | 5 +- src/Lean/Meta/Sym/SimpM.lean | 15 ++++- tests/bench/sym/simp_1.lean | 5 +- 4 files changed, 115 insertions(+), 9 deletions(-) diff --git a/src/Lean/Meta/Sym/DiscrTree.lean b/src/Lean/Meta/Sym/DiscrTree.lean index 962d7f9e95..789cc6dc62 100644 --- a/src/Lean/Meta/Sym/DiscrTree.lean +++ b/src/Lean/Meta/Sym/DiscrTree.lean @@ -132,8 +132,103 @@ public def insertPattern [BEq α] (d : DiscrTree α) (p : Pattern) (v : α) : Di let keys := p.mkDiscrTreeKeys d.insertKeyValue keys v -/-! -**TODO** Retrieval. +def getKeyArgs (e : Expr) : Key × Array Expr := + match e.getAppFn with + | .lit v => (.lit v, #[]) + | .const declName _ => (.const declName e.getAppNumArgs, e.getAppRevArgs) + | .fvar fvarId => (.fvar fvarId e.getAppNumArgs, e.getAppRevArgs) + | .forallE _ d b _ => (.arrow, #[b, d]) + | _ => (.other, #[]) + +abbrev findKey? (cs : Array (Key × Trie α)) (k : Key) : Option (Key × Trie α) := + cs.binSearch (k, default) (fun a b => a.1 < b.1) + +partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Array α) : Array α := + match c with + | .node vs cs => + if todo.isEmpty then + result ++ vs + else if cs.isEmpty then + result + else + let e := todo.back! + let todo := todo.pop + let first := cs[0]! /- Recall that `Key.star` is the minimal key -/ + let (k, args) := getKeyArgs e + /- We must always visit `Key.star` edges since they are wildcards. + Thus, `todo` is not used linearly when there is `Key.star` edge + and there is an edge for `k` and `k != Key.star`. -/ + let visitStar (result : Array α) : Array α := + if first.1 == .star then + getMatchLoop todo first.2 result + else + result + let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : Array α := + match findKey? cs k with + | none => result + | some c => getMatchLoop (todo ++ args) c.2 result + let result := visitStar result + match k with + | .star => result + | _ => visitNonStar k args result + +def getMatchRoot (d : DiscrTree α) (k : Key) (args : Array Expr) (result : Array α) : Array α := + match d.root.find? k with + | none => result + | some c => getMatchLoop args c result + +def getStarResult (d : DiscrTree α) : Array α := + let result : Array α := .mkEmpty initCapacity + match d.root.find? .star with + | none => result + | some (.node vs _) => result ++ vs + +def getMatchCore (d : DiscrTree α) (e : Expr) : Key × Array α := + let result := getStarResult d + let (k, args) := getKeyArgs e + match k with + | .star => (k, result) + | _ => (k, getMatchRoot d k args result) + +/-- +Retrieves all values whose patterns match the expression `e`. -/ +public def getMatch (d : DiscrTree α) (e : Expr) : Array α := + getMatchCore d e |>.2 + +/-- +Retrieves all values whose patterns match a prefix of `e`, along with the number of +extra (ignored) arguments. + +This is useful for rewriting: if a pattern matches `f x` but `e` is `f x y z`, we can +still apply the rewrite and return `(value, 2)` indicating 2 extra arguments. +-/ +public partial def getMatchWithExtra (d : DiscrTree α) (e : Expr) : Array (α × Nat) := + let (k, result) := getMatchCore d e + let result := result.map (·, 0) + if !e.isApp then + result + else if !mayMatchPrefix k then + result + else + go e.appFn! 1 result +where + mayMatchPrefix (k : Key) : Bool := + let cont (k : Key) : Bool := + if d.root.find? k |>.isSome then + true + else + mayMatchPrefix k + match k with + | .const f (n+1) => cont (.const f n) + | .fvar f (n+1) => cont (.fvar f n) + | _ => false + + go (e : Expr) (numExtra : Nat) (result : Array (α × Nat)) : Array (α × Nat) := + let result := result ++ (getMatchCore d e).2.map (., numExtra) + if e.isApp then + go e.appFn! (numExtra + 1) result + else + result end Lean.Meta.Sym diff --git a/src/Lean/Meta/Sym/Rewrite.lean b/src/Lean/Meta/Sym/Rewrite.lean index 2a230a7e95..435b1641ae 100644 --- a/src/Lean/Meta/Sym/Rewrite.lean +++ b/src/Lean/Meta/Sym/Rewrite.lean @@ -8,6 +8,7 @@ prelude public import Lean.Meta.Sym.SimpM public import Lean.Meta.Sym.SimpFun import Lean.Meta.Sym.InstantiateS +import Lean.Meta.Sym.DiscrTree namespace Lean.Meta.Sym.Simp open Grind @@ -41,8 +42,8 @@ public def Theorem.rewrite? (thm : Theorem) (e : Expr) : SimpM (Option Result) : return none public def rewrite : SimpFun := fun e => do - -- **TODO**: use indexing - for thm in (← read).thms.thms do + -- **TODO**: over-applied terms + for thm in (← read).thms.getMatch e do if let some result ← thm.rewrite? e then return result return { expr := e } diff --git a/src/Lean/Meta/Sym/SimpM.lean b/src/Lean/Meta/Sym/SimpM.lean index 8e8ba3bcc6..ee92172546 100644 --- a/src/Lean/Meta/Sym/SimpM.lean +++ b/src/Lean/Meta/Sym/SimpM.lean @@ -7,6 +7,7 @@ module prelude public import Lean.Meta.Sym.SymM public import Lean.Meta.Sym.Pattern +import Lean.Meta.Sym.DiscrTree public section namespace Lean.Meta.Sym.Simp @@ -129,10 +130,18 @@ structure Theorem where /-- Right-hand side of the equation. -/ rhs : Expr +instance : BEq Theorem where + beq thm₁ thm₂ := thm₁.expr == thm₂.expr + /-- Collection of simplification theorems available to the simplifier. -/ structure Theorems where - /-- **TODO**: No indexing for now. We will add a structural discrimination tree later. -/ - thms : Array Theorem := #[] + thms : DiscrTree Theorem := {} + +def Theorems.insert (thms : Theorems) (thm : Theorem) : Theorems := + { thms with thms := insertPattern thms.thms thm.pattern thm } + +def Theorems.getMatch (thms : Theorems) (e : Expr) : Array Theorem := + Sym.getMatch thms.thms e /-- Read-only context for the simplifier. -/ structure Context where @@ -182,7 +191,7 @@ abbrev getCache : SimpM Cache := end Simp -public def simp (e : Expr) (thms : Simp.Theorems := {}) (config : Simp.Config := {}) : SymM Simp.Result := do +def simp (e : Expr) (thms : Simp.Theorems := {}) (config : Simp.Config := {}) : SymM Simp.Result := do Simp.SimpM.run (Simp.simp e) thms config end Lean.Meta.Sym diff --git a/tests/bench/sym/simp_1.lean b/tests/bench/sym/simp_1.lean index d22438ee72..edc00984f8 100644 --- a/tests/bench/sym/simp_1.lean +++ b/tests/bench/sym/simp_1.lean @@ -8,8 +8,9 @@ namespace SimpBench -/ def mkSimpTheorems : MetaM Sym.Simp.Theorems := do - let thm ← Sym.Simp.mkTheoremFromDecl ``Nat.zero_add - return { thms := #[thm] } + let result : Sym.Simp.Theorems := {} + let result := result.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.zero_add) + return result def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run' do let e ← Grind.shareCommon e