lean4-htt/tests/playground/DiscrTree.lean
2020-10-27 18:29:19 -07:00

245 lines
9.1 KiB
Text
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import Lean.Format
open Lean
def List.insert {α} [BEq α] (as : List α) (a : α) : List α :=
if as.contains a then as else a::as
inductive Term
| var : Nat → Term
| app : String → Array Term → Term
instance : Inhabited Term := ⟨Term.var 0⟩
inductive Key
| var : Key
| sym : String → Nat → Key
instance : Inhabited Key := ⟨Key.var⟩
def Key.beq : Key → Key → Bool
| Key.var, Key.var => true
| Key.sym k₁ a₁, Key.sym k₂ a₂ => k₁ == k₂ && a₁ == a₂
| _, _ => false
instance : BEq Key := ⟨Key.beq⟩
def Key.lt : Key → Key → Bool
| Key.var, Key.var => false
| Key.var, _ => true
| Key.sym k₁ a₁, Key.sym k₂ a₂ => k₁ < k₂ || (k₁ == k₂ && a₁ < a₂)
| _, _ => false
instance : Less Key := ⟨fun k₁ k₂ => k₁.lt k₂⟩
def Key.format : Key → Format
| Key.var => "*"
| Key.sym k a => if a > 0 then k ++ "." ++ fmt a else k
instance : HasFormat Key := ⟨Key.format⟩
def Term.key : Term → Key
| Term.var _ => Key.var
| Term.app f as => Key.sym f as.size
def Term.args : Term → Array Term
| Term.var _ => #[]
| Term.app f as => as
-- TODO: root should be a persistent hash map
inductive Trie (α : Type)
| node (vals : List α) (children : Array (Key × Trie)) : Trie
namespace Trie
def empty {α} : Trie α :=
node [] #[]
instance {α} : Inhabited (Trie α) := ⟨empty⟩
partial def appendTodoAux (as : Array Term) : Nat → Array Term → Array Term
| 0, todo => todo
| i+1, todo => appendTodoAux i (todo.push (as.get! i))
def appendTodo (todo : Array Term) (as : Array Term) : Array Term :=
appendTodoAux as as.size todo
partial def createNodes {α} (v : α) : Array Term → Trie α
| todo =>
if todo.isEmpty then node [v] #[]
else
let t := todo.back;
let todo := todo.pop;
node [] #[(t.key, createNodes (appendTodo todo t.args))]
partial def insertAux {α} [BEq α] (v : α) : Array Term → Trie α → Trie α
| todo, node vs cs =>
if todo.isEmpty then node (vs.insert v) cs
else
let t := todo.back;
let todo := todo.pop;
let todo := appendTodo todo t.args;
let k := t.key;
node vs $ Id.run $
cs.binInsertM
(fun a b => a.1 < b.1)
(fun ⟨_, s⟩ => (k, insertAux todo s)) -- merge with existing
(fun _ => (k, createNodes v todo)) -- add new node
(k, arbitrary _)
def insert {α} [BEq α] (d : Trie α) (k : Term) (v : α) : Trie α :=
let todo : Array Term := Array.mkEmpty 32;
let todo := todo.push k;
insertAux v todo d
partial def format {α} [HasFormat α] : Trie α → Format
| node vs cs => Format.group $ Format.paren $ "node" ++ (if vs.isEmpty then Format.nil else " " ++ fmt vs) ++ Format.join (cs.toList.map $ fun ⟨k, c⟩ => Format.line ++ Format.paren (fmt k ++ " => " ++ format c))
instance {α} [HasFormat α] : HasFormat (Trie α) := ⟨format⟩
@[specialize] partial def foldMatchAux {α β} {m : Type → Type} [Monad m] (f : β → α → m β) : Array Term → Trie α → β → m β
| todo, node vs cs, b =>
if todo.isEmpty then vs.foldlM f b
else if cs.isEmpty then pure b
else
let t := todo.back;
let todo := todo.pop;
let first := cs.get! 0;
let k := t.key;
match k with
| Key.var => if first.1 == Key.var then foldMatchAux todo first.2 b else pure b
| Key.sym _ _ => do
match cs.binSearch (k, arbitrary _) (fun a b => a.1 < b.1) with
| none => if first.1 == Key.var then foldMatchAux todo first.2 b else pure b
| some c => do
b ← if first.1 == Key.var then foldMatchAux todo first.2 b else pure b;
let todo := appendTodo todo t.args;
foldMatchAux todo c.2 b
@[specialize] def foldMatch {α β} {m : Type → Type} [Monad m] (d : Trie α) (k : Term) (f : β → α → m β) (b : β) : m β :=
let todo : Array Term := Array.mkEmpty 32;
let todo := todo.push k;
foldMatchAux f todo d b
/-- Return all (approximate) matches (aka generalizations) of the term `k` -/
def getMatch {α} (d : Trie α) (k : Term) : Array α :=
Id.run $ d.foldMatch k (fun (r : Array α) v => pure $ r.push v) #[]
@[specialize] partial def foldUnifyAux {α β} {m : Type → Type} [Monad m] (f : β → α → m β) : Nat → Array Term → Trie α → β → m β
| skip+1, todo, node vs cs, b =>
if cs.isEmpty then pure b
else
cs.foldlM
(fun b ⟨k, c⟩ =>
match k with
| Key.var => foldUnifyAux skip todo c b
| Key.sym _ a => foldUnifyAux (skip + a) todo c b)
b
| 0, todo, node vs cs, b =>
if todo.isEmpty then vs.foldlM f b
else if cs.isEmpty then pure b
else
let t := todo.back;
let todo := todo.pop;
let first := cs.get! 0;
let k := t.key;
match k with
| Key.var =>
cs.foldlM
(fun b ⟨k, c⟩ =>
match k with
| Key.var => foldUnifyAux 0 todo c b
| Key.sym _ a => foldUnifyAux a todo c b)
b
| Key.sym _ _ => do
match cs.binSearch (k, arbitrary _) (fun a b => a.1 < b.1) with
| none => if first.1 == Key.var then foldUnifyAux 0 todo first.2 b else pure b
| some c => do
b ← if first.1 == Key.var then foldUnifyAux 0 todo first.2 b else pure b;
let todo := appendTodo todo t.args;
foldUnifyAux 0 todo c.2 b
@[specialize] def foldUnify {α β} {m : Type → Type} [Monad m] (d : Trie α) (k : Term) (f : β → α → m β) (b : β) : m β :=
let todo : Array Term := Array.mkEmpty 32;
let todo := todo.push k;
foldUnifyAux f 0 todo d b
/-- Return all candidate unifiers of the term `k` -/
def getUnify {α} (d : Trie α) (k : Term) : Array α :=
Id.run $ d.foldUnify k (fun (r : Array α) v => pure $ r.push v) #[]
end Trie
def mkApp (s : String) (cs : Array Term) := Term.app s cs
def mkConst (s : String) := Term.app s #[]
def mkVar (i : Nat) := Term.var i
def tst1 : IO Unit :=
let d := @Trie.empty Nat;
let t := mkApp "f" #[mkApp "g" #[mkConst "a"], mkApp "g" #[mkConst "b"]];
let d := d.insert t 10;
let t := mkApp "f" #[mkApp "h" #[mkConst "a", mkVar 0], mkConst "b"];
let d := d.insert t 20;
let t := mkApp "f" #[mkConst "b", mkConst "c", mkConst "d"];
let d := d.insert t 20;
let d := (20:Nat).fold
(fun i (d : Trie Nat) =>
let t := mkApp "f" #[mkApp "h" #[mkConst "a", mkVar 0], mkApp "f" #[mkConst ("c" ++ toString i)]];
d.insert t i)
d;
let d := (20:Nat).fold
(fun i (d : Trie Nat) =>
let t := mkApp "f" #[mkApp "g" #[mkConst ("a" ++ toString i)], mkApp "g" #[mkConst "b"]];
d.insert t i)
d;
-- let t := mkApp "g" [mkApp "h" [mkConst "a"]];
-- let d := d.insert t 10;
IO.println (format d)
#eval tst1
def check (as bs : Array Nat) : IO Unit :=
let as := as.qsort (fun a b => a < b);
let bs := bs.qsort (fun a b => a < b);
unless (as == bs) $ throw $ IO.userError "check failed"
def tst2 : IO Unit :=
do
let d := @Trie.empty Nat;
let d := d.insert (mkApp "f" #[mkVar 0, mkConst "a"]) 1; -- f * a
let d := d.insert (mkApp "f" #[mkConst "b", mkVar 0]) 2; -- f b *
let d := d.insert (mkApp "f" #[mkVar 0, mkVar 0]) 3; -- f * *
let d := d.insert (mkApp "f" #[mkVar 0, mkConst "b"]) 4; -- f * b
let d := d.insert (mkApp "f" #[mkApp "h" #[mkVar 0], mkConst "b"]) 5; -- f (h *) b
let d := d.insert (mkApp "f" #[mkApp "h" #[mkConst "a"], mkConst "b"]) 6; -- f (h a) b
let d := d.insert (mkApp "f" #[mkApp "h" #[mkConst "a"], mkVar 1]) 7; -- f (h a) *
let d := d.insert (mkApp "f" #[mkApp "h" #[mkConst "a"], mkVar 0]) 8; -- f (h a) *
let d := d.insert (mkApp "f" #[mkApp "h" #[mkVar 0], mkApp "h" #[mkConst "b"]]) 9; -- f (h *) (h b)
let d := d.insert (mkApp "g" #[mkVar 0, mkConst "a"]) 10; -- g * a
let d := d.insert (mkApp "g" #[mkConst "b", mkVar 0]) 11; -- g b *
let d := d.insert (mkApp "g" #[mkVar 0, mkVar 0]) 12; -- g * *
let d := d.insert (mkApp "g" #[mkApp "h" #[mkConst "a"], mkConst "b"]) 13; -- g (h a) b
let d := d.insert (mkApp "g" #[mkApp "h" #[mkConst "a"], mkVar 1]) 14; -- g (h a) *
IO.println (format d);
let vs := d.getMatch (mkApp "f" #[mkApp "h" #[mkConst "a"], mkApp "h" #[mkConst "b"]]); -- f (h a) (h b)
check vs #[3, 7, 8, 9];
let vs := d.getMatch (mkApp "f" #[mkConst "b", mkConst "a"]); -- f a b
check vs #[1, 2, 3];
let vs := d.getMatch (mkApp "g" #[mkConst "b", mkConst "b"]); -- g b b
check vs #[11, 12];
let vs := d.getUnify (mkApp "f" #[mkApp "h" #[mkVar 0], mkApp "h" #[mkVar 0]]); -- f (h *) (h *)
check vs #[3, 7, 8, 9];
let vs := d.getUnify (mkApp "f" #[mkApp "h" #[mkVar 0], mkVar 0]); -- f (h *) *
check vs #[1, 3, 4, 5, 6, 7, 8, 9];
let vs := d.getUnify (mkApp "f" #[mkApp "h" #[mkConst "b"], mkVar 0]); -- f (h b) *
check vs #[1, 3, 4, 5, 9];
let vs := d.getUnify (mkVar 0); -- *
check vs (List.iota 14).toArray;
let vs := d.getUnify (mkApp "g" #[mkVar 0, mkConst "b"]); -- g * b
check vs #[11, 12, 13, 14];
let vs := d.getUnify (mkApp "g" #[mkApp "h" #[mkVar 0], mkConst "b"]); -- g (h *) b
check vs #[12, 13, 14];
let vs := d.getUnify (mkApp "g" #[mkApp "h" #[mkConst "b"], mkVar 0]); -- g (h b) *
check vs #[10, 12];
pure ()
#eval tst2