From a2cf78ac4aec99ea23d162fa143566ef6d47ad56 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 3 Jan 2026 19:51:56 -0800 Subject: [PATCH] perf: `Sym.Simp.DiscrTree` retrieval (#11889) This PR improves the discrimination tree retrieval performance used by `Sym.simp`. --- src/Lean/Meta/Sym/Simp/DiscrTree.lean | 95 ++++++++++++--------------- 1 file changed, 41 insertions(+), 54 deletions(-) diff --git a/src/Lean/Meta/Sym/Simp/DiscrTree.lean b/src/Lean/Meta/Sym/Simp/DiscrTree.lean index 789cc6dc62..5ca55fbc10 100644 --- a/src/Lean/Meta/Sym/Simp/DiscrTree.lean +++ b/src/Lean/Meta/Sym/Simp/DiscrTree.lean @@ -44,16 +44,6 @@ Retrieval should use the standard `DiscrTree.getMatch` or similar, which will fi whose key sequence is compatible with the query term. -/ -/-- -Returns the number of child keys for a given discrimination tree key. -**Note**: Unlike the standard `DiscrTree` module, `Key.arrow` has arity 2. --/ -def getKeyArity : Key → Nat - | .const _ a => a - | .fvar _ a => a - | .arrow => 2 - | _ => 0 - /-- Returns `true` if argument at position `i` should be ignored (is a proof or instance). -/ def ignoreArg (infos : Array ProofInstArgInfo) (i : Nat) : Bool := if h : i < infos.size then @@ -132,69 +122,66 @@ public def insertPattern [BEq α] (d : DiscrTree α) (p : Pattern) (v : α) : Di let keys := p.mkDiscrTreeKeys d.insertKeyValue keys v -def getKeyArgs (e : Expr) : Key × Array Expr := - match e.getAppFn with - | .lit v => (.lit v, #[]) - | .const declName _ => (.const declName e.getAppNumArgs, e.getAppRevArgs) - | .fvar fvarId => (.fvar fvarId e.getAppNumArgs, e.getAppRevArgs) - | .forallE _ d b _ => (.arrow, #[b, d]) - | _ => (.other, #[]) - abbrev findKey? (cs : Array (Key × Trie α)) (k : Key) : Option (Key × Trie α) := cs.binSearch (k, default) (fun a b => a.1 < b.1) +def getKey (e : Expr) : Key := + match e.getAppFn with + | .lit v => .lit v + | .const declName _ => .const declName e.getAppNumArgs + | .fvar fvarId => .fvar fvarId e.getAppNumArgs + | .forallE _ _ _ _ => .arrow + | _ => .other + +/-- Push `e` arguments/children into the `todo` stack. -/ +def pushArgsTodo (todo : Array Expr) (e : Expr) : Array Expr := + match e with + | .app f a => pushArgsTodo (todo.push a) f + | .forallE _ d b _ => todo.push b |>.push d + | _ => todo + partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Array α) : Array α := match c with | .node vs cs => + let csize := cs.size if todo.isEmpty then result ++ vs - else if cs.isEmpty then + else if h : csize = 0 then result else let e := todo.back! let todo := todo.pop - let first := cs[0]! /- Recall that `Key.star` is the minimal key -/ - let (k, args) := getKeyArgs e - /- We must always visit `Key.star` edges since they are wildcards. - Thus, `todo` is not used linearly when there is `Key.star` edge - and there is an edge for `k` and `k != Key.star`. -/ - let visitStar (result : Array α) : Array α := + let first := cs[0] /- Recall that `Key.star` is the minimal key -/ + if csize = 1 then + /- Special case: only one child node -/ if first.1 == .star then getMatchLoop todo first.2 result + else if first.1 == getKey e then + getMatchLoop (pushArgsTodo todo e) first.2 result + else + result + else + /- We must always visit `Key.star` edges since they are wildcards. + Thus, `todo` is not used linearly when there is `Key.star` edge + and there is an edge for `k` and `k != Key.star`. -/ + let result := if first.1 == .star then + getMatchLoop todo first.2 result else result - let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : Array α := - match findKey? cs k with + match findKey? cs (getKey e) with | none => result - | some c => getMatchLoop (todo ++ args) c.2 result - let result := visitStar result - match k with - | .star => result - | _ => visitNonStar k args result - -def getMatchRoot (d : DiscrTree α) (k : Key) (args : Array Expr) (result : Array α) : Array α := - match d.root.find? k with - | none => result - | some c => getMatchLoop args c result - -def getStarResult (d : DiscrTree α) : Array α := - let result : Array α := .mkEmpty initCapacity - match d.root.find? .star with - | none => result - | some (.node vs _) => result ++ vs - -def getMatchCore (d : DiscrTree α) (e : Expr) : Key × Array α := - let result := getStarResult d - let (k, args) := getKeyArgs e - match k with - | .star => (k, result) - | _ => (k, getMatchRoot d k args result) + | some c => getMatchLoop (pushArgsTodo todo e) c.2 result /-- Retrieves all values whose patterns match the expression `e`. -/ public def getMatch (d : DiscrTree α) (e : Expr) : Array α := - getMatchCore d e |>.2 + let result := match d.root.find? .star with + | none => .mkEmpty initCapacity + | some (.node vs _) => vs + match d.root.find? (getKey e) with + | none => result + | some c => getMatchLoop (pushArgsTodo #[] e) c result /-- Retrieves all values whose patterns match a prefix of `e`, along with the number of @@ -204,11 +191,11 @@ This is useful for rewriting: if a pattern matches `f x` but `e` is `f x y z`, w still apply the rewrite and return `(value, 2)` indicating 2 extra arguments. -/ public partial def getMatchWithExtra (d : DiscrTree α) (e : Expr) : Array (α × Nat) := - let (k, result) := getMatchCore d e + let result := getMatch d e let result := result.map (·, 0) if !e.isApp then result - else if !mayMatchPrefix k then + else if !mayMatchPrefix (getKey e) then result else go e.appFn! 1 result @@ -225,7 +212,7 @@ where | _ => false go (e : Expr) (numExtra : Nat) (result : Array (α × Nat)) : Array (α × Nat) := - let result := result ++ (getMatchCore d e).2.map (., numExtra) + let result := result ++ (getMatch d e).map (., numExtra) if e.isApp then go e.appFn! (numExtra + 1) result else