feat: add support for "arrow" at DiscrTree

This commit is contained in:
Leonardo de Moura 2021-03-25 12:20:05 -07:00
parent 59ac5be60d
commit b7acc38810
2 changed files with 47 additions and 31 deletions

View file

@ -54,6 +54,7 @@ def Key.ctorIdx : Key → Nat
| Key.lit _ => 2
| Key.fvar _ _ => 3
| Key.const _ _ => 4
| Key.arrow => 5
def Key.lt : Key → Key → Bool
| Key.lit v₁, Key.lit v₂ => v₁ < v₂
@ -71,40 +72,42 @@ def Key.format : Key → Format
| Key.lit (Literal.strVal v) => repr v
| Key.const k _ => fmt k
| Key.fvar k _ => fmt k
| Key.arrow => "→"
instance : ToFormat Key := ⟨Key.format⟩
def Key.arity : Key → Nat
| Key.const _ a => a
| Key.fvar _ a => a
| Key.arrow => 2
| _ => 0
instance {α} : Inhabited (Trie α) := ⟨Trie.node #[] #[]⟩
instance : Inhabited (Trie α) := ⟨Trie.node #[] #[]⟩
def empty {α} : DiscrTree α := { root := {} }
def empty : DiscrTree α := { root := {} }
partial def Trie.format {α} [ToFormat α] : Trie α → Format
partial def Trie.format [ToFormat α] : Trie α → Format
| Trie.node vs cs => Format.group $ Format.paren $
"node" ++ (if vs.isEmpty then Format.nil else " " ++ fmt vs)
++ Format.join (cs.toList.map $ fun ⟨k, c⟩ => Format.line ++ Format.paren (fmt k ++ " => " ++ format c))
instance {α} [ToFormat α] : ToFormat (Trie α) := ⟨Trie.format⟩
instance [ToFormat α] : ToFormat (Trie α) := ⟨Trie.format⟩
partial def format {α} [ToFormat α] (d : DiscrTree α) : Format :=
partial def format [ToFormat α] (d : DiscrTree α) : Format :=
let (_, r) := d.root.foldl
(fun (p : Bool × Format) k c =>
(false, p.2 ++ (if p.1 then Format.nil else Format.line) ++ Format.paren (fmt k ++ " => " ++ fmt c)))
(true, Format.nil)
Format.group r
instance {α} [ToFormat α] : ToFormat (DiscrTree α) := ⟨format⟩
instance [ToFormat α] : ToFormat (DiscrTree α) := ⟨format⟩
/- The discrimination tree ignores implicit arguments and proofs.
We use the following auxiliary id as a "mark". -/
private def tmpMVarId : MVarId := `_discr_tree_tmp
private def tmpStar := mkMVar tmpMVarId
instance {α} : Inhabited (DiscrTree α) where
instance : Inhabited (DiscrTree α) where
default := {}
/--
@ -137,13 +140,13 @@ instance {α} : Inhabited (DiscrTree α) where
Remark: if users have problems with the solution above, we may provide a `noIndexing` annotation,
and `ignoreArg` would return true for any term of the form `noIndexing t`.
-/
private def ignoreArg (a : Expr) (i : Nat) (infos : Array ParamInfo) : MetaM Bool :=
private def ignoreArg (a : Expr) (i : Nat) (infos : Array ParamInfo) : MetaM Bool := do
if h : i < infos.size then
let info := infos.get ⟨i, h⟩
if info.instImplicit then
pure true
return true
else if info.implicit then
not <$> isType a
return not (← isType a)
else
isProof a
else
@ -155,13 +158,13 @@ private partial def pushArgsAux (infos : Array ParamInfo) : Nat → Expr → Arr
pushArgsAux infos (i-1) f (todo.push tmpStar)
else
pushArgsAux infos (i-1) f (todo.push a)
| _, _, todo => pure todo
| _, _, todo => return todo
private partial def whnfEta (e : Expr) : MetaM Expr := do
let e ← whnf e
match e.etaExpandedStrict? with
| some e => whnfEta e
| none => pure e
| none => return e
/--
Return true if `e` is one of the following
@ -254,12 +257,17 @@ private def pushArgs (root : Bool) (todo : Array Expr) (e : Expr) : MetaM (Key
return (Key.other, todo)
else
return (Key.star, todo)
| Expr.forallE _ d b _ =>
if b.hasLooseBVars then
return (Key.other, todo)
else
return (Key.arrow, todo.push d |>.push b)
| _ =>
return (Key.other, todo)
partial def mkPathAux (root : Bool) (todo : Array Expr) (keys : Array Key) : MetaM (Array Key) := do
if todo.isEmpty then
pure keys
return keys
else
let e := todo.back
let todo := todo.pop
@ -274,7 +282,7 @@ def mkPath (e : Expr) : MetaM (Array Key) := do
let keys : Array Key := Array.mkEmpty initCapacity
mkPathAux (root := true) (todo.push e) keys
private partial def createNodes {α} (keys : Array Key) (v : α) (i : Nat) : Trie α :=
private partial def createNodes (keys : Array Key) (v : α) (i : Nat) : Trie α :=
if h : i < keys.size then
let k := keys.get ⟨i, h⟩
let c := createNodes keys v (i+1)
@ -282,10 +290,10 @@ private partial def createNodes {α} (keys : Array Key) (v : α) (i : Nat) : Tri
else
Trie.node #[v] #[]
private def insertVal {α} [BEq α] (vs : Array α) (v : α) : Array α :=
private def insertVal [BEq α] (vs : Array α) (v : α) : Array α :=
if vs.contains v then vs else vs.push v
private partial def insertAux {α} [BEq α] (keys : Array Key) (v : α) : Nat → Trie α → Trie α
private partial def insertAux [BEq α] (keys : Array Key) (v : α) : Nat → Trie α → Trie α
| i, Trie.node vs cs =>
if h : i < keys.size then
let k := keys.get ⟨i, h⟩
@ -298,7 +306,7 @@ private partial def insertAux {α} [BEq α] (keys : Array Key) (v : α) : Nat
else
Trie.node (insertVal vs v) cs
def insertCore {α} [BEq α] (d : DiscrTree α) (keys : Array Key) (v : α) : DiscrTree α :=
def insertCore [BEq α] (d : DiscrTree α) (keys : Array Key) (v : α) : DiscrTree α :=
if keys.isEmpty then panic! "invalid key sequence"
else
let k := keys[0]
@ -310,23 +318,23 @@ def insertCore {α} [BEq α] (d : DiscrTree α) (keys : Array Key) (v : α) : Di
let c := insertAux keys v 1 c
{ root := d.root.insert k c }
def insert {α} [BEq α] (d : DiscrTree α) (e : Expr) (v : α) : MetaM (DiscrTree α) := do
def insert [BEq α] (d : DiscrTree α) (e : Expr) (v : α) : MetaM (DiscrTree α) := do
let keys ← mkPath e
return d.insertCore keys v
private def getKeyArgs (e : Expr) (isMatch : Bool) : MetaM (Key × Array Expr) := do
let e ← whnfEta e
match e.getAppFn with
| Expr.lit v _ => pure (Key.lit v, #[])
| Expr.lit v _ => return (Key.lit v, #[])
| Expr.const c _ _ =>
let nargs := e.getAppNumArgs
pure (Key.const c nargs, e.getAppRevArgs)
return (Key.const c nargs, e.getAppRevArgs)
| Expr.fvar fvarId _ =>
let nargs := e.getAppNumArgs
pure (Key.fvar fvarId nargs, e.getAppRevArgs)
return (Key.fvar fvarId nargs, e.getAppRevArgs)
| Expr.mvar mvarId _ =>
if isMatch then
pure (Key.other, #[])
return (Key.other, #[])
else do
let ctx ← read
if ctx.config.isDefEqStuckEx then
@ -344,12 +352,18 @@ private def getKeyArgs (e : Expr) (isMatch : Bool) : MetaM (Key × Array Expr) :
a regular metavariable here, otherwise we return the empty set of candidates.
This is incorrect because it is equivalent to saying that there is no solution even if
the caller assigns `?m` and try again. -/
pure (Key.star, #[])
return (Key.star, #[])
else if (← isReadOnlyOrSyntheticOpaqueExprMVar mvarId) then
pure (Key.other, #[])
return (Key.other, #[])
else
pure (Key.star, #[])
| _ => pure (Key.other, #[])
return (Key.star, #[])
| Expr.forallE _ d b _ =>
if b.hasLooseBVars then
return (Key.other, #[])
else
return (Key.arrow, #[d, b])
| _ =>
return (Key.other, #[])
private abbrev getMatchKeyArgs (e : Expr) : MetaM (Key × Array Expr) :=
getKeyArgs e (isMatch := true)
@ -357,21 +371,21 @@ private abbrev getMatchKeyArgs (e : Expr) : MetaM (Key × Array Expr) :=
private abbrev getUnifyKeyArgs (e : Expr) : MetaM (Key × Array Expr) :=
getKeyArgs e (isMatch := false)
private def getStarResult {α} (d : DiscrTree α) : Array α :=
private def getStarResult (d : DiscrTree α) : Array α :=
let result : Array α := Array.mkEmpty initCapacity
match d.root.find? Key.star with
| none => result
| some (Trie.node vs _) => result ++ vs
partial def getMatch {α} (d : DiscrTree α) (e : Expr) : MetaM (Array α) :=
partial def getMatch (d : DiscrTree α) (e : Expr) : MetaM (Array α) :=
withReducible do
let result := getStarResult d
let (k, args) ← getMatchKeyArgs e
match k with
| Key.star => pure result
| Key.star => return result
| _ =>
match d.root.find? k with
| none => pure result
| none => return result
| some c => process args c result
where
process (todo : Array Expr) (c : Trie α) (result : Array α) : MetaM (Array α) := do
@ -403,7 +417,7 @@ where
let result ← visitStarChild result
process (todo ++ args) c.2 result
partial def getUnify {α} (d : DiscrTree α) (e : Expr) : MetaM (Array α) :=
partial def getUnify (d : DiscrTree α) (e : Expr) : MetaM (Array α) :=
withReducible do
let (k, args) ← getUnifyKeyArgs e
match k with

View file

@ -17,6 +17,7 @@ inductive Key where
| lit : Literal → Key
| star : Key
| other : Key
| arrow : Key
deriving Inhabited, BEq
protected def Key.hash : Key → USize
@ -25,6 +26,7 @@ protected def Key.hash : Key → USize
| Key.lit v => mixHash 1879 $ hash v
| Key.star => 7883
| Key.other => 2411
| Key.arrow => 17
instance : Hashable Key := ⟨Key.hash⟩