feat: add support for "arrow" at DiscrTree
This commit is contained in:
parent
59ac5be60d
commit
b7acc38810
2 changed files with 47 additions and 31 deletions
|
|
@ -54,6 +54,7 @@ def Key.ctorIdx : Key → Nat
|
|||
| Key.lit _ => 2
|
||||
| Key.fvar _ _ => 3
|
||||
| Key.const _ _ => 4
|
||||
| Key.arrow => 5
|
||||
|
||||
def Key.lt : Key → Key → Bool
|
||||
| Key.lit v₁, Key.lit v₂ => v₁ < v₂
|
||||
|
|
@ -71,40 +72,42 @@ def Key.format : Key → Format
|
|||
| Key.lit (Literal.strVal v) => repr v
|
||||
| Key.const k _ => fmt k
|
||||
| Key.fvar k _ => fmt k
|
||||
| Key.arrow => "→"
|
||||
|
||||
instance : ToFormat Key := ⟨Key.format⟩
|
||||
|
||||
def Key.arity : Key → Nat
|
||||
| Key.const _ a => a
|
||||
| Key.fvar _ a => a
|
||||
| Key.arrow => 2
|
||||
| _ => 0
|
||||
|
||||
instance {α} : Inhabited (Trie α) := ⟨Trie.node #[] #[]⟩
|
||||
instance : Inhabited (Trie α) := ⟨Trie.node #[] #[]⟩
|
||||
|
||||
def empty {α} : DiscrTree α := { root := {} }
|
||||
def empty : DiscrTree α := { root := {} }
|
||||
|
||||
partial def Trie.format {α} [ToFormat α] : Trie α → Format
|
||||
partial def Trie.format [ToFormat α] : Trie α → Format
|
||||
| Trie.node vs cs => Format.group $ Format.paren $
|
||||
"node" ++ (if vs.isEmpty then Format.nil else " " ++ fmt vs)
|
||||
++ Format.join (cs.toList.map $ fun ⟨k, c⟩ => Format.line ++ Format.paren (fmt k ++ " => " ++ format c))
|
||||
|
||||
instance {α} [ToFormat α] : ToFormat (Trie α) := ⟨Trie.format⟩
|
||||
instance [ToFormat α] : ToFormat (Trie α) := ⟨Trie.format⟩
|
||||
|
||||
partial def format {α} [ToFormat α] (d : DiscrTree α) : Format :=
|
||||
partial def format [ToFormat α] (d : DiscrTree α) : Format :=
|
||||
let (_, r) := d.root.foldl
|
||||
(fun (p : Bool × Format) k c =>
|
||||
(false, p.2 ++ (if p.1 then Format.nil else Format.line) ++ Format.paren (fmt k ++ " => " ++ fmt c)))
|
||||
(true, Format.nil)
|
||||
Format.group r
|
||||
|
||||
instance {α} [ToFormat α] : ToFormat (DiscrTree α) := ⟨format⟩
|
||||
instance [ToFormat α] : ToFormat (DiscrTree α) := ⟨format⟩
|
||||
|
||||
/- The discrimination tree ignores implicit arguments and proofs.
|
||||
We use the following auxiliary id as a "mark". -/
|
||||
private def tmpMVarId : MVarId := `_discr_tree_tmp
|
||||
private def tmpStar := mkMVar tmpMVarId
|
||||
|
||||
instance {α} : Inhabited (DiscrTree α) where
|
||||
instance : Inhabited (DiscrTree α) where
|
||||
default := {}
|
||||
|
||||
/--
|
||||
|
|
@ -137,13 +140,13 @@ instance {α} : Inhabited (DiscrTree α) where
|
|||
Remark: if users have problems with the solution above, we may provide a `noIndexing` annotation,
|
||||
and `ignoreArg` would return true for any term of the form `noIndexing t`.
|
||||
-/
|
||||
private def ignoreArg (a : Expr) (i : Nat) (infos : Array ParamInfo) : MetaM Bool :=
|
||||
private def ignoreArg (a : Expr) (i : Nat) (infos : Array ParamInfo) : MetaM Bool := do
|
||||
if h : i < infos.size then
|
||||
let info := infos.get ⟨i, h⟩
|
||||
if info.instImplicit then
|
||||
pure true
|
||||
return true
|
||||
else if info.implicit then
|
||||
not <$> isType a
|
||||
return not (← isType a)
|
||||
else
|
||||
isProof a
|
||||
else
|
||||
|
|
@ -155,13 +158,13 @@ private partial def pushArgsAux (infos : Array ParamInfo) : Nat → Expr → Arr
|
|||
pushArgsAux infos (i-1) f (todo.push tmpStar)
|
||||
else
|
||||
pushArgsAux infos (i-1) f (todo.push a)
|
||||
| _, _, todo => pure todo
|
||||
| _, _, todo => return todo
|
||||
|
||||
private partial def whnfEta (e : Expr) : MetaM Expr := do
|
||||
let e ← whnf e
|
||||
match e.etaExpandedStrict? with
|
||||
| some e => whnfEta e
|
||||
| none => pure e
|
||||
| none => return e
|
||||
|
||||
/--
|
||||
Return true if `e` is one of the following
|
||||
|
|
@ -254,12 +257,17 @@ private def pushArgs (root : Bool) (todo : Array Expr) (e : Expr) : MetaM (Key
|
|||
return (Key.other, todo)
|
||||
else
|
||||
return (Key.star, todo)
|
||||
| Expr.forallE _ d b _ =>
|
||||
if b.hasLooseBVars then
|
||||
return (Key.other, todo)
|
||||
else
|
||||
return (Key.arrow, todo.push d |>.push b)
|
||||
| _ =>
|
||||
return (Key.other, todo)
|
||||
|
||||
partial def mkPathAux (root : Bool) (todo : Array Expr) (keys : Array Key) : MetaM (Array Key) := do
|
||||
if todo.isEmpty then
|
||||
pure keys
|
||||
return keys
|
||||
else
|
||||
let e := todo.back
|
||||
let todo := todo.pop
|
||||
|
|
@ -274,7 +282,7 @@ def mkPath (e : Expr) : MetaM (Array Key) := do
|
|||
let keys : Array Key := Array.mkEmpty initCapacity
|
||||
mkPathAux (root := true) (todo.push e) keys
|
||||
|
||||
private partial def createNodes {α} (keys : Array Key) (v : α) (i : Nat) : Trie α :=
|
||||
private partial def createNodes (keys : Array Key) (v : α) (i : Nat) : Trie α :=
|
||||
if h : i < keys.size then
|
||||
let k := keys.get ⟨i, h⟩
|
||||
let c := createNodes keys v (i+1)
|
||||
|
|
@ -282,10 +290,10 @@ private partial def createNodes {α} (keys : Array Key) (v : α) (i : Nat) : Tri
|
|||
else
|
||||
Trie.node #[v] #[]
|
||||
|
||||
private def insertVal {α} [BEq α] (vs : Array α) (v : α) : Array α :=
|
||||
private def insertVal [BEq α] (vs : Array α) (v : α) : Array α :=
|
||||
if vs.contains v then vs else vs.push v
|
||||
|
||||
private partial def insertAux {α} [BEq α] (keys : Array Key) (v : α) : Nat → Trie α → Trie α
|
||||
private partial def insertAux [BEq α] (keys : Array Key) (v : α) : Nat → Trie α → Trie α
|
||||
| i, Trie.node vs cs =>
|
||||
if h : i < keys.size then
|
||||
let k := keys.get ⟨i, h⟩
|
||||
|
|
@ -298,7 +306,7 @@ private partial def insertAux {α} [BEq α] (keys : Array Key) (v : α) : Nat
|
|||
else
|
||||
Trie.node (insertVal vs v) cs
|
||||
|
||||
def insertCore {α} [BEq α] (d : DiscrTree α) (keys : Array Key) (v : α) : DiscrTree α :=
|
||||
def insertCore [BEq α] (d : DiscrTree α) (keys : Array Key) (v : α) : DiscrTree α :=
|
||||
if keys.isEmpty then panic! "invalid key sequence"
|
||||
else
|
||||
let k := keys[0]
|
||||
|
|
@ -310,23 +318,23 @@ def insertCore {α} [BEq α] (d : DiscrTree α) (keys : Array Key) (v : α) : Di
|
|||
let c := insertAux keys v 1 c
|
||||
{ root := d.root.insert k c }
|
||||
|
||||
def insert {α} [BEq α] (d : DiscrTree α) (e : Expr) (v : α) : MetaM (DiscrTree α) := do
|
||||
def insert [BEq α] (d : DiscrTree α) (e : Expr) (v : α) : MetaM (DiscrTree α) := do
|
||||
let keys ← mkPath e
|
||||
return d.insertCore keys v
|
||||
|
||||
private def getKeyArgs (e : Expr) (isMatch : Bool) : MetaM (Key × Array Expr) := do
|
||||
let e ← whnfEta e
|
||||
match e.getAppFn with
|
||||
| Expr.lit v _ => pure (Key.lit v, #[])
|
||||
| Expr.lit v _ => return (Key.lit v, #[])
|
||||
| Expr.const c _ _ =>
|
||||
let nargs := e.getAppNumArgs
|
||||
pure (Key.const c nargs, e.getAppRevArgs)
|
||||
return (Key.const c nargs, e.getAppRevArgs)
|
||||
| Expr.fvar fvarId _ =>
|
||||
let nargs := e.getAppNumArgs
|
||||
pure (Key.fvar fvarId nargs, e.getAppRevArgs)
|
||||
return (Key.fvar fvarId nargs, e.getAppRevArgs)
|
||||
| Expr.mvar mvarId _ =>
|
||||
if isMatch then
|
||||
pure (Key.other, #[])
|
||||
return (Key.other, #[])
|
||||
else do
|
||||
let ctx ← read
|
||||
if ctx.config.isDefEqStuckEx then
|
||||
|
|
@ -344,12 +352,18 @@ private def getKeyArgs (e : Expr) (isMatch : Bool) : MetaM (Key × Array Expr) :
|
|||
a regular metavariable here, otherwise we return the empty set of candidates.
|
||||
This is incorrect because it is equivalent to saying that there is no solution even if
|
||||
the caller assigns `?m` and try again. -/
|
||||
pure (Key.star, #[])
|
||||
return (Key.star, #[])
|
||||
else if (← isReadOnlyOrSyntheticOpaqueExprMVar mvarId) then
|
||||
pure (Key.other, #[])
|
||||
return (Key.other, #[])
|
||||
else
|
||||
pure (Key.star, #[])
|
||||
| _ => pure (Key.other, #[])
|
||||
return (Key.star, #[])
|
||||
| Expr.forallE _ d b _ =>
|
||||
if b.hasLooseBVars then
|
||||
return (Key.other, #[])
|
||||
else
|
||||
return (Key.arrow, #[d, b])
|
||||
| _ =>
|
||||
return (Key.other, #[])
|
||||
|
||||
private abbrev getMatchKeyArgs (e : Expr) : MetaM (Key × Array Expr) :=
|
||||
getKeyArgs e (isMatch := true)
|
||||
|
|
@ -357,21 +371,21 @@ private abbrev getMatchKeyArgs (e : Expr) : MetaM (Key × Array Expr) :=
|
|||
private abbrev getUnifyKeyArgs (e : Expr) : MetaM (Key × Array Expr) :=
|
||||
getKeyArgs e (isMatch := false)
|
||||
|
||||
private def getStarResult {α} (d : DiscrTree α) : Array α :=
|
||||
private def getStarResult (d : DiscrTree α) : Array α :=
|
||||
let result : Array α := Array.mkEmpty initCapacity
|
||||
match d.root.find? Key.star with
|
||||
| none => result
|
||||
| some (Trie.node vs _) => result ++ vs
|
||||
|
||||
partial def getMatch {α} (d : DiscrTree α) (e : Expr) : MetaM (Array α) :=
|
||||
partial def getMatch (d : DiscrTree α) (e : Expr) : MetaM (Array α) :=
|
||||
withReducible do
|
||||
let result := getStarResult d
|
||||
let (k, args) ← getMatchKeyArgs e
|
||||
match k with
|
||||
| Key.star => pure result
|
||||
| Key.star => return result
|
||||
| _ =>
|
||||
match d.root.find? k with
|
||||
| none => pure result
|
||||
| none => return result
|
||||
| some c => process args c result
|
||||
where
|
||||
process (todo : Array Expr) (c : Trie α) (result : Array α) : MetaM (Array α) := do
|
||||
|
|
@ -403,7 +417,7 @@ where
|
|||
let result ← visitStarChild result
|
||||
process (todo ++ args) c.2 result
|
||||
|
||||
partial def getUnify {α} (d : DiscrTree α) (e : Expr) : MetaM (Array α) :=
|
||||
partial def getUnify (d : DiscrTree α) (e : Expr) : MetaM (Array α) :=
|
||||
withReducible do
|
||||
let (k, args) ← getUnifyKeyArgs e
|
||||
match k with
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ inductive Key where
|
|||
| lit : Literal → Key
|
||||
| star : Key
|
||||
| other : Key
|
||||
| arrow : Key
|
||||
deriving Inhabited, BEq
|
||||
|
||||
protected def Key.hash : Key → USize
|
||||
|
|
@ -25,6 +26,7 @@ protected def Key.hash : Key → USize
|
|||
| Key.lit v => mixHash 1879 $ hash v
|
||||
| Key.star => 7883
|
||||
| Key.other => 2411
|
||||
| Key.arrow => 17
|
||||
|
||||
instance : Hashable Key := ⟨Key.hash⟩
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue