feat: add DiscrTree.insert
This commit is contained in:
parent
c6048f0f94
commit
84d582bf9a
3 changed files with 192 additions and 0 deletions
|
|
@ -32,6 +32,14 @@ def Literal.beq : Literal → Literal → Bool
|
|||
|
||||
instance Literal.hasBeq : HasBeq Literal := ⟨Literal.beq⟩
|
||||
|
||||
def Literal.lt : Literal → Literal → Bool
|
||||
| Literal.natVal _, Literal.strVal _ => true
|
||||
| Literal.natVal v₁, Literal.natVal v₂ => v₁ < v₂
|
||||
| Literal.strVal v₁, Literal.strVal v₂ => v₁ < v₂
|
||||
| _, _ => false
|
||||
|
||||
instance Literal.hasLess : HasLess Literal := ⟨fun a b => a.lt b⟩
|
||||
|
||||
inductive BinderInfo
|
||||
| default | implicit | strictImplicit | instImplicit | auxDecl
|
||||
|
||||
|
|
|
|||
|
|
@ -10,3 +10,4 @@ import Init.Lean.Meta.WHNF
|
|||
import Init.Lean.Meta.InferType
|
||||
import Init.Lean.Meta.FunInfo
|
||||
import Init.Lean.Meta.ExprDefEq
|
||||
import Init.Lean.Meta.DiscrTree
|
||||
|
|
|
|||
183
src/Init/Lean/Meta/DiscrTree.lean
Normal file
183
src/Init/Lean/Meta/DiscrTree.lean
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
/-
|
||||
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Lean.Meta.Basic
|
||||
import Init.Lean.Meta.FunInfo
|
||||
|
||||
namespace Lean
|
||||
namespace Meta
|
||||
namespace DiscrTree
|
||||
|
||||
inductive Key
|
||||
| const : Name → Nat → Key
|
||||
| fvar : Name → Nat → Key
|
||||
| lit : Literal → Key
|
||||
| star : Key
|
||||
| other : Key
|
||||
|
||||
instance Key.inhabited : Inhabited Key := ⟨Key.star⟩
|
||||
|
||||
def Key.hash : Key → USize
|
||||
| Key.const n a => mixHash 5237 $ mixHash (hash n) (hash a)
|
||||
| Key.fvar n a => mixHash 3541 $ mixHash (hash n) (hash a)
|
||||
| Key.lit v => mixHash 1879 $ hash v
|
||||
| Key.star => 7883
|
||||
| Key.other => 2411
|
||||
|
||||
instance Key.hashable : Hashable Key := ⟨Key.hash⟩
|
||||
|
||||
def Key.beq : Key → Key → Bool
|
||||
| Key.const c₁ a₁, Key.const c₂ a₂ => c₁ == c₂ && a₁ == a₂
|
||||
| Key.fvar c₁ a₁, Key.fvar c₂ a₂ => c₁ == c₂ && a₁ == a₂
|
||||
| Key.lit v₁, Key.lit v₂ => v₁ == v₂
|
||||
| Key.star, Key.star => true
|
||||
| Key.other, Key.other => true
|
||||
| _, _ => false
|
||||
|
||||
instance Key.hasBeq : HasBeq Key := ⟨Key.beq⟩
|
||||
|
||||
def Key.lt : Key → Key → Bool
|
||||
| Key.star, Key.star => false
|
||||
| Key.star, _ => true
|
||||
| Key.other, Key.star => false
|
||||
| Key.other, Key.other => false
|
||||
| Key.other, _ => true
|
||||
| Key.lit v₁, Key.lit v₂ => v₁ < v₂
|
||||
| Key.lit _, Key.const _ _ => true
|
||||
| Key.lit _, Key.fvar _ _ => true
|
||||
| Key.lit _, _ => false
|
||||
| Key.fvar n₁ a₁, Key.fvar n₂ a₂ => Name.quickLt n₁ n₂ || (n₁ == n₂ && a₁ < a₂)
|
||||
| Key.fvar _ _, Key.const _ _ => true
|
||||
| Key.fvar _ _, _ => false
|
||||
| Key.const n₁ a₁, Key.const n₂ a₂ => Name.quickLt n₁ n₂ || (n₁ == n₂ && a₁ < a₂)
|
||||
| Key.const _ _, _ => false
|
||||
|
||||
instance Key.hasLess : HasLess Key := ⟨fun a b => Key.lt a b⟩
|
||||
|
||||
def Key.format : Key → Format
|
||||
| Key.star => "*"
|
||||
| Key.other => "◾"
|
||||
| Key.lit (Literal.natVal v) => fmt v
|
||||
| Key.lit (Literal.strVal v) => repr v
|
||||
| Key.const k _ => fmt k
|
||||
| Key.fvar k _ => fmt k
|
||||
|
||||
instance Key.hasFormat : HasFormat Key := ⟨Key.format⟩
|
||||
|
||||
inductive Trie (α : Type)
|
||||
| node (vs : Array α) (children : Array (Key × Trie)) : Trie
|
||||
|
||||
instance Trie.inhabited {α} : Inhabited (Trie α) := ⟨Trie.node #[] #[]⟩
|
||||
|
||||
end DiscrTree
|
||||
|
||||
open DiscrTree
|
||||
|
||||
structure DiscrTree (α : Type) :=
|
||||
(root : PersistentHashMap Key (Trie α) := {})
|
||||
|
||||
namespace DiscrTree
|
||||
|
||||
def empty {α} : DiscrTree α := { root := {} }
|
||||
|
||||
/- The discrimination tree ignores implicit arguments and proofs.
|
||||
We use the following auxiliary id as a "mark". -/
|
||||
private def tmpMVarId : Name := `_discr_tree_tmp
|
||||
private def tmpStar := mkMVar tmpMVarId
|
||||
|
||||
private partial def pushArgsAux (infos : Array ParamInfo) : Nat → Expr → Array Expr → MetaM (Array Expr)
|
||||
| i, Expr.app f a _, todo =>
|
||||
if h : i < infos.size then
|
||||
let info := infos.get ⟨i, h⟩;
|
||||
if info.implicit || info.instImplicit || info.prop then
|
||||
pushArgsAux (i-1) f (todo.push tmpStar)
|
||||
else
|
||||
pushArgsAux (i-1) f (todo.push a)
|
||||
else
|
||||
pushArgsAux (i-1) f (todo.push a)
|
||||
| _, _, todo => pure todo
|
||||
|
||||
private def pushArgs (todo : Array Expr) (e : Expr) : MetaM (Key × Array Expr) :=
|
||||
do e ← whnf e;
|
||||
let fn := e.getAppFn;
|
||||
let push (k : Key) (nargs : Nat) : MetaM (Key × Array Expr) := do {
|
||||
info ← getFunInfoNArgs fn nargs;
|
||||
todo ← pushArgsAux info.paramInfo (nargs-1) e todo;
|
||||
pure (k, todo)
|
||||
};
|
||||
match fn with
|
||||
| Expr.lit v _ => pure (Key.lit v, todo)
|
||||
| Expr.const c _ _ => let nargs := e.getAppNumArgs; push (Key.const c nargs) nargs
|
||||
| Expr.fvar fvarId _ => let nargs := e.getAppNumArgs; push (Key.fvar fvarId nargs) nargs
|
||||
| Expr.mvar mvarId _ =>
|
||||
if mvarId == `_tmp then
|
||||
-- We use `tmp to mark implicit arguments and proofs
|
||||
pure (Key.star, todo)
|
||||
else condM (isReadOnlyOrSyntheticExprMVar mvarId)
|
||||
(pure (Key.other, todo))
|
||||
(pure (Key.star, todo))
|
||||
| _ => pure (Key.other, todo)
|
||||
|
||||
private partial def createNodes {α} (v : α) : Array Expr → MetaM (Trie α)
|
||||
| todo =>
|
||||
if todo.isEmpty then pure $ Trie.node #[v] #[]
|
||||
else do
|
||||
let e := todo.back;
|
||||
let todo := todo.pop;
|
||||
(k, todo) ← pushArgs todo e;
|
||||
c ← createNodes todo;
|
||||
pure $ Trie.node #[] #[(k, c)]
|
||||
|
||||
private def insertVal {α} [HasBeq α] (vs : Array α) (v : α) : Array α :=
|
||||
if vs.contains v then vs else vs.push v
|
||||
|
||||
private partial def insertAux {α} [HasBeq α] (v : α) : Array Expr → Trie α → MetaM (Trie α)
|
||||
| todo, Trie.node vs cs =>
|
||||
if todo.isEmpty then
|
||||
pure $ Trie.node (insertVal vs v) cs
|
||||
else do
|
||||
let e := todo.back;
|
||||
let todo := todo.pop;
|
||||
(k, todo) ← pushArgs todo e;
|
||||
c ← cs.binInsertM
|
||||
(fun a b => a.1 < b.1)
|
||||
(fun ⟨_, s⟩ => do c ← insertAux todo s; pure (k, c)) -- merge with existing
|
||||
(fun _ => do c ← createNodes v todo; pure (k, c))
|
||||
(k, arbitrary _);
|
||||
pure $ Trie.node vs c
|
||||
|
||||
private def initCapacity := 16
|
||||
|
||||
private def insert {α} [HasBeq α] (d : DiscrTree α) (e : Expr) (v : α) : MetaM (DiscrTree α) :=
|
||||
usingTransparency TransparencyMode.reducible $ do
|
||||
(k, todo) ← pushArgs (Array.mkEmpty initCapacity) e;
|
||||
match d.root.find k with
|
||||
| none => do
|
||||
c ← createNodes v todo;
|
||||
pure $ { root := d.root.insert k c }
|
||||
| some c => do
|
||||
c ← insertAux v todo c;
|
||||
pure $ { root := d.root.insert k c }
|
||||
|
||||
partial def Trie.format {α} [HasFormat α] : Trie α → Format
|
||||
| Trie.node vs cs => Format.group $ Format.paren $
|
||||
"node" ++ (if vs.isEmpty then Format.nil else " " ++ fmt vs)
|
||||
++ Format.join (cs.toList.map $ fun ⟨k, c⟩ => Format.line ++ Format.paren (fmt k ++ " => " ++ Trie.format c))
|
||||
|
||||
instance Trie.hasFormat {α} [HasFormat α] : HasFormat (Trie α) := ⟨Trie.format⟩
|
||||
|
||||
partial def format {α} [HasFormat α] (d : DiscrTree α) : Format :=
|
||||
let (_, r) := d.root.foldl
|
||||
(fun (p : Bool × Format) k c =>
|
||||
(false, p.2 ++ (if p.1 then Format.line else Format.nil) ++ Format.paren (fmt k ++ " => " ++ fmt c)))
|
||||
(true, Format.nil);
|
||||
Format.group r
|
||||
|
||||
instance DiscrTree.hasFormat {α} [HasFormat α] : HasFormat (DiscrTree α) := ⟨format⟩
|
||||
|
||||
end DiscrTree
|
||||
end Meta
|
||||
end Lean
|
||||
Loading…
Add table
Reference in a new issue