diff --git a/src/Lean/Meta/DiscrTree.lean b/src/Lean/Meta/DiscrTree.lean index 92f122adc3..67c0d27488 100644 --- a/src/Lean/Meta/DiscrTree.lean +++ b/src/Lean/Meta/DiscrTree.lean @@ -424,66 +424,76 @@ private def getStarResult (d : DiscrTree α) : Array α := private abbrev findKey (cs : Array (Key × Trie α)) (k : Key) : Option (Key × Trie α) := cs.binSearch (k, arbitrary) (fun a b => a.1 < b.1) +private partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Array α) : MetaM (Array α) := do + match c with + | Trie.node vs cs => + if todo.isEmpty then + return result ++ vs + else if cs.isEmpty then + return result + else + let e := todo.back + let todo := todo.pop + let first := cs[0] /- Recall that `Key.star` is the minimal key -/ + let (k, args) ← getMatchKeyArgs e (root := false) + /- We must always visit `Key.star` edges since they are wildcards. + Thus, `todo` is not used linearly when there is `Key.star` edge + and there is an edge for `k` and `k != Key.star`. -/ + let visitStar (result : Array α) : MetaM (Array α) := + if first.1 == Key.star then + getMatchLoop todo first.2 result + else + return result + let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) := + match findKey cs k with + | none => result + | some c => getMatchLoop (todo ++ args) c.2 result + let result ← visitStar result + match k with + | Key.star => result + /- + Recall that dependent arrows are `(Key.other, #[])`, and non-dependent arrows are `(Key.arrow, #[a, b])`. + A non-dependent arrow may be an instance of a dependent arrow (stored at `DiscrTree`). Thus, we also visit the `Key.other` child. + -/ + | Key.arrow => visitNonStar Key.other #[] (← visitNonStar k args result) + | _ => visitNonStar k args result + +private def getMatchRoot (d : DiscrTree α) (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) := + match d.root.find? k with + | none => return result + | some c => getMatchLoop args c result + /-- Find values that match `e` in `d`. - If `allowExtraArgs == true`, we also return solutions that match prefixes of `e`. -/ -partial def getMatch (d : DiscrTree α) (e : Expr) (allowExtraArgs := false) : MetaM (Array α) := +partial def getMatch (d : DiscrTree α) (e : Expr) : MetaM (Array α) := withReducible do let result := getStarResult d let (k, args) ← getMatchKeyArgs e (root := true) match k with | Key.star => return result - | _ => if allowExtraArgs then processRootWithExtra k args result else processRoot k args result -where - processRoot (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) := do - match d.root.find? k with - | none => return result - | some c => process args c result + | _ => getMatchRoot d k args result - processRootWithExtra (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) := do - let result ← processRoot k args result +/-- + Similar to `getMatch`, but returns solutions that are prefixes of `e`. + We store the number of ignored arguments in the result.-/ +partial def getMatchWithExtra (d : DiscrTree α) (e : Expr) : MetaM (Array (α × Nat)) := + withReducible do + let result := getStarResult d |>.map (., 0) + let (k, args) ← getMatchKeyArgs e (root := true) + match k with + | Key.star => return result + | _ => process k args 0 result +where + process (k : Key) (args : Array Expr) (numExtraArgs : Nat) (result : Array (α × Nat)) : MetaM (Array (α × Nat)) := do + let result := result ++ ((← getMatchRoot d k args #[]).map (., numExtraArgs)) match k with | Key.const f 0 => return result - | Key.const f (n+1) => processRootWithExtra (Key.const f n) args.pop result + | Key.const f (n+1) => process (Key.const f n) args.pop (numExtraArgs + 1) result | Key.fvar f 0 => return result - | Key.fvar f (n+1) => processRootWithExtra (Key.fvar f n) args.pop result + | Key.fvar f (n+1) => process (Key.fvar f n) args.pop (numExtraArgs + 1) result | _ => return result - process (todo : Array Expr) (c : Trie α) (result : Array α) : MetaM (Array α) := do - match c with - | Trie.node vs cs => - if todo.isEmpty then - return result ++ vs - else if cs.isEmpty then - return result - else - let e := todo.back - let todo := todo.pop - let first := cs[0] /- Recall that `Key.star` is the minimal key -/ - let (k, args) ← getMatchKeyArgs e (root := false) - /- We must always visit `Key.star` edges since they are wildcards. - Thus, `todo` is not used linearly when there is `Key.star` edge - and there is an edge for `k` and `k != Key.star`. -/ - let visitStar (result : Array α) : MetaM (Array α) := - if first.1 == Key.star then - process todo first.2 result - else - return result - let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) := - match findKey cs k with - | none => result - | some c => process (todo ++ args) c.2 result - let result ← visitStar result - match k with - | Key.star => result - /- - Recall that dependent arrows are `(Key.other, #[])`, and non-dependent arrows are `(Key.arrow, #[a, b])`. - A non-dependent arrow may be an instance of a dependent arrow (stored at `DiscrTree`). Thus, we also visit the `Key.other` child. - -/ - | Key.arrow => visitNonStar Key.other #[] (← visitNonStar k args result) - | _ => visitNonStar k args result - partial def getUnify (d : DiscrTree α) (e : Expr) : MetaM (Array α) := withReducible do let (k, args) ← getUnifyKeyArgs e (root := true) diff --git a/src/Lean/Meta/Tactic/Simp/Rewrite.lean b/src/Lean/Meta/Tactic/Simp/Rewrite.lean index 1e3b1fbd01..d41f7e48dd 100644 --- a/src/Lean/Meta/Tactic/Simp/Rewrite.lean +++ b/src/Lean/Meta/Tactic/Simp/Rewrite.lean @@ -42,72 +42,84 @@ where trace[Meta.Tactic.simp.discharge] "{lemmaName}, failed to synthesize instance{indentExpr type}" return false -def tryLemma? (e : Expr) (lemma : SimpLemma) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Result) := +private def tryLemmaCore (lhs : Expr) (xs : Array Expr) (bis : Array BinderInfo) (val : Expr) (type : Expr) (e : Expr) (lemma : SimpLemma) (numExtraArgs : Nat) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Result) := do + let rec go (e : Expr) : SimpM (Option Result) := do + if (← isDefEq lhs e) then + unless (← synthesizeArgs lemma.getName xs bis discharge?) do + return none + let proof ← instantiateMVars (mkAppN val xs) + if ← hasAssignableMVar proof then + trace[Meta.Tactic.simp.rewrite] "{lemma}, has unassigned metavariables after unification" + return none + let rhs ← instantiateMVars type.appArg! + if e == rhs then + return none + if lemma.perm && !Expr.lt rhs e then + trace[Meta.Tactic.simp.rewrite] "{lemma}, perm rejected {e} ==> {rhs}" + return none + trace[Meta.Tactic.simp.rewrite] "{lemma}, {e} ==> {rhs}" + return some { expr := rhs, proof? := proof } + else + unless lhs.isMVar do + -- We do not report unification failures when `lhs` is a metavariable + -- Example: `x = ()` + -- TODO: reconsider if we want lemmas such as `(x : Unit) → x = ()` + trace[Meta.Tactic.simp.unify] "{lemma}, failed to unify {lhs} with {e}" + return none + /- Check whether we need something more sophisticated here. + This simple approach was good enough for Mathlib 3 -/ + let mut extraArgs := #[] + let mut e := e + for i in [:numExtraArgs] do + extraArgs := extraArgs.push e.appArg! + e := e.appFn! + match (← go e) with + | none => return none + | some { expr := eNew, proof? := none } => return some { expr := mkAppN eNew extraArgs } + | some { expr := eNew, proof? := some proof } => + let mut proof := proof + for extraArg in extraArgs do + proof ← mkCongrFun proof extraArg + return some { expr := mkAppN eNew extraArgs, proof? := some proof } + +def tryLemmaWithExtraArgs? (e : Expr) (lemma : SimpLemma) (numExtraArgs : Nat) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Result) := withNewMCtxDepth do let val ← lemma.getValue let type ← inferType val let (xs, bis, type) ← forallMetaTelescopeReducing type let type ← whnf (← instantiateMVars type) let lhs := type.appFn!.appArg! - let rec go (e : Expr) : SimpM (Option Result) := do - if (← isDefEq lhs e) then - unless (← synthesizeArgs lemma.getName xs bis discharge?) do - return none - let proof ← instantiateMVars (mkAppN val xs) - if ← hasAssignableMVar proof then - trace[Meta.Tactic.simp.rewrite] "{lemma}, has unassigned metavariables after unification" - return none - let rhs ← instantiateMVars type.appArg! - if e == rhs then - return none - if lemma.perm && !Expr.lt rhs e then - trace[Meta.Tactic.simp.rewrite] "{lemma}, perm rejected {e} ==> {rhs}" - return none - trace[Meta.Tactic.simp.rewrite] "{lemma}, {e} ==> {rhs}" - return some { expr := rhs, proof? := proof } - else - unless lhs.isMVar do - -- We do not report unification failures when `lhs` is a metavariable - -- Example: `x = ()` - -- TODO: reconsider if we want lemmas such as `(x : Unit) → x = ()` - trace[Meta.Tactic.simp.unify] "{lemma}, failed to unify {lhs} with {e}" - return none - let lhsNumArgs := lhs.getAppNumArgs - let eNumArgs := e.getAppNumArgs - if eNumArgs == lhsNumArgs then - go e - else if eNumArgs < lhsNumArgs then - return none - else - /- Check whether we need something more sophisticated here. - This simple approach was good enough for Mathlib 3 -/ - let mut extraArgs := #[] - let mut e := e - for i in [:eNumArgs - lhsNumArgs] do - extraArgs := extraArgs.push e.appArg! - e := e.appFn! - match (← go e) with - | none => return none - | some { expr := eNew, proof? := none } => return some { expr := mkAppN eNew extraArgs } - | some { expr := eNew, proof? := some proof } => - let mut proof := proof - for extraArg in extraArgs do - proof ← mkCongrFun proof extraArg - return some { expr := mkAppN eNew extraArgs, proof? := some proof } + tryLemmaCore lhs xs bis val type e lemma numExtraArgs discharge? +def tryLemma? (e : Expr) (lemma : SimpLemma) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Result) := do + withNewMCtxDepth do + let val ← lemma.getValue + let type ← inferType val + let (xs, bis, type) ← forallMetaTelescopeReducing type + let type ← whnf (← instantiateMVars type) + let lhs := type.appFn!.appArg! + match (← tryLemmaCore lhs xs bis val type e lemma 0 discharge?) with + | some result => return some result + | none => + let lhsNumArgs := lhs.getAppNumArgs + let eNumArgs := e.getAppNumArgs + if eNumArgs > lhsNumArgs then + tryLemmaCore lhs xs bis val type e lemma (eNumArgs - lhsNumArgs) discharge? + else + return none /- Remark: the parameter tag is used for creating trace messages. It is irrelevant otherwise. -/ def rewrite (e : Expr) (s : DiscrTree SimpLemma) (erased : Std.PHashSet Name) (discharge? : Expr → SimpM (Option Expr)) (tag : String) : SimpM Result := do - let lemmas ← s.getMatch e (allowExtraArgs := true) - if lemmas.isEmpty then + let candidates ← s.getMatchWithExtra e + if candidates.isEmpty then trace[Debug.Meta.Tactic.simp] "no theorems found for {tag}-rewriting {e}" return { expr := e } else - let lemmas := lemmas.insertionSort fun e₁ e₂ => e₁.priority < e₂.priority - for lemma in lemmas do + let candidates := candidates.insertionSort fun e₁ e₂ => e₁.1.priority < e₂.1.priority + for (lemma, numExtraArgs) in candidates do unless inErasedSet lemma do - if let some result ← tryLemma? e lemma discharge? then + if let some result ← tryLemmaWithExtraArgs? e lemma numExtraArgs discharge? then return result return { expr := e } where