feat: beta reduction in grind (#6700)
This PR adds support for beta reduction in the `grind` tactic. `grind` can now solve goals such as ```lean example (f : Nat → Nat) : f = (fun x : Nat => x + 5) → f 2 > 5 := by grind ```
This commit is contained in:
parent
645bdea23c
commit
a062eea204
12 changed files with 221 additions and 20 deletions
|
|
@ -53,6 +53,7 @@ builtin_initialize registerTraceClass `grind.offset.propagate
|
|||
builtin_initialize registerTraceClass `grind.offset.eq
|
||||
builtin_initialize registerTraceClass `grind.offset.eq.to (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.offset.eq.from (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.beta
|
||||
|
||||
/-! Trace options for `grind` developers -/
|
||||
builtin_initialize registerTraceClass `grind.debug
|
||||
|
|
@ -68,5 +69,6 @@ builtin_initialize registerTraceClass `grind.debug.canon
|
|||
builtin_initialize registerTraceClass `grind.debug.offset
|
||||
builtin_initialize registerTraceClass `grind.debug.offset.proof
|
||||
builtin_initialize registerTraceClass `grind.debug.ematch.pattern
|
||||
builtin_initialize registerTraceClass `grind.debug.beta
|
||||
|
||||
end Lean
|
||||
|
|
|
|||
77
src/Lean/Meta/Tactic/Grind/Beta.lean
Normal file
77
src/Lean/Meta/Tactic/Grind/Beta.lean
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Meta.Tactic.Grind.Types
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
/-- Returns all lambda expressions in the equivalence class with root `root`. -/
|
||||
def getEqcLambdas (root : ENode) : GoalM (Array Expr) := do
|
||||
unless root.hasLambdas do return #[]
|
||||
foldEqc root.self (init := #[]) fun n lams =>
|
||||
if n.self.isLambda then return lams.push n.self else return lams
|
||||
|
||||
/--
|
||||
Returns the root of the functions in the equivalence class containing `e`.
|
||||
That is, if `f a` is in `root`s equivalence class, results contains the root of `f`.
|
||||
-/
|
||||
def getFnRoots (e : Expr) : GoalM (Array Expr) := do
|
||||
foldEqc e (init := #[]) fun n fns => do
|
||||
let fn := n.self.getAppFn
|
||||
let fnRoot := (← getRoot? fn).getD fn
|
||||
if Option.isNone <| fns.find? (isSameExpr · fnRoot) then
|
||||
return fns.push fnRoot
|
||||
else
|
||||
return fns
|
||||
|
||||
/--
|
||||
For each `lam` in `lams` s.t. `lam` and `f` are in the same equivalence class,
|
||||
propagate `f args = lam args`.
|
||||
-/
|
||||
def propagateBetaEqs (lams : Array Expr) (f : Expr) (args : Array Expr) : GoalM Unit := do
|
||||
if args.isEmpty then return ()
|
||||
for lam in lams do
|
||||
let rhs := lam.beta args
|
||||
unless rhs.isLambda do
|
||||
let mut gen := Nat.max (← getGeneration lam) (← getGeneration f)
|
||||
let lhs := mkAppN f args
|
||||
if (← hasSameType f lam) then
|
||||
let mut h ← mkEqProof f lam
|
||||
for arg in args do
|
||||
gen := Nat.max gen (← getGeneration arg)
|
||||
h ← mkCongrFun h arg
|
||||
let eq ← mkEq lhs rhs
|
||||
trace[grind.beta] "{eq}, using {lam}"
|
||||
addNewFact h eq (gen+1)
|
||||
|
||||
private def isPropagateBetaTarget (e : Expr) : GoalM Bool := do
|
||||
let .app f _ := e | return false
|
||||
go f
|
||||
where
|
||||
go (f : Expr) : GoalM Bool := do
|
||||
if let some root ← getRootENode? f then
|
||||
return root.hasLambdas
|
||||
let .app f _ := f | return false
|
||||
go f
|
||||
|
||||
/--
|
||||
Applies beta-reduction for lambdas in `f`s equivalence class.
|
||||
We use this function while internalizing new applications.
|
||||
-/
|
||||
def propagateBetaForNewApp (e : Expr) : GoalM Unit := do
|
||||
unless (← isPropagateBetaTarget e) do return ()
|
||||
let mut e := e
|
||||
let mut args := #[]
|
||||
repeat
|
||||
unless args.isEmpty do
|
||||
if let some root ← getRootENode? e then
|
||||
if root.hasLambdas then
|
||||
propagateBetaEqs (← getEqcLambdas root) e args.reverse
|
||||
let .app f arg := e | return ()
|
||||
e := f
|
||||
args := args.push arg
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
|
@ -11,6 +11,7 @@ import Lean.Meta.Tactic.Grind.Inv
|
|||
import Lean.Meta.Tactic.Grind.PP
|
||||
import Lean.Meta.Tactic.Grind.Ctor
|
||||
import Lean.Meta.Tactic.Grind.Util
|
||||
import Lean.Meta.Tactic.Grind.Beta
|
||||
import Lean.Meta.Tactic.Grind.Internalize
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
|
@ -40,7 +41,7 @@ Remove `root` parents from the congruence table.
|
|||
This is an auxiliary function performed while merging equivalence classes.
|
||||
-/
|
||||
private def removeParents (root : Expr) : GoalM ParentSet := do
|
||||
let parents ← getParentsAndReset root
|
||||
let parents ← getParents root
|
||||
for parent in parents do
|
||||
-- Recall that we may have `Expr.forallE` in `parents` because of `ForallProp.lean`
|
||||
if (← pure parent.isApp <&&> isCongrRoot parent) then
|
||||
|
|
@ -107,6 +108,31 @@ private def propagateOffsetEq (rhsRoot lhsRoot : ENode) : GoalM Unit := do
|
|||
if let some rhsOffset := rhsRoot.offset? then
|
||||
Arith.processNewOffsetEqLit rhsOffset lhsRoot.self
|
||||
|
||||
/--
|
||||
Tries to apply beta-reductiong using the parent applications of the functions in `fns` with
|
||||
the lambda expressions in `lams`.
|
||||
-/
|
||||
def propagateBeta (lams : Array Expr) (fns : Array Expr) : GoalM Unit := do
|
||||
if lams.isEmpty then return ()
|
||||
let lamRoot ← getRoot lams.back!
|
||||
trace[grind.debug.beta] "fns: {fns}, lams: {lams}"
|
||||
for fn in fns do
|
||||
trace[grind.debug.beta] "fn: {fn}, parents: {(← getParents fn).toArray}"
|
||||
for parent in (← getParents fn) do
|
||||
let mut args := #[]
|
||||
let mut curr := parent
|
||||
trace[grind.debug.beta] "parent: {parent}"
|
||||
repeat
|
||||
trace[grind.debug.beta] "curr: {curr}"
|
||||
if (← isEqv curr lamRoot) then
|
||||
propagateBetaEqs lams curr args.reverse
|
||||
let .app f arg := curr
|
||||
| break
|
||||
-- Remark: recall that we do not eagerly internalize partial applications.
|
||||
internalize curr (← getGeneration parent)
|
||||
args := args.push arg
|
||||
curr := f
|
||||
|
||||
private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
|
||||
let lhsNode ← getENode lhs
|
||||
let rhsNode ← getENode rhs
|
||||
|
|
@ -158,6 +184,10 @@ where
|
|||
proof? := proof
|
||||
flipped
|
||||
}
|
||||
let lams₁ ← getEqcLambdas lhsRoot
|
||||
let lams₂ ← getEqcLambdas rhsRoot
|
||||
let fns₁ ← if lams₁.isEmpty then pure #[] else getFnRoots rhsRoot.self
|
||||
let fns₂ ← if lams₂.isEmpty then pure #[] else getFnRoots lhsRoot.self
|
||||
let parents ← removeParents lhsRoot.self
|
||||
updateRoots lhs rhsNode.root
|
||||
trace_goal[grind.debug] "{← ppENodeRef lhs} new root {← ppENodeRef rhsNode.root}, {← ppENodeRef (← getRoot lhs)}"
|
||||
|
|
@ -172,6 +202,9 @@ where
|
|||
hasLambdas := rhsRoot.hasLambdas || lhsRoot.hasLambdas
|
||||
heqProofs := isHEq || rhsRoot.heqProofs || lhsRoot.heqProofs
|
||||
}
|
||||
propagateBeta lams₁ fns₁
|
||||
propagateBeta lams₂ fns₂
|
||||
resetParentsOf lhsRoot.self
|
||||
copyParentsTo parents rhsNode.root
|
||||
unless (← isInconsistent) do
|
||||
updateMT rhsRoot.self
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import Lean.Meta.Match.MatchEqsExt
|
|||
import Lean.Meta.Tactic.Grind.Types
|
||||
import Lean.Meta.Tactic.Grind.Util
|
||||
import Lean.Meta.Tactic.Grind.Canon
|
||||
import Lean.Meta.Tactic.Grind.Beta
|
||||
import Lean.Meta.Tactic.Grind.Arith.Internalize
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
|
@ -194,7 +195,7 @@ partial def internalize (e : Expr) (generation : Nat) (parent? : Option Expr :=
|
|||
activateTheoremPatterns fName generation
|
||||
else
|
||||
internalize f generation e
|
||||
registerParent e f
|
||||
registerParent e f
|
||||
for h : i in [: args.size] do
|
||||
let arg := args[i]
|
||||
internalize arg generation e
|
||||
|
|
@ -204,6 +205,8 @@ partial def internalize (e : Expr) (generation : Nat) (parent? : Option Expr :=
|
|||
updateAppMap e
|
||||
Arith.internalize e parent?
|
||||
propagateUp e
|
||||
propagateBetaForNewApp e
|
||||
|
||||
end
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
|
|
|||
|
|
@ -501,8 +501,9 @@ def getENode (e : Expr) : GoalM ENode := do
|
|||
(← get).getENode e
|
||||
|
||||
/-- Returns the generation of the given term. Is assumes it has been internalized -/
|
||||
def getGeneration (e : Expr) : GoalM Nat :=
|
||||
return (← getENode e).generation
|
||||
def getGeneration (e : Expr) : GoalM Nat := do
|
||||
let some n ← getENode? e | return 0
|
||||
return n.generation
|
||||
|
||||
/-- Returns `true` if `e` is in the equivalence class of `True`. -/
|
||||
def isEqTrue (e : Expr) : GoalM Bool := do
|
||||
|
|
@ -519,8 +520,8 @@ def isEqv (a b : Expr) : GoalM Bool := do
|
|||
if isSameExpr a b then
|
||||
return true
|
||||
else
|
||||
let na ← getENode a
|
||||
let nb ← getENode b
|
||||
let some na ← getENode? a | return false
|
||||
let some nb ← getENode? b | return false
|
||||
return isSameExpr na.root nb.root
|
||||
|
||||
/-- Returns `true` if the root of its equivalence class. -/
|
||||
|
|
@ -549,6 +550,11 @@ def getRoot (e : Expr) : GoalM Expr := do
|
|||
def getRootENode (e : Expr) : GoalM ENode := do
|
||||
getENode (← getRoot e)
|
||||
|
||||
/-- Returns the root enode in the equivalence class of `e` if it is in an equivalence class. -/
|
||||
def getRootENode? (e : Expr) : GoalM (Option ENode) := do
|
||||
let some n ← getENode? e | return none
|
||||
getENode? n.root
|
||||
|
||||
/--
|
||||
Returns the next element in the equivalence class of `e`
|
||||
if `e` has been internalized in the given goal.
|
||||
|
|
@ -614,7 +620,7 @@ Records that `parent` is a parent of `child`. This function actually stores the
|
|||
information in the root (aka canonical representative) of `child`.
|
||||
-/
|
||||
def registerParent (parent : Expr) (child : Expr) : GoalM Unit := do
|
||||
let some childRoot ← getRoot? child | return ()
|
||||
let childRoot := (← getRoot? child).getD child
|
||||
let parents := if let some parents := (← get).parents.find? { expr := childRoot } then parents else {}
|
||||
modify fun s => { s with parents := s.parents.insert { expr := childRoot } (parents.insert parent) }
|
||||
|
||||
|
|
@ -628,12 +634,10 @@ def getParents (e : Expr) : GoalM ParentSet := do
|
|||
return parents
|
||||
|
||||
/--
|
||||
Similar to `getParents`, but also removes the entry `e ↦ parents` from the parent map.
|
||||
Removes the entry `e ↦ parents` from the parent map.
|
||||
-/
|
||||
def getParentsAndReset (e : Expr) : GoalM ParentSet := do
|
||||
let parents ← getParents e
|
||||
def resetParentsOf (e : Expr) : GoalM Unit := do
|
||||
modify fun s => { s with parents := s.parents.erase { expr := e } }
|
||||
return parents
|
||||
|
||||
/--
|
||||
Copy `parents` to the parents of `root`.
|
||||
|
|
@ -800,6 +804,18 @@ def getENodes : GoalM (Array ENode) := do
|
|||
if isSameExpr n.next e then return ()
|
||||
curr := n.next
|
||||
|
||||
/-- Folds using `f` and `init` over the equivalence class containing `e` -/
|
||||
@[inline] def foldEqc (e : Expr) (init : α) (f : ENode → α → GoalM α) : GoalM α := do
|
||||
let mut curr := e
|
||||
let mut r := init
|
||||
repeat
|
||||
let n ← getENode curr
|
||||
r ← f n r
|
||||
if isSameExpr n.next e then return r
|
||||
curr := n.next
|
||||
unreachable!
|
||||
return r
|
||||
|
||||
def forEachENode (f : ENode → GoalM Unit) : GoalM Unit := do
|
||||
let nodes ← getENodes
|
||||
for n in nodes do
|
||||
|
|
|
|||
72
tests/lean/run/grind_beta.lean
Normal file
72
tests/lean/run/grind_beta.lean
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
def f (x : Nat) : Nat → Nat → Nat :=
|
||||
match x with
|
||||
| 0 => fun _ _ => 0
|
||||
| _+1 => fun a b => a + b
|
||||
|
||||
example : f 0 b c = 0 := by
|
||||
grind [f]
|
||||
|
||||
example : f (a+1) b c = b + c := by
|
||||
grind [f]
|
||||
|
||||
example : f x b c ≠ b + c → x = a + 1 → False := by
|
||||
grind [f]
|
||||
|
||||
example : x = a + 1 → f x b c ≠ b + c → False := by
|
||||
grind [f]
|
||||
|
||||
example : x = a + 1 → f x b c ≠ b + c → False := by
|
||||
grind [f]
|
||||
|
||||
example : f x b c > 0 → x = 0 → False := by
|
||||
grind [f]
|
||||
|
||||
example : f x b c > 0 → x ≠ 0 := by
|
||||
grind [f]
|
||||
|
||||
example (f : Nat → Nat → Nat) : f 2 3 ≠ 5 → f = (fun x y : Nat => x + y) → False := by
|
||||
grind
|
||||
|
||||
opaque bla : Nat → Nat → Nat → Nat
|
||||
|
||||
/--
|
||||
info: [grind.beta] f 2 3 = bla 2 3 2, using fun x y => bla x y x
|
||||
[grind.beta] f 2 3 = 2 + 3, using fun x y => x + y
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
set_option trace.grind.beta true in
|
||||
example (g h f : Nat → Nat → Nat) :
|
||||
f 2 3 ≠ 5 →
|
||||
g = (fun x y : Nat => x + y) →
|
||||
h = (fun x y => bla x y x) →
|
||||
g = h →
|
||||
f = h →
|
||||
False := by
|
||||
grind
|
||||
|
||||
example (g h f : Nat → Nat → Nat) :
|
||||
f 2 3 ≠ 5 →
|
||||
h = (fun x y => bla x y x) →
|
||||
g = (fun x y : Nat => x + y) →
|
||||
g = h →
|
||||
h = f →
|
||||
False := by
|
||||
grind
|
||||
|
||||
|
||||
example (f : Nat → Nat → Nat) : f = (fun x y : Nat => x + y) → f 2 3 = 5 := by
|
||||
grind
|
||||
|
||||
example (f g h : Nat → Nat → Nat) :
|
||||
h = (fun x y => bla x y x) →
|
||||
g = (fun x y : Nat => x + y) →
|
||||
g = h →
|
||||
h = f →
|
||||
f 2 3 = 5 := by
|
||||
grind
|
||||
|
||||
example (f : Nat → Nat) : f = (fun x : Nat => x + 5) → f 2 > 5 := by
|
||||
grind
|
||||
|
||||
example (f : Nat → Nat → Nat) : f a = (fun x : Nat => x + 5) → f a 2 > 5 := by
|
||||
grind
|
||||
|
|
@ -52,15 +52,13 @@ theorem left_comm [CommMonoid α] (a b c : α) : a * (b * c) = b * (a * c) := by
|
|||
|
||||
open Lean Meta Elab Tactic Grind in
|
||||
def fallback : Fallback := do
|
||||
let nodes ← filterENodes fun e => return e.self.isAppOf ``HMul.hMul
|
||||
let nodes ← filterENodes fun e => return e.self.isApp && e.self.isAppOf ``HMul.hMul
|
||||
trace[Meta.debug] "{nodes.toList.map (·.self)}"
|
||||
(← get).mvarId.admit
|
||||
|
||||
set_option trace.Meta.debug true
|
||||
|
||||
/--
|
||||
info: [Meta.debug] [b * c, a * (b * c), d * (b * c)]
|
||||
-/
|
||||
/-- info: [Meta.debug] [b * c, a * (b * c), d * (b * c)] -/
|
||||
#guard_msgs (info) in
|
||||
example (a b c d : Nat) : b * (a * c) = d * (b * c) → False := by
|
||||
rw [left_comm] -- Introduces a new (non-canonical) instance for `Mul Nat`
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ def f (a : α) := a
|
|||
|
||||
open Lean Meta Grind in
|
||||
def fallback : Fallback := do
|
||||
let nodes ← filterENodes fun e => return e.self.isAppOf ``f
|
||||
let nodes ← filterENodes fun e => return e.self.isApp && e.self.isAppOf ``f
|
||||
trace[Meta.debug] "{nodes.toList.map (·.self)}"
|
||||
(← get).mvarId.admit
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ def g (a : Nat) := a + a
|
|||
-- Prints the equivalence class containing a `f` application
|
||||
open Lean Meta Grind in
|
||||
def fallback : Fallback := do
|
||||
let #[n, _] ← filterENodes fun e => return e.self.isAppOf ``f | unreachable!
|
||||
let #[n, _] ← filterENodes fun e => return e.self.isApp && e.self.isAppOf ``f | unreachable!
|
||||
let eqc ← getEqc n.self
|
||||
trace[Meta.debug] eqc
|
||||
(← get).mvarId.admit
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ def fallback (n : Nat) : Fallback := do
|
|||
-- The `f 0` equivalence class contains `n+1` elements
|
||||
assert! (← getEqc f0).length == n + 1
|
||||
forEachENode fun node => do
|
||||
if node.self.isAppOf ``g then
|
||||
if node.self.isApp && node.self.isAppOf ``g then
|
||||
-- Any equivalence class containing a `g`-application contains 2 elements
|
||||
assert! (← getEqc (← getRoot node.self)).length == 2
|
||||
(← get).mvarId.admit
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ open Lean Meta Grind in
|
|||
def fallback : Fallback := do
|
||||
let nodes ← filterENodes fun e => return e.self.isAppOf ``Lean.Grind.nestedProof
|
||||
trace[Meta.debug] "{nodes.toList.map (·.self)}"
|
||||
let nodes ← filterENodes fun e => return e.self.isAppOf ``GetElem.getElem
|
||||
let nodes ← filterENodes fun e => return e.self.isApp && e.self.isAppOf ``GetElem.getElem
|
||||
let [_, n, _] := nodes.toList | unreachable!
|
||||
trace[Meta.debug] "{← getEqc n.self}"
|
||||
(← get).mvarId.admit
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ def g {α : Sort u} (a : α) := a
|
|||
|
||||
open Lean Meta Grind in
|
||||
def fallback : Fallback := do
|
||||
let nodes ← filterENodes fun e => return e.self.isAppOf ``g
|
||||
let nodes ← filterENodes fun e => return e.self.isApp && e.self.isAppOf ``g
|
||||
trace[Meta.debug] "{nodes.toList.map (·.self)}"
|
||||
(← get).mvarId.admit
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue