feat: add getMatchWithExtra and improve tryLemma at simp
This commit is contained in:
parent
87f49be5dd
commit
19a710ffc9
2 changed files with 119 additions and 97 deletions
|
|
@ -424,66 +424,76 @@ private def getStarResult (d : DiscrTree α) : Array α :=
|
|||
private abbrev findKey (cs : Array (Key × Trie α)) (k : Key) : Option (Key × Trie α) :=
|
||||
cs.binSearch (k, arbitrary) (fun a b => a.1 < b.1)
|
||||
|
||||
private partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Array α) : MetaM (Array α) := do
|
||||
match c with
|
||||
| Trie.node vs cs =>
|
||||
if todo.isEmpty then
|
||||
return result ++ vs
|
||||
else if cs.isEmpty then
|
||||
return result
|
||||
else
|
||||
let e := todo.back
|
||||
let todo := todo.pop
|
||||
let first := cs[0] /- Recall that `Key.star` is the minimal key -/
|
||||
let (k, args) ← getMatchKeyArgs e (root := false)
|
||||
/- We must always visit `Key.star` edges since they are wildcards.
|
||||
Thus, `todo` is not used linearly when there is `Key.star` edge
|
||||
and there is an edge for `k` and `k != Key.star`. -/
|
||||
let visitStar (result : Array α) : MetaM (Array α) :=
|
||||
if first.1 == Key.star then
|
||||
getMatchLoop todo first.2 result
|
||||
else
|
||||
return result
|
||||
let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
|
||||
match findKey cs k with
|
||||
| none => result
|
||||
| some c => getMatchLoop (todo ++ args) c.2 result
|
||||
let result ← visitStar result
|
||||
match k with
|
||||
| Key.star => result
|
||||
/-
|
||||
Recall that dependent arrows are `(Key.other, #[])`, and non-dependent arrows are `(Key.arrow, #[a, b])`.
|
||||
A non-dependent arrow may be an instance of a dependent arrow (stored at `DiscrTree`). Thus, we also visit the `Key.other` child.
|
||||
-/
|
||||
| Key.arrow => visitNonStar Key.other #[] (← visitNonStar k args result)
|
||||
| _ => visitNonStar k args result
|
||||
|
||||
private def getMatchRoot (d : DiscrTree α) (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
|
||||
match d.root.find? k with
|
||||
| none => return result
|
||||
| some c => getMatchLoop args c result
|
||||
|
||||
/--
|
||||
Find values that match `e` in `d`.
|
||||
If `allowExtraArgs == true`, we also return solutions that match prefixes of `e`.
|
||||
-/
|
||||
partial def getMatch (d : DiscrTree α) (e : Expr) (allowExtraArgs := false) : MetaM (Array α) :=
|
||||
partial def getMatch (d : DiscrTree α) (e : Expr) : MetaM (Array α) :=
|
||||
withReducible do
|
||||
let result := getStarResult d
|
||||
let (k, args) ← getMatchKeyArgs e (root := true)
|
||||
match k with
|
||||
| Key.star => return result
|
||||
| _ => if allowExtraArgs then processRootWithExtra k args result else processRoot k args result
|
||||
where
|
||||
processRoot (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) := do
|
||||
match d.root.find? k with
|
||||
| none => return result
|
||||
| some c => process args c result
|
||||
| _ => getMatchRoot d k args result
|
||||
|
||||
processRootWithExtra (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) := do
|
||||
let result ← processRoot k args result
|
||||
/--
|
||||
Similar to `getMatch`, but returns solutions that are prefixes of `e`.
|
||||
We store the number of ignored arguments in the result.-/
|
||||
partial def getMatchWithExtra (d : DiscrTree α) (e : Expr) : MetaM (Array (α × Nat)) :=
|
||||
withReducible do
|
||||
let result := getStarResult d |>.map (., 0)
|
||||
let (k, args) ← getMatchKeyArgs e (root := true)
|
||||
match k with
|
||||
| Key.star => return result
|
||||
| _ => process k args 0 result
|
||||
where
|
||||
process (k : Key) (args : Array Expr) (numExtraArgs : Nat) (result : Array (α × Nat)) : MetaM (Array (α × Nat)) := do
|
||||
let result := result ++ ((← getMatchRoot d k args #[]).map (., numExtraArgs))
|
||||
match k with
|
||||
| Key.const f 0 => return result
|
||||
| Key.const f (n+1) => processRootWithExtra (Key.const f n) args.pop result
|
||||
| Key.const f (n+1) => process (Key.const f n) args.pop (numExtraArgs + 1) result
|
||||
| Key.fvar f 0 => return result
|
||||
| Key.fvar f (n+1) => processRootWithExtra (Key.fvar f n) args.pop result
|
||||
| Key.fvar f (n+1) => process (Key.fvar f n) args.pop (numExtraArgs + 1) result
|
||||
| _ => return result
|
||||
|
||||
process (todo : Array Expr) (c : Trie α) (result : Array α) : MetaM (Array α) := do
|
||||
match c with
|
||||
| Trie.node vs cs =>
|
||||
if todo.isEmpty then
|
||||
return result ++ vs
|
||||
else if cs.isEmpty then
|
||||
return result
|
||||
else
|
||||
let e := todo.back
|
||||
let todo := todo.pop
|
||||
let first := cs[0] /- Recall that `Key.star` is the minimal key -/
|
||||
let (k, args) ← getMatchKeyArgs e (root := false)
|
||||
/- We must always visit `Key.star` edges since they are wildcards.
|
||||
Thus, `todo` is not used linearly when there is `Key.star` edge
|
||||
and there is an edge for `k` and `k != Key.star`. -/
|
||||
let visitStar (result : Array α) : MetaM (Array α) :=
|
||||
if first.1 == Key.star then
|
||||
process todo first.2 result
|
||||
else
|
||||
return result
|
||||
let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
|
||||
match findKey cs k with
|
||||
| none => result
|
||||
| some c => process (todo ++ args) c.2 result
|
||||
let result ← visitStar result
|
||||
match k with
|
||||
| Key.star => result
|
||||
/-
|
||||
Recall that dependent arrows are `(Key.other, #[])`, and non-dependent arrows are `(Key.arrow, #[a, b])`.
|
||||
A non-dependent arrow may be an instance of a dependent arrow (stored at `DiscrTree`). Thus, we also visit the `Key.other` child.
|
||||
-/
|
||||
| Key.arrow => visitNonStar Key.other #[] (← visitNonStar k args result)
|
||||
| _ => visitNonStar k args result
|
||||
|
||||
partial def getUnify (d : DiscrTree α) (e : Expr) : MetaM (Array α) :=
|
||||
withReducible do
|
||||
let (k, args) ← getUnifyKeyArgs e (root := true)
|
||||
|
|
|
|||
|
|
@ -42,72 +42,84 @@ where
|
|||
trace[Meta.Tactic.simp.discharge] "{lemmaName}, failed to synthesize instance{indentExpr type}"
|
||||
return false
|
||||
|
||||
def tryLemma? (e : Expr) (lemma : SimpLemma) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Result) :=
|
||||
private def tryLemmaCore (lhs : Expr) (xs : Array Expr) (bis : Array BinderInfo) (val : Expr) (type : Expr) (e : Expr) (lemma : SimpLemma) (numExtraArgs : Nat) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Result) := do
|
||||
let rec go (e : Expr) : SimpM (Option Result) := do
|
||||
if (← isDefEq lhs e) then
|
||||
unless (← synthesizeArgs lemma.getName xs bis discharge?) do
|
||||
return none
|
||||
let proof ← instantiateMVars (mkAppN val xs)
|
||||
if ← hasAssignableMVar proof then
|
||||
trace[Meta.Tactic.simp.rewrite] "{lemma}, has unassigned metavariables after unification"
|
||||
return none
|
||||
let rhs ← instantiateMVars type.appArg!
|
||||
if e == rhs then
|
||||
return none
|
||||
if lemma.perm && !Expr.lt rhs e then
|
||||
trace[Meta.Tactic.simp.rewrite] "{lemma}, perm rejected {e} ==> {rhs}"
|
||||
return none
|
||||
trace[Meta.Tactic.simp.rewrite] "{lemma}, {e} ==> {rhs}"
|
||||
return some { expr := rhs, proof? := proof }
|
||||
else
|
||||
unless lhs.isMVar do
|
||||
-- We do not report unification failures when `lhs` is a metavariable
|
||||
-- Example: `x = ()`
|
||||
-- TODO: reconsider if we want lemmas such as `(x : Unit) → x = ()`
|
||||
trace[Meta.Tactic.simp.unify] "{lemma}, failed to unify {lhs} with {e}"
|
||||
return none
|
||||
/- Check whether we need something more sophisticated here.
|
||||
This simple approach was good enough for Mathlib 3 -/
|
||||
let mut extraArgs := #[]
|
||||
let mut e := e
|
||||
for i in [:numExtraArgs] do
|
||||
extraArgs := extraArgs.push e.appArg!
|
||||
e := e.appFn!
|
||||
match (← go e) with
|
||||
| none => return none
|
||||
| some { expr := eNew, proof? := none } => return some { expr := mkAppN eNew extraArgs }
|
||||
| some { expr := eNew, proof? := some proof } =>
|
||||
let mut proof := proof
|
||||
for extraArg in extraArgs do
|
||||
proof ← mkCongrFun proof extraArg
|
||||
return some { expr := mkAppN eNew extraArgs, proof? := some proof }
|
||||
|
||||
def tryLemmaWithExtraArgs? (e : Expr) (lemma : SimpLemma) (numExtraArgs : Nat) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Result) :=
|
||||
withNewMCtxDepth do
|
||||
let val ← lemma.getValue
|
||||
let type ← inferType val
|
||||
let (xs, bis, type) ← forallMetaTelescopeReducing type
|
||||
let type ← whnf (← instantiateMVars type)
|
||||
let lhs := type.appFn!.appArg!
|
||||
let rec go (e : Expr) : SimpM (Option Result) := do
|
||||
if (← isDefEq lhs e) then
|
||||
unless (← synthesizeArgs lemma.getName xs bis discharge?) do
|
||||
return none
|
||||
let proof ← instantiateMVars (mkAppN val xs)
|
||||
if ← hasAssignableMVar proof then
|
||||
trace[Meta.Tactic.simp.rewrite] "{lemma}, has unassigned metavariables after unification"
|
||||
return none
|
||||
let rhs ← instantiateMVars type.appArg!
|
||||
if e == rhs then
|
||||
return none
|
||||
if lemma.perm && !Expr.lt rhs e then
|
||||
trace[Meta.Tactic.simp.rewrite] "{lemma}, perm rejected {e} ==> {rhs}"
|
||||
return none
|
||||
trace[Meta.Tactic.simp.rewrite] "{lemma}, {e} ==> {rhs}"
|
||||
return some { expr := rhs, proof? := proof }
|
||||
else
|
||||
unless lhs.isMVar do
|
||||
-- We do not report unification failures when `lhs` is a metavariable
|
||||
-- Example: `x = ()`
|
||||
-- TODO: reconsider if we want lemmas such as `(x : Unit) → x = ()`
|
||||
trace[Meta.Tactic.simp.unify] "{lemma}, failed to unify {lhs} with {e}"
|
||||
return none
|
||||
let lhsNumArgs := lhs.getAppNumArgs
|
||||
let eNumArgs := e.getAppNumArgs
|
||||
if eNumArgs == lhsNumArgs then
|
||||
go e
|
||||
else if eNumArgs < lhsNumArgs then
|
||||
return none
|
||||
else
|
||||
/- Check whether we need something more sophisticated here.
|
||||
This simple approach was good enough for Mathlib 3 -/
|
||||
let mut extraArgs := #[]
|
||||
let mut e := e
|
||||
for i in [:eNumArgs - lhsNumArgs] do
|
||||
extraArgs := extraArgs.push e.appArg!
|
||||
e := e.appFn!
|
||||
match (← go e) with
|
||||
| none => return none
|
||||
| some { expr := eNew, proof? := none } => return some { expr := mkAppN eNew extraArgs }
|
||||
| some { expr := eNew, proof? := some proof } =>
|
||||
let mut proof := proof
|
||||
for extraArg in extraArgs do
|
||||
proof ← mkCongrFun proof extraArg
|
||||
return some { expr := mkAppN eNew extraArgs, proof? := some proof }
|
||||
tryLemmaCore lhs xs bis val type e lemma numExtraArgs discharge?
|
||||
|
||||
def tryLemma? (e : Expr) (lemma : SimpLemma) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Result) := do
|
||||
withNewMCtxDepth do
|
||||
let val ← lemma.getValue
|
||||
let type ← inferType val
|
||||
let (xs, bis, type) ← forallMetaTelescopeReducing type
|
||||
let type ← whnf (← instantiateMVars type)
|
||||
let lhs := type.appFn!.appArg!
|
||||
match (← tryLemmaCore lhs xs bis val type e lemma 0 discharge?) with
|
||||
| some result => return some result
|
||||
| none =>
|
||||
let lhsNumArgs := lhs.getAppNumArgs
|
||||
let eNumArgs := e.getAppNumArgs
|
||||
if eNumArgs > lhsNumArgs then
|
||||
tryLemmaCore lhs xs bis val type e lemma (eNumArgs - lhsNumArgs) discharge?
|
||||
else
|
||||
return none
|
||||
/-
|
||||
Remark: the parameter tag is used for creating trace messages. It is irrelevant otherwise.
|
||||
-/
|
||||
def rewrite (e : Expr) (s : DiscrTree SimpLemma) (erased : Std.PHashSet Name) (discharge? : Expr → SimpM (Option Expr)) (tag : String) : SimpM Result := do
|
||||
let lemmas ← s.getMatch e (allowExtraArgs := true)
|
||||
if lemmas.isEmpty then
|
||||
let candidates ← s.getMatchWithExtra e
|
||||
if candidates.isEmpty then
|
||||
trace[Debug.Meta.Tactic.simp] "no theorems found for {tag}-rewriting {e}"
|
||||
return { expr := e }
|
||||
else
|
||||
let lemmas := lemmas.insertionSort fun e₁ e₂ => e₁.priority < e₂.priority
|
||||
for lemma in lemmas do
|
||||
let candidates := candidates.insertionSort fun e₁ e₂ => e₁.1.priority < e₂.1.priority
|
||||
for (lemma, numExtraArgs) in candidates do
|
||||
unless inErasedSet lemma do
|
||||
if let some result ← tryLemma? e lemma discharge? then
|
||||
if let some result ← tryLemmaWithExtraArgs? e lemma numExtraArgs discharge? then
|
||||
return result
|
||||
return { expr := e }
|
||||
where
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue