From 30c05416a4fbad5de739a52b4a2965cdf0830551 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 23 Nov 2019 05:44:26 -0800 Subject: [PATCH] test: add `foldMatch` and `getMatch` These are the primitives we need for retrieving candidate simp lemmas. --- tests/playground/DiscrTree.lean | 61 +++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/playground/DiscrTree.lean b/tests/playground/DiscrTree.lean index cfad72500d..79f264e8c6 100644 --- a/tests/playground/DiscrTree.lean +++ b/tests/playground/DiscrTree.lean @@ -94,6 +94,34 @@ insertAux v todo d partial def format {α} [HasFormat α] : Trie α → Format | node vs cs => Format.group $ Format.paren $ "node" ++ (if vs.isEmpty then Format.nil else " " ++ fmt vs) ++ Format.join (cs.toList.map $ fun ⟨k, c⟩ => Format.line ++ Format.paren (fmt k ++ " => " ++ format c)) +@[specialize] partial def foldMatchAux {α β} {m : Type → Type} [Monad m] (f : β → α → m β) : Array Term → Trie α → β → m β +| todo, node vs cs, b => + if todo.isEmpty then vs.foldlM f b + else if cs.isEmpty then pure b + else + let t := todo.back; + let todo := todo.pop; + let first := cs.get! 0; + let k := t.key; + match k with + | Key.var => if first.1 == Key.var then foldMatchAux todo first.2 b else pure b + | Key.sym _ _ => do + match cs.binSearch (k, arbitrary _) (fun a b => a.1 < b.1) with + | none => if first.1 == Key.var then foldMatchAux todo first.2 b else pure b + | some c => do + b ← if first.1 == Key.var then foldMatchAux todo first.2 b else pure b; + let todo := appendTodo todo t.args; + foldMatchAux todo c.2 b + +@[specialize] def foldMatch {α β} {m : Type → Type} [Monad m] (d : Trie α) (k : Term) (f : β → α → m β) (b : β) : m β := +let todo : Array Term := Array.mkEmpty 32; +let todo := todo.push k; +foldMatchAux f todo d b + +/-- Return all (approximate) matches (aka generalizations) of the term `k` -/ +def getMatch {α} (d : Trie α) (k : Term) : Array α := +Id.run $ d.foldMatch k (fun (r : Array α) v => pure $ r.push v) #[] + instance {α} [HasFormat α] : HasFormat (Trie α) := ⟨format⟩ end Trie @@ -125,3 +153,36 @@ let d := (20:Nat).fold IO.println (format d) #eval tst1 + +def check (as bs : Array Nat) : IO Unit := +let as := as.qsort (fun a b => a < b); +let bs := bs.qsort (fun a b => a < b); +unless (as == bs) $ throw $ IO.userError "check failed" + +def tst2 : IO Unit := +do +let d := @Trie.empty Nat; +let d := d.insert (mkApp "f" #[mkVar 0, mkConst "a"]) 1; -- f * a +let d := d.insert (mkApp "f" #[mkConst "b", mkVar 0]) 2; -- f b * +let d := d.insert (mkApp "f" #[mkVar 0, mkVar 0]) 3; -- f * * +let d := d.insert (mkApp "f" #[mkVar 0, mkConst "b"]) 4; -- f * b +let d := d.insert (mkApp "f" #[mkApp "h" #[mkVar 0], mkConst "b"]) 5; -- f (h *) b +let d := d.insert (mkApp "f" #[mkApp "h" #[mkConst "a"], mkConst "b"]) 6; -- f (h a) b +let d := d.insert (mkApp "f" #[mkApp "h" #[mkConst "a"], mkVar 1]) 7; -- f (h a) * +let d := d.insert (mkApp "f" #[mkApp "h" #[mkConst "a"], mkVar 0]) 8; -- f (h a) * +let d := d.insert (mkApp "f" #[mkApp "h" #[mkVar 0], mkApp "h" #[mkConst "b"]]) 9; -- f (h *) (h b) +let d := d.insert (mkApp "g" #[mkVar 0, mkConst "a"]) 10; -- g * a +let d := d.insert (mkApp "g" #[mkConst "b", mkVar 0]) 11; -- g b * +let d := d.insert (mkApp "g" #[mkVar 0, mkVar 0]) 12; -- g * * +let d := d.insert (mkApp "g" #[mkApp "h" #[mkConst "a"], mkConst "b"]) 13; -- g (h a) b +let d := d.insert (mkApp "g" #[mkApp "h" #[mkConst "a"], mkVar 1]) 14; -- g (h a) * +IO.println (format d); +let vs := d.getMatch (mkApp "f" #[mkApp "h" #[mkConst "a"], mkApp "h" #[mkConst "b"]]); -- f (h a) (h b) +check vs #[3, 7, 8, 9]; +let vs := d.getMatch (mkApp "f" #[mkConst "b", mkConst "a"]); -- f a b +check vs #[1, 2, 3]; +let vs := d.getMatch (mkApp "g" #[mkConst "b", mkConst "b"]); -- g b b +check vs #[11, 12]; +pure () + +#eval tst2