From a062eea204a9c982e318c0c98cd6af63e9f61050 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 19 Jan 2025 13:29:24 -0800 Subject: [PATCH] feat: beta reduction in `grind` (#6700) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds support for beta reduction in the `grind` tactic. `grind` can now solve goals such as ```lean example (f : Nat → Nat) : f = (fun x : Nat => x + 5) → f 2 > 5 := by grind ``` --- src/Lean/Meta/Tactic/Grind.lean | 2 + src/Lean/Meta/Tactic/Grind/Beta.lean | 77 +++++++++++++++++++++ src/Lean/Meta/Tactic/Grind/Core.lean | 35 +++++++++- src/Lean/Meta/Tactic/Grind/Internalize.lean | 5 +- src/Lean/Meta/Tactic/Grind/Types.lean | 34 ++++++--- tests/lean/run/grind_beta.lean | 72 +++++++++++++++++++ tests/lean/run/grind_canon_insts.lean | 6 +- tests/lean/run/grind_canon_types.lean | 2 +- tests/lean/run/grind_congr.lean | 2 +- tests/lean/run/grind_many_eqs.lean | 2 +- tests/lean/run/grind_nested_proofs.lean | 2 +- tests/lean/run/grind_norm_levels.lean | 2 +- 12 files changed, 221 insertions(+), 20 deletions(-) create mode 100644 src/Lean/Meta/Tactic/Grind/Beta.lean create mode 100644 tests/lean/run/grind_beta.lean diff --git a/src/Lean/Meta/Tactic/Grind.lean b/src/Lean/Meta/Tactic/Grind.lean index 0cd55eef6e..0bb04bafe0 100644 --- a/src/Lean/Meta/Tactic/Grind.lean +++ b/src/Lean/Meta/Tactic/Grind.lean @@ -53,6 +53,7 @@ builtin_initialize registerTraceClass `grind.offset.propagate builtin_initialize registerTraceClass `grind.offset.eq builtin_initialize registerTraceClass `grind.offset.eq.to (inherited := true) builtin_initialize registerTraceClass `grind.offset.eq.from (inherited := true) +builtin_initialize registerTraceClass `grind.beta /-! Trace options for `grind` developers -/ builtin_initialize registerTraceClass `grind.debug @@ -68,5 +69,6 @@ builtin_initialize registerTraceClass `grind.debug.canon builtin_initialize registerTraceClass `grind.debug.offset builtin_initialize registerTraceClass `grind.debug.offset.proof builtin_initialize registerTraceClass `grind.debug.ematch.pattern +builtin_initialize registerTraceClass `grind.debug.beta end Lean diff --git a/src/Lean/Meta/Tactic/Grind/Beta.lean b/src/Lean/Meta/Tactic/Grind/Beta.lean new file mode 100644 index 0000000000..326c489bd8 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Beta.lean @@ -0,0 +1,77 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.Grind.Types + +namespace Lean.Meta.Grind + +/-- Returns all lambda expressions in the equivalence class with root `root`. -/ +def getEqcLambdas (root : ENode) : GoalM (Array Expr) := do + unless root.hasLambdas do return #[] + foldEqc root.self (init := #[]) fun n lams => + if n.self.isLambda then return lams.push n.self else return lams + +/-- +Returns the root of the functions in the equivalence class containing `e`. +That is, if `f a` is in `root`s equivalence class, results contains the root of `f`. +-/ +def getFnRoots (e : Expr) : GoalM (Array Expr) := do + foldEqc e (init := #[]) fun n fns => do + let fn := n.self.getAppFn + let fnRoot := (← getRoot? fn).getD fn + if Option.isNone <| fns.find? (isSameExpr · fnRoot) then + return fns.push fnRoot + else + return fns + +/-- +For each `lam` in `lams` s.t. `lam` and `f` are in the same equivalence class, +propagate `f args = lam args`. +-/ +def propagateBetaEqs (lams : Array Expr) (f : Expr) (args : Array Expr) : GoalM Unit := do + if args.isEmpty then return () + for lam in lams do + let rhs := lam.beta args + unless rhs.isLambda do + let mut gen := Nat.max (← getGeneration lam) (← getGeneration f) + let lhs := mkAppN f args + if (← hasSameType f lam) then + let mut h ← mkEqProof f lam + for arg in args do + gen := Nat.max gen (← getGeneration arg) + h ← mkCongrFun h arg + let eq ← mkEq lhs rhs + trace[grind.beta] "{eq}, using {lam}" + addNewFact h eq (gen+1) + +private def isPropagateBetaTarget (e : Expr) : GoalM Bool := do + let .app f _ := e | return false + go f +where + go (f : Expr) : GoalM Bool := do + if let some root ← getRootENode? f then + return root.hasLambdas + let .app f _ := f | return false + go f + +/-- +Applies beta-reduction for lambdas in `f`s equivalence class. +We use this function while internalizing new applications. +-/ +def propagateBetaForNewApp (e : Expr) : GoalM Unit := do + unless (← isPropagateBetaTarget e) do return () + let mut e := e + let mut args := #[] + repeat + unless args.isEmpty do + if let some root ← getRootENode? e then + if root.hasLambdas then + propagateBetaEqs (← getEqcLambdas root) e args.reverse + let .app f arg := e | return () + e := f + args := args.push arg + +end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Core.lean b/src/Lean/Meta/Tactic/Grind/Core.lean index e960ce2d50..e5158d2de9 100644 --- a/src/Lean/Meta/Tactic/Grind/Core.lean +++ b/src/Lean/Meta/Tactic/Grind/Core.lean @@ -11,6 +11,7 @@ import Lean.Meta.Tactic.Grind.Inv import Lean.Meta.Tactic.Grind.PP import Lean.Meta.Tactic.Grind.Ctor import Lean.Meta.Tactic.Grind.Util +import Lean.Meta.Tactic.Grind.Beta import Lean.Meta.Tactic.Grind.Internalize namespace Lean.Meta.Grind @@ -40,7 +41,7 @@ Remove `root` parents from the congruence table. This is an auxiliary function performed while merging equivalence classes. -/ private def removeParents (root : Expr) : GoalM ParentSet := do - let parents ← getParentsAndReset root + let parents ← getParents root for parent in parents do -- Recall that we may have `Expr.forallE` in `parents` because of `ForallProp.lean` if (← pure parent.isApp <&&> isCongrRoot parent) then @@ -107,6 +108,31 @@ private def propagateOffsetEq (rhsRoot lhsRoot : ENode) : GoalM Unit := do if let some rhsOffset := rhsRoot.offset? then Arith.processNewOffsetEqLit rhsOffset lhsRoot.self +/-- +Tries to apply beta-reductiong using the parent applications of the functions in `fns` with +the lambda expressions in `lams`. +-/ +def propagateBeta (lams : Array Expr) (fns : Array Expr) : GoalM Unit := do + if lams.isEmpty then return () + let lamRoot ← getRoot lams.back! + trace[grind.debug.beta] "fns: {fns}, lams: {lams}" + for fn in fns do + trace[grind.debug.beta] "fn: {fn}, parents: {(← getParents fn).toArray}" + for parent in (← getParents fn) do + let mut args := #[] + let mut curr := parent + trace[grind.debug.beta] "parent: {parent}" + repeat + trace[grind.debug.beta] "curr: {curr}" + if (← isEqv curr lamRoot) then + propagateBetaEqs lams curr args.reverse + let .app f arg := curr + | break + -- Remark: recall that we do not eagerly internalize partial applications. + internalize curr (← getGeneration parent) + args := args.push arg + curr := f + private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do let lhsNode ← getENode lhs let rhsNode ← getENode rhs @@ -158,6 +184,10 @@ where proof? := proof flipped } + let lams₁ ← getEqcLambdas lhsRoot + let lams₂ ← getEqcLambdas rhsRoot + let fns₁ ← if lams₁.isEmpty then pure #[] else getFnRoots rhsRoot.self + let fns₂ ← if lams₂.isEmpty then pure #[] else getFnRoots lhsRoot.self let parents ← removeParents lhsRoot.self updateRoots lhs rhsNode.root trace_goal[grind.debug] "{← ppENodeRef lhs} new root {← ppENodeRef rhsNode.root}, {← ppENodeRef (← getRoot lhs)}" @@ -172,6 +202,9 @@ where hasLambdas := rhsRoot.hasLambdas || lhsRoot.hasLambdas heqProofs := isHEq || rhsRoot.heqProofs || lhsRoot.heqProofs } + propagateBeta lams₁ fns₁ + propagateBeta lams₂ fns₂ + resetParentsOf lhsRoot.self copyParentsTo parents rhsNode.root unless (← isInconsistent) do updateMT rhsRoot.self diff --git a/src/Lean/Meta/Tactic/Grind/Internalize.lean b/src/Lean/Meta/Tactic/Grind/Internalize.lean index 705eb29853..7d9f7723e8 100644 --- a/src/Lean/Meta/Tactic/Grind/Internalize.lean +++ b/src/Lean/Meta/Tactic/Grind/Internalize.lean @@ -12,6 +12,7 @@ import Lean.Meta.Match.MatchEqsExt import Lean.Meta.Tactic.Grind.Types import Lean.Meta.Tactic.Grind.Util import Lean.Meta.Tactic.Grind.Canon +import Lean.Meta.Tactic.Grind.Beta import Lean.Meta.Tactic.Grind.Arith.Internalize namespace Lean.Meta.Grind @@ -194,7 +195,7 @@ partial def internalize (e : Expr) (generation : Nat) (parent? : Option Expr := activateTheoremPatterns fName generation else internalize f generation e - registerParent e f + registerParent e f for h : i in [: args.size] do let arg := args[i] internalize arg generation e @@ -204,6 +205,8 @@ partial def internalize (e : Expr) (generation : Nat) (parent? : Option Expr := updateAppMap e Arith.internalize e parent? propagateUp e + propagateBetaForNewApp e + end end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index d055030bbb..663a05ca7f 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -501,8 +501,9 @@ def getENode (e : Expr) : GoalM ENode := do (← get).getENode e /-- Returns the generation of the given term. Is assumes it has been internalized -/ -def getGeneration (e : Expr) : GoalM Nat := - return (← getENode e).generation +def getGeneration (e : Expr) : GoalM Nat := do + let some n ← getENode? e | return 0 + return n.generation /-- Returns `true` if `e` is in the equivalence class of `True`. -/ def isEqTrue (e : Expr) : GoalM Bool := do @@ -519,8 +520,8 @@ def isEqv (a b : Expr) : GoalM Bool := do if isSameExpr a b then return true else - let na ← getENode a - let nb ← getENode b + let some na ← getENode? a | return false + let some nb ← getENode? b | return false return isSameExpr na.root nb.root /-- Returns `true` if the root of its equivalence class. -/ @@ -549,6 +550,11 @@ def getRoot (e : Expr) : GoalM Expr := do def getRootENode (e : Expr) : GoalM ENode := do getENode (← getRoot e) +/-- Returns the root enode in the equivalence class of `e` if it is in an equivalence class. -/ +def getRootENode? (e : Expr) : GoalM (Option ENode) := do + let some n ← getENode? e | return none + getENode? n.root + /-- Returns the next element in the equivalence class of `e` if `e` has been internalized in the given goal. @@ -614,7 +620,7 @@ Records that `parent` is a parent of `child`. This function actually stores the information in the root (aka canonical representative) of `child`. -/ def registerParent (parent : Expr) (child : Expr) : GoalM Unit := do - let some childRoot ← getRoot? child | return () + let childRoot := (← getRoot? child).getD child let parents := if let some parents := (← get).parents.find? { expr := childRoot } then parents else {} modify fun s => { s with parents := s.parents.insert { expr := childRoot } (parents.insert parent) } @@ -628,12 +634,10 @@ def getParents (e : Expr) : GoalM ParentSet := do return parents /-- -Similar to `getParents`, but also removes the entry `e ↦ parents` from the parent map. +Removes the entry `e ↦ parents` from the parent map. -/ -def getParentsAndReset (e : Expr) : GoalM ParentSet := do - let parents ← getParents e +def resetParentsOf (e : Expr) : GoalM Unit := do modify fun s => { s with parents := s.parents.erase { expr := e } } - return parents /-- Copy `parents` to the parents of `root`. @@ -800,6 +804,18 @@ def getENodes : GoalM (Array ENode) := do if isSameExpr n.next e then return () curr := n.next +/-- Folds using `f` and `init` over the equivalence class containing `e` -/ +@[inline] def foldEqc (e : Expr) (init : α) (f : ENode → α → GoalM α) : GoalM α := do + let mut curr := e + let mut r := init + repeat + let n ← getENode curr + r ← f n r + if isSameExpr n.next e then return r + curr := n.next + unreachable! + return r + def forEachENode (f : ENode → GoalM Unit) : GoalM Unit := do let nodes ← getENodes for n in nodes do diff --git a/tests/lean/run/grind_beta.lean b/tests/lean/run/grind_beta.lean new file mode 100644 index 0000000000..6e24e05d22 --- /dev/null +++ b/tests/lean/run/grind_beta.lean @@ -0,0 +1,72 @@ +def f (x : Nat) : Nat → Nat → Nat := + match x with + | 0 => fun _ _ => 0 + | _+1 => fun a b => a + b + +example : f 0 b c = 0 := by + grind [f] + +example : f (a+1) b c = b + c := by + grind [f] + +example : f x b c ≠ b + c → x = a + 1 → False := by + grind [f] + +example : x = a + 1 → f x b c ≠ b + c → False := by + grind [f] + +example : x = a + 1 → f x b c ≠ b + c → False := by + grind [f] + +example : f x b c > 0 → x = 0 → False := by + grind [f] + +example : f x b c > 0 → x ≠ 0 := by + grind [f] + +example (f : Nat → Nat → Nat) : f 2 3 ≠ 5 → f = (fun x y : Nat => x + y) → False := by + grind + +opaque bla : Nat → Nat → Nat → Nat + +/-- +info: [grind.beta] f 2 3 = bla 2 3 2, using fun x y => bla x y x +[grind.beta] f 2 3 = 2 + 3, using fun x y => x + y +-/ +#guard_msgs (info) in +set_option trace.grind.beta true in +example (g h f : Nat → Nat → Nat) : + f 2 3 ≠ 5 → + g = (fun x y : Nat => x + y) → + h = (fun x y => bla x y x) → + g = h → + f = h → + False := by + grind + +example (g h f : Nat → Nat → Nat) : + f 2 3 ≠ 5 → + h = (fun x y => bla x y x) → + g = (fun x y : Nat => x + y) → + g = h → + h = f → + False := by + grind + + +example (f : Nat → Nat → Nat) : f = (fun x y : Nat => x + y) → f 2 3 = 5 := by + grind + +example (f g h : Nat → Nat → Nat) : + h = (fun x y => bla x y x) → + g = (fun x y : Nat => x + y) → + g = h → + h = f → + f 2 3 = 5 := by + grind + +example (f : Nat → Nat) : f = (fun x : Nat => x + 5) → f 2 > 5 := by + grind + +example (f : Nat → Nat → Nat) : f a = (fun x : Nat => x + 5) → f a 2 > 5 := by + grind diff --git a/tests/lean/run/grind_canon_insts.lean b/tests/lean/run/grind_canon_insts.lean index 266f415363..287f735ed6 100644 --- a/tests/lean/run/grind_canon_insts.lean +++ b/tests/lean/run/grind_canon_insts.lean @@ -52,15 +52,13 @@ theorem left_comm [CommMonoid α] (a b c : α) : a * (b * c) = b * (a * c) := by open Lean Meta Elab Tactic Grind in def fallback : Fallback := do - let nodes ← filterENodes fun e => return e.self.isAppOf ``HMul.hMul + let nodes ← filterENodes fun e => return e.self.isApp && e.self.isAppOf ``HMul.hMul trace[Meta.debug] "{nodes.toList.map (·.self)}" (← get).mvarId.admit set_option trace.Meta.debug true -/-- -info: [Meta.debug] [b * c, a * (b * c), d * (b * c)] --/ +/-- info: [Meta.debug] [b * c, a * (b * c), d * (b * c)] -/ #guard_msgs (info) in example (a b c d : Nat) : b * (a * c) = d * (b * c) → False := by rw [left_comm] -- Introduces a new (non-canonical) instance for `Mul Nat` diff --git a/tests/lean/run/grind_canon_types.lean b/tests/lean/run/grind_canon_types.lean index 8870a41c62..05fa1ee350 100644 --- a/tests/lean/run/grind_canon_types.lean +++ b/tests/lean/run/grind_canon_types.lean @@ -5,7 +5,7 @@ def f (a : α) := a open Lean Meta Grind in def fallback : Fallback := do - let nodes ← filterENodes fun e => return e.self.isAppOf ``f + let nodes ← filterENodes fun e => return e.self.isApp && e.self.isAppOf ``f trace[Meta.debug] "{nodes.toList.map (·.self)}" (← get).mvarId.admit diff --git a/tests/lean/run/grind_congr.lean b/tests/lean/run/grind_congr.lean index 4f09c4809f..8c13ab8b7d 100644 --- a/tests/lean/run/grind_congr.lean +++ b/tests/lean/run/grind_congr.lean @@ -6,7 +6,7 @@ def g (a : Nat) := a + a -- Prints the equivalence class containing a `f` application open Lean Meta Grind in def fallback : Fallback := do - let #[n, _] ← filterENodes fun e => return e.self.isAppOf ``f | unreachable! + let #[n, _] ← filterENodes fun e => return e.self.isApp && e.self.isAppOf ``f | unreachable! let eqc ← getEqc n.self trace[Meta.debug] eqc (← get).mvarId.admit diff --git a/tests/lean/run/grind_many_eqs.lean b/tests/lean/run/grind_many_eqs.lean index 25c254917c..efddd3e5f2 100644 --- a/tests/lean/run/grind_many_eqs.lean +++ b/tests/lean/run/grind_many_eqs.lean @@ -14,7 +14,7 @@ def fallback (n : Nat) : Fallback := do -- The `f 0` equivalence class contains `n+1` elements assert! (← getEqc f0).length == n + 1 forEachENode fun node => do - if node.self.isAppOf ``g then + if node.self.isApp && node.self.isAppOf ``g then -- Any equivalence class containing a `g`-application contains 2 elements assert! (← getEqc (← getRoot node.self)).length == 2 (← get).mvarId.admit diff --git a/tests/lean/run/grind_nested_proofs.lean b/tests/lean/run/grind_nested_proofs.lean index 6b4a73326d..062b7b85c3 100644 --- a/tests/lean/run/grind_nested_proofs.lean +++ b/tests/lean/run/grind_nested_proofs.lean @@ -6,7 +6,7 @@ open Lean Meta Grind in def fallback : Fallback := do let nodes ← filterENodes fun e => return e.self.isAppOf ``Lean.Grind.nestedProof trace[Meta.debug] "{nodes.toList.map (·.self)}" - let nodes ← filterENodes fun e => return e.self.isAppOf ``GetElem.getElem + let nodes ← filterENodes fun e => return e.self.isApp && e.self.isAppOf ``GetElem.getElem let [_, n, _] := nodes.toList | unreachable! trace[Meta.debug] "{← getEqc n.self}" (← get).mvarId.admit diff --git a/tests/lean/run/grind_norm_levels.lean b/tests/lean/run/grind_norm_levels.lean index 9e68e92678..73f86e08fb 100644 --- a/tests/lean/run/grind_norm_levels.lean +++ b/tests/lean/run/grind_norm_levels.lean @@ -4,7 +4,7 @@ def g {α : Sort u} (a : α) := a open Lean Meta Grind in def fallback : Fallback := do - let nodes ← filterENodes fun e => return e.self.isAppOf ``g + let nodes ← filterENodes fun e => return e.self.isApp && e.self.isAppOf ``g trace[Meta.debug] "{nodes.toList.map (·.self)}" (← get).mvarId.admit