From 886eb3f51f435e26a3fb669ea2d40052c89d023a Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 23 Nov 2019 06:23:04 -0800 Subject: [PATCH] test: add `foldUnify` and `getUnify` These are the primitives we need for retrieving candidate instances. --- tests/playground/DiscrTree.lean | 59 ++++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/tests/playground/DiscrTree.lean b/tests/playground/DiscrTree.lean index 79f264e8c6..08a870d610 100644 --- a/tests/playground/DiscrTree.lean +++ b/tests/playground/DiscrTree.lean @@ -94,6 +94,8 @@ insertAux v todo d partial def format {α} [HasFormat α] : Trie α → Format | node vs cs => Format.group $ Format.paren $ "node" ++ (if vs.isEmpty then Format.nil else " " ++ fmt vs) ++ Format.join (cs.toList.map $ fun ⟨k, c⟩ => Format.line ++ Format.paren (fmt k ++ " => " ++ format c)) +instance {α} [HasFormat α] : HasFormat (Trie α) := ⟨format⟩ + @[specialize] partial def foldMatchAux {α β} {m : Type → Type} [Monad m] (f : β → α → m β) : Array Term → Trie α → β → m β | todo, node vs cs, b => if todo.isEmpty then vs.foldlM f b @@ -122,7 +124,48 @@ foldMatchAux f todo d b def getMatch {α} (d : Trie α) (k : Term) : Array α := Id.run $ d.foldMatch k (fun (r : Array α) v => pure $ r.push v) #[] -instance {α} [HasFormat α] : HasFormat (Trie α) := ⟨format⟩ +@[specialize] partial def foldUnifyAux {α β} {m : Type → Type} [Monad m] (f : β → α → m β) : Nat → Array Term → Trie α → β → m β +| skip+1, todo, node vs cs, b => + if cs.isEmpty then pure b + else + cs.foldlM + (fun b ⟨k, c⟩ => + match k with + | Key.var => foldUnifyAux skip todo c b + | Key.sym _ a => foldUnifyAux (skip + a) todo c b) + b +| 0, todo, node vs cs, b => + if todo.isEmpty then vs.foldlM f b + else if cs.isEmpty then pure b + else + let t := todo.back; + let todo := todo.pop; + let first := cs.get! 0; + let k := t.key; + match k with + | Key.var => + cs.foldlM + (fun b ⟨k, c⟩ => + match k with + | Key.var => foldUnifyAux 0 todo c b + | Key.sym _ a => foldUnifyAux a todo c b) + b + | Key.sym _ _ => do + match cs.binSearch (k, arbitrary _) (fun a b => a.1 < b.1) with + | none => if first.1 == Key.var then foldUnifyAux 0 todo first.2 b else pure b + | some c => do + b ← if first.1 == Key.var then foldUnifyAux 0 todo first.2 b else pure b; + let todo := appendTodo todo t.args; + foldUnifyAux 0 todo c.2 b + +@[specialize] def foldUnify {α β} {m : Type → Type} [Monad m] (d : Trie α) (k : Term) (f : β → α → m β) (b : β) : m β := +let todo : Array Term := Array.mkEmpty 32; +let todo := todo.push k; +foldUnifyAux f 0 todo d b + +/-- Return all candidate unifiers of the term `k` -/ +def getUnify {α} (d : Trie α) (k : Term) : Array α := +Id.run $ d.foldUnify k (fun (r : Array α) v => pure $ r.push v) #[] end Trie @@ -183,6 +226,20 @@ let vs := d.getMatch (mkApp "f" #[mkConst "b", mkConst "a"]); -- f a b check vs #[1, 2, 3]; let vs := d.getMatch (mkApp "g" #[mkConst "b", mkConst "b"]); -- g b b check vs #[11, 12]; +let vs := d.getUnify (mkApp "f" #[mkApp "h" #[mkVar 0], mkApp "h" #[mkVar 0]]); -- f (h *) (h *) +check vs #[3, 7, 8, 9]; +let vs := d.getUnify (mkApp "f" #[mkApp "h" #[mkVar 0], mkVar 0]); -- f (h *) * +check vs #[1, 3, 4, 5, 6, 7, 8, 9]; +let vs := d.getUnify (mkApp "f" #[mkApp "h" #[mkConst "b"], mkVar 0]); -- f (h b) * +check vs #[1, 3, 4, 5, 9]; +let vs := d.getUnify (mkVar 0); -- * +check vs (List.iota 14).toArray; +let vs := d.getUnify (mkApp "g" #[mkVar 0, mkConst "b"]); -- g * b +check vs #[11, 12, 13, 14]; +let vs := d.getUnify (mkApp "g" #[mkApp "h" #[mkVar 0], mkConst "b"]); -- g (h *) b +check vs #[12, 13, 14]; +let vs := d.getUnify (mkApp "g" #[mkApp "h" #[mkConst "b"], mkVar 0]); -- g (h b) * +check vs #[10, 12]; pure () #eval tst2