perf: Sym.Simp.DiscrTree retrieval (#11889)
This PR improves the discrimination tree retrieval performance used by `Sym.simp`.
This commit is contained in:
parent
bc72487aed
commit
a2cf78ac4a
1 changed files with 41 additions and 54 deletions
|
|
@ -44,16 +44,6 @@ Retrieval should use the standard `DiscrTree.getMatch` or similar, which will fi
|
|||
whose key sequence is compatible with the query term.
|
||||
-/
|
||||
|
||||
/--
|
||||
Returns the number of child keys for a given discrimination tree key.
|
||||
**Note**: Unlike the standard `DiscrTree` module, `Key.arrow` has arity 2.
|
||||
-/
|
||||
def getKeyArity : Key → Nat
|
||||
| .const _ a => a
|
||||
| .fvar _ a => a
|
||||
| .arrow => 2
|
||||
| _ => 0
|
||||
|
||||
/-- Returns `true` if argument at position `i` should be ignored (is a proof or instance). -/
|
||||
def ignoreArg (infos : Array ProofInstArgInfo) (i : Nat) : Bool :=
|
||||
if h : i < infos.size then
|
||||
|
|
@ -132,69 +122,66 @@ public def insertPattern [BEq α] (d : DiscrTree α) (p : Pattern) (v : α) : Di
|
|||
let keys := p.mkDiscrTreeKeys
|
||||
d.insertKeyValue keys v
|
||||
|
||||
def getKeyArgs (e : Expr) : Key × Array Expr :=
|
||||
match e.getAppFn with
|
||||
| .lit v => (.lit v, #[])
|
||||
| .const declName _ => (.const declName e.getAppNumArgs, e.getAppRevArgs)
|
||||
| .fvar fvarId => (.fvar fvarId e.getAppNumArgs, e.getAppRevArgs)
|
||||
| .forallE _ d b _ => (.arrow, #[b, d])
|
||||
| _ => (.other, #[])
|
||||
|
||||
abbrev findKey? (cs : Array (Key × Trie α)) (k : Key) : Option (Key × Trie α) :=
|
||||
cs.binSearch (k, default) (fun a b => a.1 < b.1)
|
||||
|
||||
def getKey (e : Expr) : Key :=
|
||||
match e.getAppFn with
|
||||
| .lit v => .lit v
|
||||
| .const declName _ => .const declName e.getAppNumArgs
|
||||
| .fvar fvarId => .fvar fvarId e.getAppNumArgs
|
||||
| .forallE _ _ _ _ => .arrow
|
||||
| _ => .other
|
||||
|
||||
/-- Push `e` arguments/children into the `todo` stack. -/
|
||||
def pushArgsTodo (todo : Array Expr) (e : Expr) : Array Expr :=
|
||||
match e with
|
||||
| .app f a => pushArgsTodo (todo.push a) f
|
||||
| .forallE _ d b _ => todo.push b |>.push d
|
||||
| _ => todo
|
||||
|
||||
partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Array α) : Array α :=
|
||||
match c with
|
||||
| .node vs cs =>
|
||||
let csize := cs.size
|
||||
if todo.isEmpty then
|
||||
result ++ vs
|
||||
else if cs.isEmpty then
|
||||
else if h : csize = 0 then
|
||||
result
|
||||
else
|
||||
let e := todo.back!
|
||||
let todo := todo.pop
|
||||
let first := cs[0]! /- Recall that `Key.star` is the minimal key -/
|
||||
let (k, args) := getKeyArgs e
|
||||
/- We must always visit `Key.star` edges since they are wildcards.
|
||||
Thus, `todo` is not used linearly when there is `Key.star` edge
|
||||
and there is an edge for `k` and `k != Key.star`. -/
|
||||
let visitStar (result : Array α) : Array α :=
|
||||
let first := cs[0] /- Recall that `Key.star` is the minimal key -/
|
||||
if csize = 1 then
|
||||
/- Special case: only one child node -/
|
||||
if first.1 == .star then
|
||||
getMatchLoop todo first.2 result
|
||||
else if first.1 == getKey e then
|
||||
getMatchLoop (pushArgsTodo todo e) first.2 result
|
||||
else
|
||||
result
|
||||
else
|
||||
/- We must always visit `Key.star` edges since they are wildcards.
|
||||
Thus, `todo` is not used linearly when there is `Key.star` edge
|
||||
and there is an edge for `k` and `k != Key.star`. -/
|
||||
let result := if first.1 == .star then
|
||||
getMatchLoop todo first.2 result
|
||||
else
|
||||
result
|
||||
let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : Array α :=
|
||||
match findKey? cs k with
|
||||
match findKey? cs (getKey e) with
|
||||
| none => result
|
||||
| some c => getMatchLoop (todo ++ args) c.2 result
|
||||
let result := visitStar result
|
||||
match k with
|
||||
| .star => result
|
||||
| _ => visitNonStar k args result
|
||||
|
||||
def getMatchRoot (d : DiscrTree α) (k : Key) (args : Array Expr) (result : Array α) : Array α :=
|
||||
match d.root.find? k with
|
||||
| none => result
|
||||
| some c => getMatchLoop args c result
|
||||
|
||||
def getStarResult (d : DiscrTree α) : Array α :=
|
||||
let result : Array α := .mkEmpty initCapacity
|
||||
match d.root.find? .star with
|
||||
| none => result
|
||||
| some (.node vs _) => result ++ vs
|
||||
|
||||
def getMatchCore (d : DiscrTree α) (e : Expr) : Key × Array α :=
|
||||
let result := getStarResult d
|
||||
let (k, args) := getKeyArgs e
|
||||
match k with
|
||||
| .star => (k, result)
|
||||
| _ => (k, getMatchRoot d k args result)
|
||||
| some c => getMatchLoop (pushArgsTodo todo e) c.2 result
|
||||
|
||||
/--
|
||||
Retrieves all values whose patterns match the expression `e`.
|
||||
-/
|
||||
public def getMatch (d : DiscrTree α) (e : Expr) : Array α :=
|
||||
getMatchCore d e |>.2
|
||||
let result := match d.root.find? .star with
|
||||
| none => .mkEmpty initCapacity
|
||||
| some (.node vs _) => vs
|
||||
match d.root.find? (getKey e) with
|
||||
| none => result
|
||||
| some c => getMatchLoop (pushArgsTodo #[] e) c result
|
||||
|
||||
/--
|
||||
Retrieves all values whose patterns match a prefix of `e`, along with the number of
|
||||
|
|
@ -204,11 +191,11 @@ This is useful for rewriting: if a pattern matches `f x` but `e` is `f x y z`, w
|
|||
still apply the rewrite and return `(value, 2)` indicating 2 extra arguments.
|
||||
-/
|
||||
public partial def getMatchWithExtra (d : DiscrTree α) (e : Expr) : Array (α × Nat) :=
|
||||
let (k, result) := getMatchCore d e
|
||||
let result := getMatch d e
|
||||
let result := result.map (·, 0)
|
||||
if !e.isApp then
|
||||
result
|
||||
else if !mayMatchPrefix k then
|
||||
else if !mayMatchPrefix (getKey e) then
|
||||
result
|
||||
else
|
||||
go e.appFn! 1 result
|
||||
|
|
@ -225,7 +212,7 @@ where
|
|||
| _ => false
|
||||
|
||||
go (e : Expr) (numExtra : Nat) (result : Array (α × Nat)) : Array (α × Nat) :=
|
||||
let result := result ++ (getMatchCore d e).2.map (., numExtra)
|
||||
let result := result ++ (getMatch d e).map (., numExtra)
|
||||
if e.isApp then
|
||||
go e.appFn! (numExtra + 1) result
|
||||
else
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue