From 6c8976abbea74e6abb059de6267a0804cf14f614 Mon Sep 17 00:00:00 2001 From: Joe Hendrix Date: Sat, 23 Mar 2024 01:01:35 -0400 Subject: [PATCH] feat: upstream rw? tactic (#3719) This updates the rw? tactic from Mathlib to use lazy discriminator trees and upstreams it. --------- Co-authored-by: Scott Morrison --- src/Init/Tactics.lean | 16 ++ src/Lean/Elab/Tactic.lean | 1 + src/Lean/Elab/Tactic/Rewrites.lean | 69 +++++ src/Lean/Expr.lean | 16 ++ src/Lean/Meta/LazyDiscrTree.lean | 218 ++++++++++----- src/Lean/Meta/Tactic.lean | 1 + src/Lean/Meta/Tactic/LibrarySearch.lean | 22 +- src/Lean/Meta/Tactic/Rewrites.lean | 339 ++++++++++++++++++++++++ tests/lean/run/rewrites.lean | 126 +++++++++ 9 files changed, 737 insertions(+), 71 deletions(-) create mode 100644 src/Lean/Elab/Tactic/Rewrites.lean create mode 100644 src/Lean/Meta/Tactic/Rewrites.lean create mode 100644 tests/lean/run/rewrites.lean diff --git a/src/Init/Tactics.lean b/src/Init/Tactics.lean index ee564806ae..416c69bbfa 100644 --- a/src/Init/Tactics.lean +++ b/src/Init/Tactics.lean @@ -1318,6 +1318,22 @@ used when closing the goal. -/ syntax (name := apply?) "apply?" (" using " (colGt term),+)? : tactic +/-- +Syntax for excluding some names, e.g. `[-my_lemma, -my_theorem]`. +-/ +syntax rewrites_forbidden := " [" (("-" ident),*,?) "]" + +/-- +`rw?` tries to find a lemma which can rewrite the goal. + +`rw?` should not be left in proofs; it is a search tool, like `apply?`. + +Suggestions are printed as `rw [h]` or `rw [← h]`. + +You can use `rw? [-my_lemma, -my_theorem]` to prevent `rw?` using the named lemmas. +-/ +syntax (name := rewrites?) "rw?" (ppSpace location)? (rewrites_forbidden)? : tactic + /-- `show_term tac` runs `tac`, then prints the generated term in the form "exact X Y Z" or "refine X ?_ Z" if there are remaining subgoals. diff --git a/src/Lean/Elab/Tactic.lean b/src/Lean/Elab/Tactic.lean index 493b2e46bd..5d594fc8b2 100644 --- a/src/Lean/Elab/Tactic.lean +++ b/src/Lean/Elab/Tactic.lean @@ -39,3 +39,4 @@ import Lean.Elab.Tactic.SolveByElim import Lean.Elab.Tactic.LibrarySearch import Lean.Elab.Tactic.ShowTerm import Lean.Elab.Tactic.Rfl +import Lean.Elab.Tactic.Rewrites diff --git a/src/Lean/Elab/Tactic/Rewrites.lean b/src/Lean/Elab/Tactic/Rewrites.lean new file mode 100644 index 0000000000..49bb7ab35f --- /dev/null +++ b/src/Lean/Elab/Tactic/Rewrites.lean @@ -0,0 +1,69 @@ +/- +Copyright (c) 2023 Scott Morrison. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Scott Morrison +-/ +prelude +import Lean.Elab.Tactic.Location +import Lean.Meta.Tactic.Replace +import Lean.Meta.Tactic.Rewrites + +/-! +# The `rewrites` tactic. + +`rw?` tries to find a lemma which can rewrite the goal. + +`rw?` should not be left in proofs; it is a search tool, like `apply?`. + +Suggestions are printed as `rw [h]` or `rw [← h]`. + +-/ +namespace Lean.Elab.Rewrites + +open Lean Meta Rewrites +open Lean.Parser.Tactic + +open Lean Elab Tactic + +@[builtin_tactic Lean.Parser.Tactic.rewrites?] +def evalExact : Tactic := fun stx => do + let `(tactic| rw?%$tk $[$loc]? $[[ $[-$forbidden],* ]]?) := stx + | throwUnsupportedSyntax + let moduleRef ← createModuleTreeRef + let forbidden : NameSet := + ((forbidden.getD #[]).map Syntax.getId).foldl (init := ∅) fun s n => s.insert n + reportOutOfHeartbeats `findRewrites tk + let goal ← getMainGoal + withLocation (expandOptLocation (Lean.mkOptionalNode loc)) + fun f => do + let some a ← f.findDecl? | return + if a.isImplementationDetail then return + let target ← instantiateMVars (← f.getType) + let hyps ← localHypotheses (except := [f]) + let results ← findRewrites hyps moduleRef goal target (stopAtRfl := false) forbidden + reportOutOfHeartbeats `rewrites tk + if results.isEmpty then + throwError "Could not find any lemmas which can rewrite the hypothesis {← f.getUserName}" + for r in results do withMCtx r.mctx do + Tactic.TryThis.addRewriteSuggestion tk [(r.expr, r.symm)] + r.result.eNew (loc? := .some (.fvar f)) (origSpan? := ← getRef) + if let some r := results[0]? then + setMCtx r.mctx + let replaceResult ← goal.replaceLocalDecl f r.result.eNew r.result.eqProof + replaceMainGoal (replaceResult.mvarId :: r.result.mvarIds) + do + let target ← instantiateMVars (← goal.getType) + let hyps ← localHypotheses + let results ← findRewrites hyps moduleRef goal target (stopAtRfl := true) forbidden + reportOutOfHeartbeats `rewrites tk + if results.isEmpty then + throwError "Could not find any lemmas which can rewrite the goal" + results.forM (·.addSuggestion tk) + if let some r := results[0]? then + setMCtx r.mctx + replaceMainGoal + ((← goal.replaceTargetEq r.result.eNew r.result.eqProof) :: r.result.mvarIds) + evalTactic (← `(tactic| try rfl)) + (fun _ => throwError "Failed to find a rewrite for some location") + +end Lean.Elab.Rewrites diff --git a/src/Lean/Expr.lean b/src/Lean/Expr.lean index 1f800f4a94..56d64875c9 100644 --- a/src/Lean/Expr.lean +++ b/src/Lean/Expr.lean @@ -1881,6 +1881,22 @@ def letFunAppArgs? (e : Expr) : Option (Array Expr × Name × Expr × Expr × Ex | .lam n _ b _ => some (rest, n, t, v, b) | _ => some (rest, .anonymous, t, v, .app f (.bvar 0)) +/-- Maps `f` on each immediate child of the given expression. -/ +@[specialize] +def traverseChildren [Applicative M] (f : Expr → M Expr) : Expr → M Expr + | e@(forallE _ d b _) => pure e.updateForallE! <*> f d <*> f b + | e@(lam _ d b _) => pure e.updateLambdaE! <*> f d <*> f b + | e@(mdata _ b) => e.updateMData! <$> f b + | e@(letE _ t v b _) => pure e.updateLet! <*> f t <*> f v <*> f b + | e@(app l r) => pure e.updateApp! <*> f l <*> f r + | e@(proj _ _ b) => e.updateProj! <$> f b + | e => pure e + +/-- `e.foldlM f a` folds the monadic function `f` over the subterms of the expression `e`, +with initial value `a`. -/ +def foldlM {α : Type} {m} [Monad m] (f : α → Expr → m α) (init : α) (e : Expr) : m α := + Prod.snd <$> StateT.run (e.traverseChildren (fun e' => fun a => Prod.mk e' <$> f a e')) init + end Expr /-- diff --git a/src/Lean/Meta/LazyDiscrTree.lean b/src/Lean/Meta/LazyDiscrTree.lean index d3bf5f9315..103b760510 100644 --- a/src/Lean/Meta/LazyDiscrTree.lean +++ b/src/Lean/Meta/LazyDiscrTree.lean @@ -393,26 +393,37 @@ Get the root key and rest of terms of an expression using the specified config. private def rootKey (cfg: WhnfCoreConfig) (e : Expr) : MetaM (Key × Array Expr) := pushArgs true (Array.mkEmpty initCapacity) e cfg -private partial def mkPathAux (root : Bool) (todo : Array Expr) (keys : Array Key) - (config : WhnfCoreConfig) : MetaM (Array Key) := do +private partial def buildPath (op : Bool → Array Expr → Expr → MetaM (Key × Array Expr)) (root : Bool) (todo : Array Expr) (keys : Array Key) : MetaM (Array Key) := do if todo.isEmpty then return keys else let e := todo.back let todo := todo.pop - let (k, todo) ← pushArgs root todo e config - mkPathAux false todo (keys.push k) config + let (k, todo) ← op root todo e + buildPath op false todo (keys.push k) /-- -Create a path from an expression. +Create a key path from an expression using the function used for patterns. -This differs from Lean.Meta.DiscrTree.mkPath in that the expression +This differs from Lean.Meta.DiscrTree.mkPath and targetPath in that the expression should uses free variables rather than meta-variables for holes. -/ -private def mkPath (e : Expr) (config : WhnfCoreConfig) : MetaM (Array Key) := do +def patternPath (e : Expr) (config : WhnfCoreConfig) : MetaM (Array Key) := do let todo : Array Expr := .mkEmpty initCapacity - let keys : Array Key := .mkEmpty initCapacity - mkPathAux (root := true) (todo.push e) keys config + let op root todo e := pushArgs root todo e config + buildPath op (root := true) (todo.push e) (.mkEmpty initCapacity) + +/-- +Create a key path from an expression we are matching against. + +This should have mvars instantiated where feasible. +-/ +def targetPath (e : Expr) (config : WhnfCoreConfig) : MetaM (Array Key) := do + let todo : Array Expr := .mkEmpty initCapacity + let op root todo e := do + let (k, args) ← MatchClone.getMatchKeyArgs e root config + pure (k, todo ++ args) + buildPath op (root := true) (todo.push e) (.mkEmpty initCapacity) /- Monad for finding matches while resolving deferred patterns. -/ @[reducible] @@ -512,7 +523,7 @@ A match result contains the terms formed from matching a term against patterns in the discrimination tree. -/ -private structure MatchResult (α : Type) where +structure MatchResult (α : Type) where /-- The elements in the match result. @@ -525,7 +536,9 @@ private structure MatchResult (α : Type) where -/ elts : Array (Array (Array α)) := #[] -private def MatchResult.push (r : MatchResult α) (score : Nat) (e : Array α) : MatchResult α := +namespace MatchResult + +private def push (r : MatchResult α) (score : Nat) (e : Array α) : MatchResult α := if e.isEmpty then r else if score < r.elts.size then @@ -539,14 +552,28 @@ private def MatchResult.push (r : MatchResult α) (score : Nat) (e : Array α) : termination_by score - a.size loop r.elts -private partial def MatchResult.toArray (mr : MatchResult α) : Array α := - loop (Array.mkEmpty n) mr.elts - where n := mr.elts.foldl (fun i a => a.foldl (fun n a => n + a.size) i) 0 - loop (r : Array α) (a : Array (Array (Array α))) := - if a.isEmpty then - r - else - loop (a.back.foldl (init := r) (fun r a => r ++ a)) a.pop +/-- +Number of elements in result +-/ +partial def size (mr : MatchResult α) : Nat := + mr.elts.foldl (fun i a => a.foldl (fun n a => n + a.size) i) 0 + +/-- +Append results to array +-/ +@[specialize] +partial def appendResultsAux (mr : MatchResult α) (a : Array β) (f : Nat → α → β) : Array β := + let aa := mr.elts + let n := aa.size + Nat.fold (n := n) (init := a) fun i r => + let j := n-1-i + let b := aa[j]! + b.foldl (init := r) (· ++ ·.map (f j)) + +partial def appendResults (mr : MatchResult α) (a : Array α) : Array α := + mr.appendResultsAux a (fun _ a => a) + +end MatchResult private partial def getMatchLoop (todo : Array Expr) (score : Nat) (c : TrieIndex) (result : MatchResult α) : MatchM α (MatchResult α) := do @@ -619,8 +646,8 @@ private def getMatchCore (root : Lean.HashMap Key TrieIndex) (e : Expr) : The results are ordered so that the longest matches in terms of number of non-star keys are first with ties going to earlier operators first. -/ -def getMatch (d : LazyDiscrTree α) (e : Expr) : MetaM (Array α × LazyDiscrTree α) := - withReducible <| runMatch d <| (·.toArray) <$> getMatchCore d.roots e +def getMatch (d : LazyDiscrTree α) (e : Expr) : MetaM (MatchResult α × LazyDiscrTree α) := + withReducible <| runMatch d <| getMatchCore d.roots e /-- Structure for quickly initializing a lazy discrimination tree with a large number @@ -845,21 +872,11 @@ def createLocalPreDiscrTree let r ← (env.constants.map₂.foldlM (init := {}) act : BaseIO (PreDiscrTree α)) pure r -/-- Create an imported environment for tree. -/ -def createLocalEnvironment - (act : Name → ConstantInfo → MetaM (Array (InitEntry α))) : - CoreM (LazyDiscrTree α) := do - let env ← getEnv - let ngen ← getChildNgen - let d ← ImportData.new - let t ← createLocalPreDiscrTree ngen env d act - let errors ← d.errors.get - if p : errors.size > 0 then - throw errors[0].exception - pure <| t.toLazy +def dropKeys (t : LazyDiscrTree α) (keys : List (List LazyDiscrTree.Key)) : MetaM (LazyDiscrTree α) := do + keys.foldlM (init := t) (·.dropKey ·) -/-- Create an imported environment for tree. -/ -def createImportedEnvironment (ngen : NameGenerator) (env : Environment) +/-- Create a discriminator tree for imported environment. -/ +def createImportedDiscrTree (ngen : NameGenerator) (env : Environment) (act : Name → ConstantInfo → MetaM (Array (InitEntry α))) (constantsPerTask : Nat := 1000) : EIO Exception (LazyDiscrTree α) := do @@ -889,23 +906,12 @@ def createImportedEnvironment (ngen : NameGenerator) (env : Environment) throw r.errors[0].exception pure <| r.tree.toLazy -def dropKeys (t : LazyDiscrTree α) (keys : List (List LazyDiscrTree.Key)) : MetaM (LazyDiscrTree α) := do - keys.foldlM (init := t) (·.dropKey ·) - -/-- -`findCandidates` searches for entries in a lazily initialized discriminator tree. - -* `ext` should be an environment extension with an IO.Ref for caching the import lazy - discriminator tree. -* `addEntry` is the function for creating discriminator tree entries from constants. -* `droppedKeys` contains keys we do not want to consider when searching for matches. - It is used for dropping very general keys. --/ -def findCandidates (ext : EnvExtension (IO.Ref (Option (LazyDiscrTree α)))) - (addEntry : Name → ConstantInfo → MetaM (Array (InitEntry α))) - (droppedKeys : List (List LazyDiscrTree.Key) := []) - (constantsPerTask : Nat := 1000) - (ty : Expr) : MetaM (Array α) := do +def findImportMatches + (ext : EnvExtension (IO.Ref (Option (LazyDiscrTree α)))) + (addEntry : Name → ConstantInfo → MetaM (Array (InitEntry α))) + (droppedKeys : List (List LazyDiscrTree.Key) := []) + (constantsPerTask : Nat := 1000) + (ty : Expr) : MetaM (MatchResult α) := do let ngen ← getNGen let (cNGen, ngen) := ngen.mkChild setNGen ngen @@ -913,14 +919,106 @@ def findCandidates (ext : EnvExtension (IO.Ref (Option (LazyDiscrTree α)))) let ref := @EnvExtension.getState _ ⟨dummy⟩ ext (←getEnv) let importTree ← (←ref.get).getDM $ do profileitM Exception "lazy discriminator import initialization" (←getOptions) $ do - let t ← createImportedEnvironment cNGen (←getEnv) addEntry + let t ← createImportedDiscrTree cNGen (←getEnv) addEntry (constantsPerTask := constantsPerTask) dropKeys t droppedKeys - let (localCandidates, _) ← - profileitM Exception "lazy discriminator local search" (←getOptions) $ do - let t ← createLocalEnvironment addEntry - let t ← dropKeys t droppedKeys - t.getMatch ty let (importCandidates, importTree) ← importTree.getMatch ty - ref.set importTree - pure (localCandidates ++ importCandidates) + ref.set (some importTree) + pure importCandidates + +/-- +A discriminator tree for the current module's declarations only. + +Note. We use different discriminator trees for imported and current module +declarations since imported declarations are typically much more numerous but +not changed after the environment is created. +-/ +structure ModuleDiscrTreeRef (α : Type _) where + ref : IO.Ref (LazyDiscrTree α) + +/-- Create a discriminator tree for current module declarations. -/ +def createModuleDiscrTree + (entriesForConst : Name → ConstantInfo → MetaM (Array (InitEntry α))) : + CoreM (LazyDiscrTree α) := do + let env ← getEnv + let ngen ← getChildNgen + let d ← ImportData.new + let t ← createLocalPreDiscrTree ngen env d entriesForConst + let errors ← d.errors.get + if p : errors.size > 0 then + throw errors[0].exception + pure <| t.toLazy + +/-- +Creates reference for lazy discriminator tree that only contains this module's definitions. +-/ +def createModuleTreeRef (entriesForConst : Name → ConstantInfo → MetaM (Array (InitEntry α))) + (droppedKeys : List (List LazyDiscrTree.Key)) : MetaM (ModuleDiscrTreeRef α) := do + profileitM Exception "build module discriminator tree" (←getOptions) $ do + let t ← createModuleDiscrTree entriesForConst + let t ← dropKeys t droppedKeys + pure { ref := ← IO.mkRef t } + +/-- +Returns candidates from this module in this module that match the expression. + +* `moduleRef` is a references to a lazy discriminator tree only containing +this module's definitions. +-/ +def findModuleMatches (moduleRef : ModuleDiscrTreeRef α) (ty : Expr) : MetaM (MatchResult α) := do + profileitM Exception "lazy discriminator local search" (←getOptions) $ do + let discrTree ← moduleRef.ref.get + let (localCandidates, localTree) ← discrTree.getMatch ty + moduleRef.ref.set localTree + pure localCandidates + +/-- +`findMatchesExt` searches for entries in a lazily initialized discriminator tree. + +It provides some additional capabilities beyond `findMatches` to adjust results +based on priority and cache module declarations + +* `modulesTreeRef` points to the discriminator tree for local environment. + Used for caching and created by `createLocalTree`. +* `ext` should be an environment extension with an IO.Ref for caching the import lazy + discriminator tree. +* `addEntry` is the function for creating discriminator tree entries from constants. +* `droppedKeys` contains keys we do not want to consider when searching for matches. + It is used for dropping very general keys. +* `constantsPerTask` stores number of constants in imported modules used to + decide when to create new task. +* `adjustResult` takes the priority and value to produce a final result. +* `ty` is the expression type. +-/ +def findMatchesExt + (moduleTreeRef : ModuleDiscrTreeRef α) + (ext : EnvExtension (IO.Ref (Option (LazyDiscrTree α)))) + (addEntry : Name → ConstantInfo → MetaM (Array (InitEntry α))) + (droppedKeys : List (List LazyDiscrTree.Key) := []) + (constantsPerTask : Nat := 1000) + (adjustResult : Nat → α → β) + (ty : Expr) : MetaM (Array β) := do + let moduleMatches ← findModuleMatches moduleTreeRef ty + let importMatches ← findImportMatches ext addEntry droppedKeys constantsPerTask ty + return Array.mkEmpty (moduleMatches.size + importMatches.size) + |> moduleMatches.appendResultsAux (f := adjustResult) + |> importMatches.appendResultsAux (f := adjustResult) + +/-- +`findMatches` searches for entries in a lazily initialized discriminator tree. + +* `ext` should be an environment extension with an IO.Ref for caching the import lazy + discriminator tree. +* `addEntry` is the function for creating discriminator tree entries from constants. +* `droppedKeys` contains keys we do not want to consider when searching for matches. + It is used for dropping very general keys. +-/ +def findMatches (ext : EnvExtension (IO.Ref (Option (LazyDiscrTree α)))) + (addEntry : Name → ConstantInfo → MetaM (Array (InitEntry α))) + (droppedKeys : List (List LazyDiscrTree.Key) := []) + (constantsPerTask : Nat := 1000) + (ty : Expr) : MetaM (Array α) := do + + let moduleTreeRef ← createModuleTreeRef addEntry droppedKeys + let incPrio _ v := v + findMatchesExt moduleTreeRef ext addEntry droppedKeys constantsPerTask incPrio ty diff --git a/src/Lean/Meta/Tactic.lean b/src/Lean/Meta/Tactic.lean index 188522c689..b50abef054 100644 --- a/src/Lean/Meta/Tactic.lean +++ b/src/Lean/Meta/Tactic.lean @@ -39,3 +39,4 @@ import Lean.Meta.Tactic.Backtrack import Lean.Meta.Tactic.SolveByElim import Lean.Meta.Tactic.FunInd import Lean.Meta.Tactic.Rfl +import Lean.Meta.Tactic.Rewrites diff --git a/src/Lean/Meta/Tactic/LibrarySearch.lean b/src/Lean/Meta/Tactic/LibrarySearch.lean index 55a3d7f3c3..f7b23466c8 100644 --- a/src/Lean/Meta/Tactic/LibrarySearch.lean +++ b/src/Lean/Meta/Tactic/LibrarySearch.lean @@ -67,7 +67,7 @@ to find candidate lemmas. @[reducible] def CandidateFinder := Expr → MetaM (Array (Name × DeclMod)) -open LazyDiscrTree (InitEntry findCandidates) +open LazyDiscrTree (InitEntry findMatches) private def addImport (name : Name) (constInfo : ConstantInfo) : MetaM (Array (InitEntry (Name × DeclMod))) := @@ -111,7 +111,7 @@ private def constantsPerImportTask : Nat := 6500 /-- Create function for finding relevant declarations. -/ def libSearchFindDecls : Expr → MetaM (Array (Name × DeclMod)) := - findCandidates ext addImport + findMatches ext addImport (droppedKeys := droppedKeys) (constantsPerTask := constantsPerImportTask) @@ -278,15 +278,15 @@ private def librarySearch' (goal : MVarId) MetaM (Option (Array (List MVarId × MetavarContext))) := do withTraceNode `Tactic.librarySearch (return m!"{librarySearchEmoji ·} {← goal.getType}") do profileitM Exception "librarySearch" (← getOptions) do - -- Create predicate that returns true when running low on heartbeats. - let candidates ← librarySearchSymm libSearchFindDecls goal - let cfg : ApplyConfig := { allowSynthFailures := true } - let shouldAbort ← mkHeartbeatCheck leavePercentHeartbeats - let act := fun cand => do - if ←shouldAbort then - abortSpeculation - librarySearchLemma cfg tactic allowFailure cand - tryOnEach act candidates + -- Create predicate that returns true when running low on heartbeats. + let candidates ← librarySearchSymm libSearchFindDecls goal + let cfg : ApplyConfig := { allowSynthFailures := true } + let shouldAbort ← mkHeartbeatCheck leavePercentHeartbeats + let act := fun cand => do + if ←shouldAbort then + abortSpeculation + librarySearchLemma cfg tactic allowFailure cand + tryOnEach act candidates /-- Tries to solve the goal either by: diff --git a/src/Lean/Meta/Tactic/Rewrites.lean b/src/Lean/Meta/Tactic/Rewrites.lean new file mode 100644 index 0000000000..c483235199 --- /dev/null +++ b/src/Lean/Meta/Tactic/Rewrites.lean @@ -0,0 +1,339 @@ +/- +Copyright (c) 2023 Scott Morrison. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Scott Morrison +-/ +prelude +import Lean.Meta.LazyDiscrTree +import Lean.Meta.Tactic.Assumption +import Lean.Meta.Tactic.Rewrite +import Lean.Meta.Tactic.Rfl +import Lean.Meta.Tactic.SolveByElim +import Lean.Meta.Tactic.TryThis +import Lean.Util.Heartbeats + +namespace Lean.Meta.Rewrites + +open Lean.Meta.LazyDiscrTree (InitEntry MatchResult) +open Lean.Meta.SolveByElim + +builtin_initialize registerTraceClass `Tactic.rewrites +builtin_initialize registerTraceClass `Tactic.rewrites.lemmas + +/-- Extract the lemma, with arguments, that was used to produce a `RewriteResult`. -/ +-- This assumes that `r.eqProof` was constructed as: +-- `mkApp6 (.const ``congrArg _) α eType lhs rhs motive heq` +-- in `Lean.Meta.Tactic.Rewrite` and we want `heq`. +def rewriteResultLemma (r : RewriteResult) : Option Expr := + if r.eqProof.isAppOfArity ``congrArg 6 then + r.eqProof.getArg! 5 + else + none + +/-- Weight to multiply the "specificity" of a rewrite lemma by when rewriting forwards. -/ +def forwardWeight := 2 +/-- Weight to multiply the "specificity" of a rewrite lemma by when rewriting backwards. -/ +def backwardWeight := 1 + + +private def addImport (name : Name) (constInfo : ConstantInfo) : + MetaM (Array (InitEntry (Name × Bool × Nat))) := do + if constInfo.isUnsafe then return #[] + if !allowCompletion (←getEnv) name then return #[] + -- We now remove some injectivity lemmas which are not useful to rewrite by. + if name matches .str _ "injEq" then return #[] + if name matches .str _ "sizeOf_spec" then return #[] + match name with + | .str _ n => if n.endsWith "_inj" ∨ n.endsWith "_inj'" then return #[] + | _ => pure () + withNewMCtxDepth do withReducible do + forallTelescopeReducing constInfo.type fun _ type => do + match type.getAppFnArgs with + | (``Eq, #[_, lhs, rhs]) + | (``Iff, #[lhs, rhs]) => do + let a := Array.mkEmpty 2 + let a := a.push (← InitEntry.fromExpr lhs (name, false, forwardWeight)) + let a := a.push (← InitEntry.fromExpr rhs (name, true, backwardWeight)) + pure a + | _ => return #[] + +/-- Configuration for `DiscrTree`. -/ +def discrTreeConfig : WhnfCoreConfig := {} + +/-- Select `=` and `↔` local hypotheses. -/ +def localHypotheses (except : List FVarId := []) : MetaM (Array (Expr × Bool × Nat)) := do + let r ← getLocalHyps + let mut result := #[] + for h in r do + if except.contains h.fvarId! then continue + let (_, _, type) ← forallMetaTelescopeReducing (← inferType h) + let type ← whnfR type + match type.getAppFnArgs with + | (``Eq, #[_, lhs, rhs]) + | (``Iff, #[lhs, rhs]) => do + let lhsKey : Array DiscrTree.Key ← DiscrTree.mkPath lhs discrTreeConfig + let rhsKey : Array DiscrTree.Key ← DiscrTree.mkPath rhs discrTreeConfig + result := result.push (h, false, forwardWeight * lhsKey.size) + |>.push (h, true, backwardWeight * rhsKey.size) + | _ => pure () + return result + +/-- +We drop `.star` and `Eq * * *` from the discriminator trees because +they match too much. +-/ +def droppedKeys : List (List LazyDiscrTree.Key) := [[.star], [.const `Eq 3, .star, .star, .star]] + +def createModuleTreeRef : MetaM (LazyDiscrTree.ModuleDiscrTreeRef (Name × Bool × Nat)) := + LazyDiscrTree.createModuleTreeRef addImport droppedKeys + +private def ExtState := IO.Ref (Option (LazyDiscrTree (Name × Bool × Nat))) + +private builtin_initialize ExtState.default : IO.Ref (Option (LazyDiscrTree (Name × Bool × Nat))) ← do + IO.mkRef .none + +private instance : Inhabited ExtState where + default := ExtState.default + +private builtin_initialize ext : EnvExtension ExtState ← + registerEnvExtension (IO.mkRef .none) + +/-- +The maximum number of constants an individual task may perform. + +The value was picked because it roughly correponded to 50ms of work on the +machine this was developed on. Smaller numbers did not seem to improve +performance when importing Std and larger numbers (<10k) seemed to degrade +initialization performance. +-/ +private def constantsPerImportTask : Nat := 6500 + +def incPrio : Nat → Name × Bool × Nat → Name × Bool × Nat +| p, (nm, d, prio) => (nm, d, prio * 100 + p) + +/-- Create function for finding relevant declarations. -/ +def rwFindDecls (moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × Bool × Nat)) : Expr → MetaM (Array (Name × Bool × Nat)) := + LazyDiscrTree.findMatchesExt moduleRef ext addImport + (droppedKeys := droppedKeys) + (constantsPerTask := constantsPerImportTask) + (adjustResult := incPrio) + +/-- Data structure recording a potential rewrite to report from the `rw?` tactic. -/ +structure RewriteResult where + /-- The lemma we rewrote by. + This is `Expr`, not just a `Name`, as it may be a local hypothesis. -/ + expr : Expr + /-- `True` if we rewrote backwards (i.e. with `rw [← h]`). -/ + symm : Bool + /-- The "weight" of the rewrite. This is calculated based on how specific the rewrite rule was. -/ + weight : Nat + /-- The result from the `rw` tactic. -/ + result : Meta.RewriteResult + /-- The metavariable context after the rewrite. + This needs to be stored as part of the result so we can backtrack the state. -/ + mctx : MetavarContext + rfl? : Bool + +/-- Update a `RewriteResult` by filling in the `rfl?` field if it is currently `none`, +to reflect whether the remaining goal can be closed by `with_reducible rfl`. -/ +def computeRfl (mctx : MetavarContext) (res : Meta.RewriteResult) : MetaM Bool := do + try + withoutModifyingState <| withMCtx mctx do + -- We use `withReducible` here to follow the behaviour of `rw`. + withReducible (← mkFreshExprMVar res.eNew).mvarId!.applyRfl + -- We do not need to record the updated `MetavarContext` here. + pure true + catch _e => + pure false + +/-- +Pretty print the result of the rewrite. +-/ +private def RewriteResult.ppResult (r : RewriteResult) : MetaM String := + return (← ppExpr r.result.eNew).pretty + + +/-- Should we try discharging side conditions? If so, using `assumption`, or `solve_by_elim`? -/ +inductive SideConditions +| none +| assumption +| solveByElim + +/-- Shortcut for calling `solveByElim`. -/ +def solveByElim (goals : List MVarId) (depth : Nat := 6) : MetaM PUnit := do + -- There is only a marginal decrease in performance for using the `symm` option for `solveByElim`. + -- (measured via `lake build && time lake env lean test/librarySearch.lean`). + let cfg : SolveByElimConfig := { maxDepth := depth, exfalso := false, symm := true } + let ⟨lemmas, ctx⟩ ← mkAssumptionSet false false [] [] #[] + let [] ← SolveByElim.solveByElim cfg lemmas ctx goals + | failure + +def rwLemma (ctx : MetavarContext) (goal : MVarId) (target : Expr) (side : SideConditions := .solveByElim) + (lem : Expr ⊕ Name) (symm : Bool) (weight : Nat) : MetaM (Option RewriteResult) := + withMCtx ctx do + let some expr ← (match lem with + | .inl hyp => pure (some hyp) + | .inr lem => some <$> mkConstWithFreshMVarLevels lem <|> pure none) + | return none + trace[Tactic.rewrites] m!"considering {if symm then "← " else ""}{expr}" + let some result ← some <$> goal.rewrite target expr symm <|> pure none + | return none + if result.mvarIds.isEmpty then + let mctx ← getMCtx + let rfl? ← computeRfl mctx result + return some { expr, symm, weight, result, mctx, rfl? } + else + -- There are side conditions, which we try to discharge using local hypotheses. + let discharge ← + match side with + | .none => pure false + | .assumption => ((fun _ => true) <$> result.mvarIds.mapM fun m => m.assumption) <|> pure false + | .solveByElim => (solveByElim result.mvarIds >>= fun _ => pure true) <|> pure false + match discharge with + | false => + return none + | true => + -- If we succeed, we need to reconstruct the expression to report that we rewrote by. + let some expr := rewriteResultLemma result | return none + let expr ← instantiateMVars expr + let (expr, symm) := if expr.isAppOfArity ``Eq.symm 4 then + (expr.getArg! 3, true) + else + (expr, false) + let mctx ← getMCtx + let rfl? ← computeRfl mctx result + return some { expr, symm, weight, result, mctx, rfl? } + +/-- +Find keys which match the expression, or some subexpression. + +Note that repeated subexpressions will be visited each time they appear, +making this operation potentially very expensive. +It would be good to solve this problem! + +Implementation: we reverse the results from `getMatch`, +so that we return lemmas matching larger subexpressions first, +and amongst those we return more specific lemmas first. +-/ +partial def getSubexpressionMatches (op : Expr → MetaM (Array α)) (e : Expr) : MetaM (Array α) := do + match e with + | .bvar _ => return #[] + | .forallE _ _ _ _ => + forallTelescope e fun args body => do + args.foldlM (fun acc arg => return acc ++ (← getSubexpressionMatches op (← inferType arg))) + (← getSubexpressionMatches op body).reverse + | .lam _ _ _ _ + | .letE _ _ _ _ _ => + lambdaLetTelescope e (fun args body => do + args.foldlM (fun acc arg => return acc ++ (← getSubexpressionMatches op (← inferType arg))) + (← getSubexpressionMatches op body).reverse) + | _ => + let init := ((← op e).reverse) + e.foldlM (init := init) (fun a f => return a ++ (← getSubexpressionMatches op f)) + +/-- +Find lemmas which can rewrite the goal. + +See also `rewrites` for a more convenient interface. +-/ +def rewriteCandidates (hyps : Array (Expr × Bool × Nat)) + (moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × Bool × Nat)) + (target : Expr) + (forbidden : NameSet := ∅) : + MetaM (Array ((Expr ⊕ Name) × Bool × Nat)) := do + -- Get all lemmas which could match some subexpression + let candidates ← getSubexpressionMatches (rwFindDecls moduleRef) target + -- Sort them by our preferring weighting + -- (length of discriminant key, doubled for the forward implication) + let candidates := candidates.insertionSort fun (_, _, rp) (_, _, sp) => rp > sp + + -- Now deduplicate. We can't use `Array.deduplicateSorted` as we haven't completely sorted, + -- and in fact want to keep some of the residual ordering from the discrimination tree. + let mut forward : NameSet := ∅ + let mut backward : NameSet := ∅ + let mut deduped := #[] + for (l, s, w) in candidates do + if forbidden.contains l then continue + if s then + if ¬ backward.contains l then + deduped := deduped.push (l, s, w) + backward := backward.insert l + else + if ¬ forward.contains l then + deduped := deduped.push (l, s, w) + forward := forward.insert l + + trace[Tactic.rewrites.lemmas] m!"Candidate rewrite lemmas:\n{deduped}" + + let hyps := hyps.map fun ⟨hyp, symm, weight⟩ => (Sum.inl hyp, symm, weight) + let lemmas := deduped.map fun ⟨lem, symm, weight⟩ => (Sum.inr lem, symm, weight) + pure <| hyps ++ lemmas + +def RewriteResult.newGoal (r : RewriteResult) : Option Expr := + if r.rfl? = true then + some (Expr.lit (.strVal "no goals")) + else + some r.result.eNew + +def RewriteResult.addSuggestion (ref : Syntax) (r : RewriteResult) : Elab.TermElabM Unit := do + withMCtx r.mctx do + Tactic.TryThis.addRewriteSuggestion ref [(r.expr, r.symm)] (type? := r.newGoal) (origSpan? := ← getRef) + +structure RewriteResultConfig where + stopAtRfl : Bool + max : Nat + minHeartbeats : Nat + goal : MVarId + target : Expr + side : SideConditions := .solveByElim + mctx : MetavarContext + +def takeListAux (cfg : RewriteResultConfig) (seen : HashMap String Unit) (acc : Array RewriteResult) + (xs : List ((Expr ⊕ Name) × Bool × Nat)) : MetaM (Array RewriteResult) := do + let mut seen := seen + let mut acc := acc + for (lem, symm, weight) in xs do + if (← getRemainingHeartbeats) < cfg.minHeartbeats then + return acc + if acc.size ≥ cfg.max then + return acc + let res ← + withoutModifyingState <| withMCtx cfg.mctx do + rwLemma cfg.mctx cfg.goal cfg.target cfg.side lem symm weight + match res with + | none => continue + | some r => + let s ← withoutModifyingState <| withMCtx r.mctx r.ppResult + if seen.contains s then + continue + let rfl? ← computeRfl r.mctx r.result + if cfg.stopAtRfl then + if rfl? then + return #[r] + else + seen := seen.insert s () + acc := acc.push r + else + seen := seen.insert s () + acc := acc.push r + return acc + +/-- Find lemmas which can rewrite the goal. -/ +def findRewrites (hyps : Array (Expr × Bool × Nat)) + (moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × Bool × Nat)) + (goal : MVarId) (target : Expr) + (forbidden : NameSet := ∅) (side : SideConditions := .solveByElim) + (stopAtRfl : Bool) (max : Nat := 20) + (leavePercentHeartbeats : Nat := 10) : MetaM (List RewriteResult) := do + let mctx ← getMCtx + let candidates ← rewriteCandidates hyps moduleRef target forbidden + let minHeartbeats : Nat := + if (← getMaxHeartbeats) = 0 then + 0 + else + leavePercentHeartbeats * (← getRemainingHeartbeats) / 100 + let cfg : RewriteResultConfig := + { stopAtRfl, minHeartbeats, max, mctx, goal, target, side } + return (← takeListAux cfg {} (Array.mkEmpty max) candidates.toList).toList + +end Lean.Meta.Rewrites diff --git a/tests/lean/run/rewrites.lean b/tests/lean/run/rewrites.lean new file mode 100644 index 0000000000..844bb721e8 --- /dev/null +++ b/tests/lean/run/rewrites.lean @@ -0,0 +1,126 @@ +attribute [refl] Eq.refl + +private axiom test_sorry : ∀ {α}, α + +-- To see the (sorted) list of lemmas that `rw?` will try rewriting by, use: +-- set_option trace.Tactic.rewrites.lemmas true + +/-- +info: Try this: rw [@List.map_append] +-- "no goals" +-/ +#guard_msgs in +example (f : α → β) (L M : List α) : (L ++ M).map f = L.map f ++ M.map f := by + rw? + +/-- +info: Try this: rw [Nat.one_mul] +-- "no goals" +-/ +#guard_msgs in +example (h : Nat) : 1 * h = h := by + rw? + +#guard_msgs(drop info) in +example (h : Int) (hyp : g * 1 = h) : g = h := by + rw? at hyp + assumption + +#guard_msgs(drop info) in +example : ∀ (x y : Nat), x ≤ y := by + intros x y + rw? -- Used to be an error here https://leanprover.zulipchat.com/#narrow/stream/287929-mathlib4/topic/panic.20and.20error.20with.20rw.3F/near/370495531 + exact test_sorry + +example : ∀ (x y : Nat), x ≤ y := by + -- Used to be a panic here https://leanprover.zulipchat.com/#narrow/stream/287929-mathlib4/topic/panic.20and.20error.20with.20rw.3F/near/370495531 + fail_if_success rw? + exact test_sorry + +axiom K : Type +@[instance] axiom K.hasOne : OfNat K 1 +@[instance] axiom K.hasIntCoe : Coe K Int + +noncomputable def foo : K → K := test_sorry + +#guard_msgs(drop info) in +example : foo x = 1 ↔ ∃ k : Int, x = k := by + rw? -- Used to panic, see https://leanprover.zulipchat.com/#narrow/stream/287929-mathlib4/topic/panic.20and.20error.20with.20rw.3F/near/370598036 + exact test_sorry + +theorem six_eq_seven : 6 = 7 := test_sorry + +-- This test also verifies that we are removing duplicate results; +-- it previously also reported `Nat.cast_ofNat` +#guard_msgs(drop info) in +example : ∀ (x : Nat), x ≤ 6 := by + rw? + guard_target = ∀ (x : Nat), x ≤ 7 + exact test_sorry + +#guard_msgs(drop info) in +example : ∀ (x : Nat) (_w : x ≤ 6), x ≤ 8 := by + rw? + guard_target = ∀ (x : Nat) (_w : x ≤ 7), x ≤ 8 + exact test_sorry + +-- check we can look inside let expressions +#guard_msgs(drop info) in +example (n : Nat) : let y := 3; n + y = 3 + n := by + rw? + +axiom α : Type +axiom f : α → α +axiom z : α +axiom f_eq (n) : f n = z + +-- Check that the same lemma isn't used multiple times. +-- This used to report two redundant copies of `f_eq`. +-- It be lovely if `rw?` could produce two *different* rewrites by `f_eq` here! +#guard_msgs(drop info) in +theorem test : f n = f m := by + fail_if_success rw? [-f_eq] -- Check that we can forbid lemmas. + rw? + rw [f_eq] + +-- Check that we can rewrite by local hypotheses. +#guard_msgs(drop info) in +example (h : 1 = 2) : 2 = 1 := by + rw? + +def zero : Nat := 0 + +-- This used to (incorrectly!) succeed because `rw?` would try `rfl`, +-- rather than `withReducible` `rfl`. +#guard_msgs(drop info) in +example : zero = 0 := by + rw? + exact test_sorry + +-- Discharge side conditions from local hypotheses. +/-- +info: Try this: rw [h p] +-- "no goals" +-/ +#guard_msgs in +example {P : Prop} (p : P) (h : P → 1 = 2) : 2 = 1 := by + rw? + +-- Use `solve_by_elim` to discharge side conditions. +/-- +info: Try this: rw [h (f p)] +-- "no goals" +-/ +#guard_msgs in +example {P Q : Prop} (p : P) (f : P → Q) (h : Q → 1 = 2) : 2 = 1 := by + rw? + +-- Rewrite in reverse, discharging side conditions from local hypotheses. +/-- +info: Try this: rw [← h₁ p] +-- Q a +-/ +#guard_msgs in +example {P : Prop} (p : P) (Q : α → Prop) (a b : α) (h₁ : P → a = b) (w : Q a) : Q b := by + rw? + exact w