feat: add discrimination tree retrieval for Sym (#11886)
This PR adds `getMatch` and `getMatchWithExtra` for retrieving patterns from discrimination trees in the symbolic simulation framework. The PR also adds uses `DiscrTree` to implement indexing in `Sym.simp`.
This commit is contained in:
parent
19df2c41b3
commit
b40dabdecd
4 changed files with 115 additions and 9 deletions
|
|
@ -132,8 +132,103 @@ public def insertPattern [BEq α] (d : DiscrTree α) (p : Pattern) (v : α) : Di
|
|||
let keys := p.mkDiscrTreeKeys
|
||||
d.insertKeyValue keys v
|
||||
|
||||
/-!
|
||||
**TODO** Retrieval.
|
||||
def getKeyArgs (e : Expr) : Key × Array Expr :=
|
||||
match e.getAppFn with
|
||||
| .lit v => (.lit v, #[])
|
||||
| .const declName _ => (.const declName e.getAppNumArgs, e.getAppRevArgs)
|
||||
| .fvar fvarId => (.fvar fvarId e.getAppNumArgs, e.getAppRevArgs)
|
||||
| .forallE _ d b _ => (.arrow, #[b, d])
|
||||
| _ => (.other, #[])
|
||||
|
||||
abbrev findKey? (cs : Array (Key × Trie α)) (k : Key) : Option (Key × Trie α) :=
|
||||
cs.binSearch (k, default) (fun a b => a.1 < b.1)
|
||||
|
||||
partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Array α) : Array α :=
|
||||
match c with
|
||||
| .node vs cs =>
|
||||
if todo.isEmpty then
|
||||
result ++ vs
|
||||
else if cs.isEmpty then
|
||||
result
|
||||
else
|
||||
let e := todo.back!
|
||||
let todo := todo.pop
|
||||
let first := cs[0]! /- Recall that `Key.star` is the minimal key -/
|
||||
let (k, args) := getKeyArgs e
|
||||
/- We must always visit `Key.star` edges since they are wildcards.
|
||||
Thus, `todo` is not used linearly when there is `Key.star` edge
|
||||
and there is an edge for `k` and `k != Key.star`. -/
|
||||
let visitStar (result : Array α) : Array α :=
|
||||
if first.1 == .star then
|
||||
getMatchLoop todo first.2 result
|
||||
else
|
||||
result
|
||||
let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : Array α :=
|
||||
match findKey? cs k with
|
||||
| none => result
|
||||
| some c => getMatchLoop (todo ++ args) c.2 result
|
||||
let result := visitStar result
|
||||
match k with
|
||||
| .star => result
|
||||
| _ => visitNonStar k args result
|
||||
|
||||
def getMatchRoot (d : DiscrTree α) (k : Key) (args : Array Expr) (result : Array α) : Array α :=
|
||||
match d.root.find? k with
|
||||
| none => result
|
||||
| some c => getMatchLoop args c result
|
||||
|
||||
def getStarResult (d : DiscrTree α) : Array α :=
|
||||
let result : Array α := .mkEmpty initCapacity
|
||||
match d.root.find? .star with
|
||||
| none => result
|
||||
| some (.node vs _) => result ++ vs
|
||||
|
||||
def getMatchCore (d : DiscrTree α) (e : Expr) : Key × Array α :=
|
||||
let result := getStarResult d
|
||||
let (k, args) := getKeyArgs e
|
||||
match k with
|
||||
| .star => (k, result)
|
||||
| _ => (k, getMatchRoot d k args result)
|
||||
|
||||
/--
|
||||
Retrieves all values whose patterns match the expression `e`.
|
||||
-/
|
||||
public def getMatch (d : DiscrTree α) (e : Expr) : Array α :=
|
||||
getMatchCore d e |>.2
|
||||
|
||||
/--
|
||||
Retrieves all values whose patterns match a prefix of `e`, along with the number of
|
||||
extra (ignored) arguments.
|
||||
|
||||
This is useful for rewriting: if a pattern matches `f x` but `e` is `f x y z`, we can
|
||||
still apply the rewrite and return `(value, 2)` indicating 2 extra arguments.
|
||||
-/
|
||||
public partial def getMatchWithExtra (d : DiscrTree α) (e : Expr) : Array (α × Nat) :=
|
||||
let (k, result) := getMatchCore d e
|
||||
let result := result.map (·, 0)
|
||||
if !e.isApp then
|
||||
result
|
||||
else if !mayMatchPrefix k then
|
||||
result
|
||||
else
|
||||
go e.appFn! 1 result
|
||||
where
|
||||
mayMatchPrefix (k : Key) : Bool :=
|
||||
let cont (k : Key) : Bool :=
|
||||
if d.root.find? k |>.isSome then
|
||||
true
|
||||
else
|
||||
mayMatchPrefix k
|
||||
match k with
|
||||
| .const f (n+1) => cont (.const f n)
|
||||
| .fvar f (n+1) => cont (.fvar f n)
|
||||
| _ => false
|
||||
|
||||
go (e : Expr) (numExtra : Nat) (result : Array (α × Nat)) : Array (α × Nat) :=
|
||||
let result := result ++ (getMatchCore d e).2.map (., numExtra)
|
||||
if e.isApp then
|
||||
go e.appFn! (numExtra + 1) result
|
||||
else
|
||||
result
|
||||
|
||||
end Lean.Meta.Sym
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ prelude
|
|||
public import Lean.Meta.Sym.SimpM
|
||||
public import Lean.Meta.Sym.SimpFun
|
||||
import Lean.Meta.Sym.InstantiateS
|
||||
import Lean.Meta.Sym.DiscrTree
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
open Grind
|
||||
|
||||
|
|
@ -41,8 +42,8 @@ public def Theorem.rewrite? (thm : Theorem) (e : Expr) : SimpM (Option Result) :
|
|||
return none
|
||||
|
||||
public def rewrite : SimpFun := fun e => do
|
||||
-- **TODO**: use indexing
|
||||
for thm in (← read).thms.thms do
|
||||
-- **TODO**: over-applied terms
|
||||
for thm in (← read).thms.getMatch e do
|
||||
if let some result ← thm.rewrite? e then
|
||||
return result
|
||||
return { expr := e }
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ module
|
|||
prelude
|
||||
public import Lean.Meta.Sym.SymM
|
||||
public import Lean.Meta.Sym.Pattern
|
||||
import Lean.Meta.Sym.DiscrTree
|
||||
public section
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
|
||||
|
|
@ -129,10 +130,18 @@ structure Theorem where
|
|||
/-- Right-hand side of the equation. -/
|
||||
rhs : Expr
|
||||
|
||||
instance : BEq Theorem where
|
||||
beq thm₁ thm₂ := thm₁.expr == thm₂.expr
|
||||
|
||||
/-- Collection of simplification theorems available to the simplifier. -/
|
||||
structure Theorems where
|
||||
/-- **TODO**: No indexing for now. We will add a structural discrimination tree later. -/
|
||||
thms : Array Theorem := #[]
|
||||
thms : DiscrTree Theorem := {}
|
||||
|
||||
def Theorems.insert (thms : Theorems) (thm : Theorem) : Theorems :=
|
||||
{ thms with thms := insertPattern thms.thms thm.pattern thm }
|
||||
|
||||
def Theorems.getMatch (thms : Theorems) (e : Expr) : Array Theorem :=
|
||||
Sym.getMatch thms.thms e
|
||||
|
||||
/-- Read-only context for the simplifier. -/
|
||||
structure Context where
|
||||
|
|
@ -182,7 +191,7 @@ abbrev getCache : SimpM Cache :=
|
|||
|
||||
end Simp
|
||||
|
||||
public def simp (e : Expr) (thms : Simp.Theorems := {}) (config : Simp.Config := {}) : SymM Simp.Result := do
|
||||
def simp (e : Expr) (thms : Simp.Theorems := {}) (config : Simp.Config := {}) : SymM Simp.Result := do
|
||||
Simp.SimpM.run (Simp.simp e) thms config
|
||||
|
||||
end Lean.Meta.Sym
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ namespace SimpBench
|
|||
-/
|
||||
|
||||
def mkSimpTheorems : MetaM Sym.Simp.Theorems := do
|
||||
let thm ← Sym.Simp.mkTheoremFromDecl ``Nat.zero_add
|
||||
return { thms := #[thm] }
|
||||
let result : Sym.Simp.Theorems := {}
|
||||
let result := result.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.zero_add)
|
||||
return result
|
||||
|
||||
def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run' do
|
||||
let e ← Grind.shareCommon e
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue