feat: add DiscrTree.insert

This commit is contained in:
Leonardo de Moura 2019-11-23 09:07:21 -08:00
parent c6048f0f94
commit 84d582bf9a
3 changed files with 192 additions and 0 deletions

View file

@ -32,6 +32,14 @@ def Literal.beq : Literal → Literal → Bool
instance Literal.hasBeq : HasBeq Literal := ⟨Literal.beq⟩
def Literal.lt : Literal → Literal → Bool
| Literal.natVal _, Literal.strVal _ => true
| Literal.natVal v₁, Literal.natVal v₂ => v₁ < v₂
| Literal.strVal v₁, Literal.strVal v₂ => v₁ < v₂
| _, _ => false
instance Literal.hasLess : HasLess Literal := ⟨fun a b => a.lt b⟩
inductive BinderInfo
| default | implicit | strictImplicit | instImplicit | auxDecl

View file

@ -10,3 +10,4 @@ import Init.Lean.Meta.WHNF
import Init.Lean.Meta.InferType
import Init.Lean.Meta.FunInfo
import Init.Lean.Meta.ExprDefEq
import Init.Lean.Meta.DiscrTree

View file

@ -0,0 +1,183 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Lean.Meta.Basic
import Init.Lean.Meta.FunInfo
namespace Lean
namespace Meta
namespace DiscrTree
inductive Key
| const : Name → Nat → Key
| fvar : Name → Nat → Key
| lit : Literal → Key
| star : Key
| other : Key
instance Key.inhabited : Inhabited Key := ⟨Key.star⟩
def Key.hash : Key → USize
| Key.const n a => mixHash 5237 $ mixHash (hash n) (hash a)
| Key.fvar n a => mixHash 3541 $ mixHash (hash n) (hash a)
| Key.lit v => mixHash 1879 $ hash v
| Key.star => 7883
| Key.other => 2411
instance Key.hashable : Hashable Key := ⟨Key.hash⟩
def Key.beq : Key → Key → Bool
| Key.const c₁ a₁, Key.const c₂ a₂ => c₁ == c₂ && a₁ == a₂
| Key.fvar c₁ a₁, Key.fvar c₂ a₂ => c₁ == c₂ && a₁ == a₂
| Key.lit v₁, Key.lit v₂ => v₁ == v₂
| Key.star, Key.star => true
| Key.other, Key.other => true
| _, _ => false
instance Key.hasBeq : HasBeq Key := ⟨Key.beq⟩
def Key.lt : Key → Key → Bool
| Key.star, Key.star => false
| Key.star, _ => true
| Key.other, Key.star => false
| Key.other, Key.other => false
| Key.other, _ => true
| Key.lit v₁, Key.lit v₂ => v₁ < v₂
| Key.lit _, Key.const _ _ => true
| Key.lit _, Key.fvar _ _ => true
| Key.lit _, _ => false
| Key.fvar n₁ a₁, Key.fvar n₂ a₂ => Name.quickLt n₁ n₂ || (n₁ == n₂ && a₁ < a₂)
| Key.fvar _ _, Key.const _ _ => true
| Key.fvar _ _, _ => false
| Key.const n₁ a₁, Key.const n₂ a₂ => Name.quickLt n₁ n₂ || (n₁ == n₂ && a₁ < a₂)
| Key.const _ _, _ => false
instance Key.hasLess : HasLess Key := ⟨fun a b => Key.lt a b⟩
def Key.format : Key → Format
| Key.star => "*"
| Key.other => "◾"
| Key.lit (Literal.natVal v) => fmt v
| Key.lit (Literal.strVal v) => repr v
| Key.const k _ => fmt k
| Key.fvar k _ => fmt k
instance Key.hasFormat : HasFormat Key := ⟨Key.format⟩
inductive Trie (α : Type)
| node (vs : Array α) (children : Array (Key × Trie)) : Trie
instance Trie.inhabited {α} : Inhabited (Trie α) := ⟨Trie.node #[] #[]⟩
end DiscrTree
open DiscrTree
structure DiscrTree (α : Type) :=
(root : PersistentHashMap Key (Trie α) := {})
namespace DiscrTree
def empty {α} : DiscrTree α := { root := {} }
/- The discrimination tree ignores implicit arguments and proofs.
We use the following auxiliary id as a "mark". -/
private def tmpMVarId : Name := `_discr_tree_tmp
private def tmpStar := mkMVar tmpMVarId
private partial def pushArgsAux (infos : Array ParamInfo) : Nat → Expr → Array Expr → MetaM (Array Expr)
| i, Expr.app f a _, todo =>
if h : i < infos.size then
let info := infos.get ⟨i, h⟩;
if info.implicit || info.instImplicit || info.prop then
pushArgsAux (i-1) f (todo.push tmpStar)
else
pushArgsAux (i-1) f (todo.push a)
else
pushArgsAux (i-1) f (todo.push a)
| _, _, todo => pure todo
private def pushArgs (todo : Array Expr) (e : Expr) : MetaM (Key × Array Expr) :=
do e ← whnf e;
let fn := e.getAppFn;
let push (k : Key) (nargs : Nat) : MetaM (Key × Array Expr) := do {
info ← getFunInfoNArgs fn nargs;
todo ← pushArgsAux info.paramInfo (nargs-1) e todo;
pure (k, todo)
};
match fn with
| Expr.lit v _ => pure (Key.lit v, todo)
| Expr.const c _ _ => let nargs := e.getAppNumArgs; push (Key.const c nargs) nargs
| Expr.fvar fvarId _ => let nargs := e.getAppNumArgs; push (Key.fvar fvarId nargs) nargs
| Expr.mvar mvarId _ =>
if mvarId == `_tmp then
-- We use `tmp to mark implicit arguments and proofs
pure (Key.star, todo)
else condM (isReadOnlyOrSyntheticExprMVar mvarId)
(pure (Key.other, todo))
(pure (Key.star, todo))
| _ => pure (Key.other, todo)
private partial def createNodes {α} (v : α) : Array Expr → MetaM (Trie α)
| todo =>
if todo.isEmpty then pure $ Trie.node #[v] #[]
else do
let e := todo.back;
let todo := todo.pop;
(k, todo) ← pushArgs todo e;
c ← createNodes todo;
pure $ Trie.node #[] #[(k, c)]
private def insertVal {α} [HasBeq α] (vs : Array α) (v : α) : Array α :=
if vs.contains v then vs else vs.push v
private partial def insertAux {α} [HasBeq α] (v : α) : Array Expr → Trie α → MetaM (Trie α)
| todo, Trie.node vs cs =>
if todo.isEmpty then
pure $ Trie.node (insertVal vs v) cs
else do
let e := todo.back;
let todo := todo.pop;
(k, todo) ← pushArgs todo e;
c ← cs.binInsertM
(fun a b => a.1 < b.1)
(fun ⟨_, s⟩ => do c ← insertAux todo s; pure (k, c)) -- merge with existing
(fun _ => do c ← createNodes v todo; pure (k, c))
(k, arbitrary _);
pure $ Trie.node vs c
private def initCapacity := 16
private def insert {α} [HasBeq α] (d : DiscrTree α) (e : Expr) (v : α) : MetaM (DiscrTree α) :=
usingTransparency TransparencyMode.reducible $ do
(k, todo) ← pushArgs (Array.mkEmpty initCapacity) e;
match d.root.find k with
| none => do
c ← createNodes v todo;
pure $ { root := d.root.insert k c }
| some c => do
c ← insertAux v todo c;
pure $ { root := d.root.insert k c }
partial def Trie.format {α} [HasFormat α] : Trie α → Format
| Trie.node vs cs => Format.group $ Format.paren $
"node" ++ (if vs.isEmpty then Format.nil else " " ++ fmt vs)
++ Format.join (cs.toList.map $ fun ⟨k, c⟩ => Format.line ++ Format.paren (fmt k ++ " => " ++ Trie.format c))
instance Trie.hasFormat {α} [HasFormat α] : HasFormat (Trie α) := ⟨Trie.format⟩
partial def format {α} [HasFormat α] (d : DiscrTree α) : Format :=
let (_, r) := d.root.foldl
(fun (p : Bool × Format) k c =>
(false, p.2 ++ (if p.1 then Format.line else Format.nil) ++ Format.paren (fmt k ++ " => " ++ fmt c)))
(true, Format.nil);
Format.group r
instance DiscrTree.hasFormat {α} [HasFormat α] : HasFormat (DiscrTree α) := ⟨format⟩
end DiscrTree
end Meta
end Lean