feat: add getMatchWithExtra and improve tryLemma at simp

This commit is contained in:
Leonardo de Moura 2021-09-09 19:28:09 -07:00
parent 87f49be5dd
commit 19a710ffc9
2 changed files with 119 additions and 97 deletions

View file

@ -424,66 +424,76 @@ private def getStarResult (d : DiscrTree α) : Array α :=
private abbrev findKey (cs : Array (Key × Trie α)) (k : Key) : Option (Key × Trie α) :=
cs.binSearch (k, arbitrary) (fun a b => a.1 < b.1)
private partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Array α) : MetaM (Array α) := do
match c with
| Trie.node vs cs =>
if todo.isEmpty then
return result ++ vs
else if cs.isEmpty then
return result
else
let e := todo.back
let todo := todo.pop
let first := cs[0] /- Recall that `Key.star` is the minimal key -/
let (k, args) ← getMatchKeyArgs e (root := false)
/- We must always visit `Key.star` edges since they are wildcards.
Thus, `todo` is not used linearly when there is `Key.star` edge
and there is an edge for `k` and `k != Key.star`. -/
let visitStar (result : Array α) : MetaM (Array α) :=
if first.1 == Key.star then
getMatchLoop todo first.2 result
else
return result
let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
match findKey cs k with
| none => result
| some c => getMatchLoop (todo ++ args) c.2 result
let result ← visitStar result
match k with
| Key.star => result
/-
Recall that dependent arrows are `(Key.other, #[])`, and non-dependent arrows are `(Key.arrow, #[a, b])`.
A non-dependent arrow may be an instance of a dependent arrow (stored at `DiscrTree`). Thus, we also visit the `Key.other` child.
-/
| Key.arrow => visitNonStar Key.other #[] (← visitNonStar k args result)
| _ => visitNonStar k args result
private def getMatchRoot (d : DiscrTree α) (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
match d.root.find? k with
| none => return result
| some c => getMatchLoop args c result
/--
Find values that match `e` in `d`.
If `allowExtraArgs == true`, we also return solutions that match prefixes of `e`.
-/
partial def getMatch (d : DiscrTree α) (e : Expr) (allowExtraArgs := false) : MetaM (Array α) :=
partial def getMatch (d : DiscrTree α) (e : Expr) : MetaM (Array α) :=
withReducible do
let result := getStarResult d
let (k, args) ← getMatchKeyArgs e (root := true)
match k with
| Key.star => return result
| _ => if allowExtraArgs then processRootWithExtra k args result else processRoot k args result
where
processRoot (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) := do
match d.root.find? k with
| none => return result
| some c => process args c result
| _ => getMatchRoot d k args result
processRootWithExtra (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) := do
let result ← processRoot k args result
/--
Similar to `getMatch`, but returns solutions that are prefixes of `e`.
We store the number of ignored arguments in the result.-/
partial def getMatchWithExtra (d : DiscrTree α) (e : Expr) : MetaM (Array (α × Nat)) :=
withReducible do
let result := getStarResult d |>.map (., 0)
let (k, args) ← getMatchKeyArgs e (root := true)
match k with
| Key.star => return result
| _ => process k args 0 result
where
process (k : Key) (args : Array Expr) (numExtraArgs : Nat) (result : Array (α × Nat)) : MetaM (Array (α × Nat)) := do
let result := result ++ ((← getMatchRoot d k args #[]).map (., numExtraArgs))
match k with
| Key.const f 0 => return result
| Key.const f (n+1) => processRootWithExtra (Key.const f n) args.pop result
| Key.const f (n+1) => process (Key.const f n) args.pop (numExtraArgs + 1) result
| Key.fvar f 0 => return result
| Key.fvar f (n+1) => processRootWithExtra (Key.fvar f n) args.pop result
| Key.fvar f (n+1) => process (Key.fvar f n) args.pop (numExtraArgs + 1) result
| _ => return result
process (todo : Array Expr) (c : Trie α) (result : Array α) : MetaM (Array α) := do
match c with
| Trie.node vs cs =>
if todo.isEmpty then
return result ++ vs
else if cs.isEmpty then
return result
else
let e := todo.back
let todo := todo.pop
let first := cs[0] /- Recall that `Key.star` is the minimal key -/
let (k, args) ← getMatchKeyArgs e (root := false)
/- We must always visit `Key.star` edges since they are wildcards.
Thus, `todo` is not used linearly when there is `Key.star` edge
and there is an edge for `k` and `k != Key.star`. -/
let visitStar (result : Array α) : MetaM (Array α) :=
if first.1 == Key.star then
process todo first.2 result
else
return result
let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
match findKey cs k with
| none => result
| some c => process (todo ++ args) c.2 result
let result ← visitStar result
match k with
| Key.star => result
/-
Recall that dependent arrows are `(Key.other, #[])`, and non-dependent arrows are `(Key.arrow, #[a, b])`.
A non-dependent arrow may be an instance of a dependent arrow (stored at `DiscrTree`). Thus, we also visit the `Key.other` child.
-/
| Key.arrow => visitNonStar Key.other #[] (← visitNonStar k args result)
| _ => visitNonStar k args result
partial def getUnify (d : DiscrTree α) (e : Expr) : MetaM (Array α) :=
withReducible do
let (k, args) ← getUnifyKeyArgs e (root := true)

View file

@ -42,72 +42,84 @@ where
trace[Meta.Tactic.simp.discharge] "{lemmaName}, failed to synthesize instance{indentExpr type}"
return false
def tryLemma? (e : Expr) (lemma : SimpLemma) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Result) :=
private def tryLemmaCore (lhs : Expr) (xs : Array Expr) (bis : Array BinderInfo) (val : Expr) (type : Expr) (e : Expr) (lemma : SimpLemma) (numExtraArgs : Nat) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Result) := do
let rec go (e : Expr) : SimpM (Option Result) := do
if (← isDefEq lhs e) then
unless (← synthesizeArgs lemma.getName xs bis discharge?) do
return none
let proof ← instantiateMVars (mkAppN val xs)
if ← hasAssignableMVar proof then
trace[Meta.Tactic.simp.rewrite] "{lemma}, has unassigned metavariables after unification"
return none
let rhs ← instantiateMVars type.appArg!
if e == rhs then
return none
if lemma.perm && !Expr.lt rhs e then
trace[Meta.Tactic.simp.rewrite] "{lemma}, perm rejected {e} ==> {rhs}"
return none
trace[Meta.Tactic.simp.rewrite] "{lemma}, {e} ==> {rhs}"
return some { expr := rhs, proof? := proof }
else
unless lhs.isMVar do
-- We do not report unification failures when `lhs` is a metavariable
-- Example: `x = ()`
-- TODO: reconsider if we want lemmas such as `(x : Unit) → x = ()`
trace[Meta.Tactic.simp.unify] "{lemma}, failed to unify {lhs} with {e}"
return none
/- Check whether we need something more sophisticated here.
This simple approach was good enough for Mathlib 3 -/
let mut extraArgs := #[]
let mut e := e
for i in [:numExtraArgs] do
extraArgs := extraArgs.push e.appArg!
e := e.appFn!
match (← go e) with
| none => return none
| some { expr := eNew, proof? := none } => return some { expr := mkAppN eNew extraArgs }
| some { expr := eNew, proof? := some proof } =>
let mut proof := proof
for extraArg in extraArgs do
proof ← mkCongrFun proof extraArg
return some { expr := mkAppN eNew extraArgs, proof? := some proof }
def tryLemmaWithExtraArgs? (e : Expr) (lemma : SimpLemma) (numExtraArgs : Nat) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Result) :=
withNewMCtxDepth do
let val ← lemma.getValue
let type ← inferType val
let (xs, bis, type) ← forallMetaTelescopeReducing type
let type ← whnf (← instantiateMVars type)
let lhs := type.appFn!.appArg!
let rec go (e : Expr) : SimpM (Option Result) := do
if (← isDefEq lhs e) then
unless (← synthesizeArgs lemma.getName xs bis discharge?) do
return none
let proof ← instantiateMVars (mkAppN val xs)
if ← hasAssignableMVar proof then
trace[Meta.Tactic.simp.rewrite] "{lemma}, has unassigned metavariables after unification"
return none
let rhs ← instantiateMVars type.appArg!
if e == rhs then
return none
if lemma.perm && !Expr.lt rhs e then
trace[Meta.Tactic.simp.rewrite] "{lemma}, perm rejected {e} ==> {rhs}"
return none
trace[Meta.Tactic.simp.rewrite] "{lemma}, {e} ==> {rhs}"
return some { expr := rhs, proof? := proof }
else
unless lhs.isMVar do
-- We do not report unification failures when `lhs` is a metavariable
-- Example: `x = ()`
-- TODO: reconsider if we want lemmas such as `(x : Unit) → x = ()`
trace[Meta.Tactic.simp.unify] "{lemma}, failed to unify {lhs} with {e}"
return none
let lhsNumArgs := lhs.getAppNumArgs
let eNumArgs := e.getAppNumArgs
if eNumArgs == lhsNumArgs then
go e
else if eNumArgs < lhsNumArgs then
return none
else
/- Check whether we need something more sophisticated here.
This simple approach was good enough for Mathlib 3 -/
let mut extraArgs := #[]
let mut e := e
for i in [:eNumArgs - lhsNumArgs] do
extraArgs := extraArgs.push e.appArg!
e := e.appFn!
match (← go e) with
| none => return none
| some { expr := eNew, proof? := none } => return some { expr := mkAppN eNew extraArgs }
| some { expr := eNew, proof? := some proof } =>
let mut proof := proof
for extraArg in extraArgs do
proof ← mkCongrFun proof extraArg
return some { expr := mkAppN eNew extraArgs, proof? := some proof }
tryLemmaCore lhs xs bis val type e lemma numExtraArgs discharge?
def tryLemma? (e : Expr) (lemma : SimpLemma) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Result) := do
withNewMCtxDepth do
let val ← lemma.getValue
let type ← inferType val
let (xs, bis, type) ← forallMetaTelescopeReducing type
let type ← whnf (← instantiateMVars type)
let lhs := type.appFn!.appArg!
match (← tryLemmaCore lhs xs bis val type e lemma 0 discharge?) with
| some result => return some result
| none =>
let lhsNumArgs := lhs.getAppNumArgs
let eNumArgs := e.getAppNumArgs
if eNumArgs > lhsNumArgs then
tryLemmaCore lhs xs bis val type e lemma (eNumArgs - lhsNumArgs) discharge?
else
return none
/-
Remark: the parameter tag is used for creating trace messages. It is irrelevant otherwise.
-/
def rewrite (e : Expr) (s : DiscrTree SimpLemma) (erased : Std.PHashSet Name) (discharge? : Expr → SimpM (Option Expr)) (tag : String) : SimpM Result := do
let lemmas ← s.getMatch e (allowExtraArgs := true)
if lemmas.isEmpty then
let candidates ← s.getMatchWithExtra e
if candidates.isEmpty then
trace[Debug.Meta.Tactic.simp] "no theorems found for {tag}-rewriting {e}"
return { expr := e }
else
let lemmas := lemmas.insertionSort fun e₁ e₂ => e₁.priority < e₂.priority
for lemma in lemmas do
let candidates := candidates.insertionSort fun e₁ e₂ => e₁.1.priority < e₂.1.priority
for (lemma, numExtraArgs) in candidates do
unless inErasedSet lemma do
if let some result ← tryLemma? e lemma discharge? then
if let some result ← tryLemmaWithExtraArgs? e lemma numExtraArgs discharge? then
return result
return { expr := e }
where