feat: split on match-expressions in the grind tactic (#6569)

This PR adds support for case splitting on `match`-expressions in
`grind`.
We still need to add support for resolving the antecedents of
`match`-conditional equations.
This commit is contained in:
Leonardo de Moura 2025-01-07 19:10:11 -08:00 committed by GitHub
parent 9040108e2f
commit 00ef231a6e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 121 additions and 19 deletions

View file

@ -21,6 +21,13 @@ def doNotSimp {α : Sort u} (a : α) : α := a
/-- Gadget for representing offsets `t+k` in patterns. -/
def offset (a b : Nat) : Nat := a + b
/--
Gadget for annotating the equalities in `match`-equations conclusions.
`_origin` is the term used to instantiate the `match`-equation using E-matching.
When `EqMatch a b origin` is `True`, we mark `origin` as a resolved case-split.
-/
def EqMatch (a b : α) {_origin : α} : Prop := a = b
theorem nestedProof_congr (p q : Prop) (h : p = q) (hp : p) (hq : q) : HEq (nestedProof p hp) (nestedProof q hq) := by
subst h; apply HEq.refl

View file

@ -33,7 +33,7 @@ private def mkEqAndProof (lhs rhs : Expr) : MetaM (Expr × Expr) := do
else
pure (mkApp4 (mkConst ``HEq [u]) lhsType lhs rhsType rhs, mkApp2 (mkConst ``HEq.refl [u]) lhsType lhs)
private partial def withNewEqs (targets targetsNew : Array Expr) (k : Array Expr → Array Expr → MetaM α) : MetaM α :=
partial def withNewEqs (targets targetsNew : Array Expr) (k : Array Expr → Array Expr → MetaM α) : MetaM α :=
let rec loop (i : Nat) (newEqs : Array Expr) (newRefls : Array Expr) := do
if i < targets.size then
let (newEqType, newRefl) ← mkEqAndProof targets[i]! targetsNew[i]!

View file

@ -23,7 +23,7 @@ import Lean.Meta.Tactic.Grind.Parser
import Lean.Meta.Tactic.Grind.EMatchTheorem
import Lean.Meta.Tactic.Grind.EMatch
import Lean.Meta.Tactic.Grind.Main
import Lean.Meta.Tactic.Grind.CasesMatch
namespace Lean
@ -52,5 +52,6 @@ builtin_initialize registerTraceClass `grind.debug.proj
builtin_initialize registerTraceClass `grind.debug.parent
builtin_initialize registerTraceClass `grind.debug.final
builtin_initialize registerTraceClass `grind.debug.forallPropagator
builtin_initialize registerTraceClass `grind.debug.split
end Lean

View file

@ -0,0 +1,53 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Cases
import Lean.Meta.Match.MatcherApp
namespace Lean.Meta.Grind
def casesMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.withContext do
let some app ← matchMatcherApp? e
| throwTacticEx `grind.casesMatch mvarId m!"`match`-expression expected{indentExpr e}"
let (motive, eqRefls) ← mkMotiveAndRefls app
let target ← mvarId.getType
let mut us := app.matcherLevels
if let some i := app.uElimPos? then
us := us.set! i (← getLevel target)
let splitterName := (← Match.getEquationsFor app.matcherName).splitterName
let splitterApp := mkConst splitterName us.toList
let splitterApp := mkAppN splitterApp app.params
let splitterApp := mkApp splitterApp motive
let splitterApp := mkAppN splitterApp app.discrs
let (mvars, _, _) ← forallMetaBoundedTelescope (← inferType splitterApp) app.alts.size (kind := .syntheticOpaque)
let splitterApp := mkAppN splitterApp mvars
let val := mkAppN splitterApp eqRefls
mvarId.assign val
updateTags mvars
return mvars.toList.map (·.mvarId!)
where
mkMotiveAndRefls (app : MatcherApp) : MetaM (Expr × Array Expr) := do
let dummy := mkSort 0
let aux := mkApp (mkAppN e.getAppFn app.params) dummy
forallBoundedTelescope (← inferType aux) app.discrs.size fun xs _ => do
withNewEqs app.discrs xs fun eqs eqRefls => do
let type ← mvarId.getType
let type ← mkForallFVars eqs type
let motive ← mkLambdaFVars xs type
return (motive, eqRefls)
updateTags (mvars : Array Expr) : MetaM Unit := do
let tag ← mvarId.getTag
if mvars.size == 1 then
mvars[0]!.mvarId!.setTag tag
else
let mut idx := 1
for mvar in mvars do
mvar.mvarId!.setTag (Name.num tag idx)
idx := idx + 1
end Lean.Meta.Grind

View file

@ -199,14 +199,18 @@ private def processContinue (c : Choice) (p : Expr) : M Unit := do
let c := { c with gen := Nat.max gen c.gen }
modify fun s => { s with choiceStack := c :: s.choiceStack }
/-- Helper function for marking parts of `match`-equation theorem as "do-not-simplify" -/
private partial def annotateMatchEqnType (prop : Expr) : M Expr := do
/--
Helper function for marking parts of `match`-equation theorem as "do-not-simplify"
`initApp` is the match-expression used to instantiate the `match`-equation.
-/
private partial def annotateMatchEqnType (prop : Expr) (initApp : Expr) : M Expr := do
if let .forallE n d b bi := prop then
withLocalDecl n bi (← markAsDoNotSimp d) fun x => do
mkForallFVars #[x] (← annotateMatchEqnType (b.instantiate1 x))
mkForallFVars #[x] (← annotateMatchEqnType (b.instantiate1 x) initApp)
else
let_expr f@Eq α lhs rhs := prop | return prop
return mkApp3 f α (← markAsDoNotSimp lhs) rhs
-- See comment at `Grind.EqMatch`
return mkApp4 (mkConst ``Grind.EqMatch f.constLevels!) α (← markAsDoNotSimp lhs) rhs initApp
/--
Stores new theorem instance in the state.
@ -218,9 +222,7 @@ private def addNewInstance (origin : Origin) (proof : Expr) (generation : Nat) :
check proof
let mut prop ← inferType proof
if Match.isMatchEqnTheorem (← getEnv) origin.key then
-- `initApp` is a match-application that we don't need to split at anymore.
markCaseSplitAsResolved (← read).initApp
prop ← annotateMatchEqnType prop
prop ← annotateMatchEqnType prop (← read).initApp
trace_goal[grind.ematch.instance] "{← origin.pp}: {prop}"
addTheoremInstance proof prop (generation+1)

View file

@ -134,6 +134,13 @@ builtin_grind_propagator propagateEqDown ↓Eq := fun e => do
let_expr Eq _ a b := e | return ()
pushEq a b <| mkApp2 (mkConst ``of_eq_true) e (← mkEqTrueProof e)
/-- Propagates `EqMatch` downwards -/
builtin_grind_propagator propagateEqMatchDown ↓Grind.EqMatch := fun e => do
if (← isEqTrue e) then
let_expr Grind.EqMatch _ a b origin := e | return ()
markCaseSplitAsResolved origin
pushEq a b <| mkApp2 (mkConst ``of_eq_true) e (← mkEqTrueProof e)
/-- Propagates `HEq` downwards -/
builtin_grind_propagator propagateHEqDown ↓HEq := fun e => do
if (← isEqTrue e) then

View file

@ -7,6 +7,7 @@ prelude
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Intro
import Lean.Meta.Tactic.Grind.Cases
import Lean.Meta.Tactic.Grind.CasesMatch
namespace Lean.Meta.Grind
@ -50,10 +51,10 @@ private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do
return .ready
| _ =>
if (← isResolvedCaseSplit e) then
trace[grind.debug.split] "split resolved: {e}"
return .resolved
if (← isMatcherApp e) then
return .notReady -- TODO: implement splitters for `match`
-- return .ready
return .ready
let .const declName .. := e.getAppFn | unreachable!
if (← isInductivePredicate declName <&&> isEqTrue e) then
return .ready
@ -111,9 +112,11 @@ def splitNext : GrindTactic := fun goal => do
| return none
let gen ← getGeneration c
trace_goal[grind.split] "{c}, generation: {gen}"
-- TODO: `match`
let major ← mkCasesMajor c
let mvarIds ← cases (← get).mvarId major
let mvarIds ← if (← isMatcherApp c) then
casesMatch (← get).mvarId c
else
let major ← mkCasesMajor c
cases (← get).mvarId major
let goal ← get
let goals := mvarIds.map fun mvarId => { goal with mvarId }
let goals ← introNewHyp goals [] (gen+1)

View file

@ -24,15 +24,16 @@ info: [grind.assert] (match as, bs with
[grind.assert] a₁ :: f 0 = as
[grind.assert] f 0 = a₂ :: f 1
[grind.assert] ¬d = []
[grind.assert] Lean.Grind.EqMatch
(match a₁ :: a₂ :: f 1, [] with
| [], x => bs
| head :: head_1 :: tail, [] => []
| x :: xs, ys => x :: g xs ys)
[]
[grind.split.resolved] match as, bs with
| [], x => bs
| head :: head_1 :: tail, [] => []
| x :: xs, ys => x :: g xs ys
[grind.assert] (match a₁ :: a₂ :: f 1, [] with
| [], x => bs
| head :: head_1 :: tail, [] => []
| x :: xs, ys => x :: g xs ys) =
[]
-/
#guard_msgs (info) in
example (f : Nat → List Nat) : g as bs = d → bs = [] → a₁ :: f 0 = as → f 0 = a₂ :: f 1 → d = [] := by

View file

@ -0,0 +1,28 @@
def g (a : α) (as : List α) : List α :=
match as with
| [] => [a]
| b::bs => a::a::b::bs
set_option trace.grind true in
set_option trace.grind.assert true in
example : ¬ (g a as).isEmpty := by
unfold List.isEmpty
unfold g
grind
def h (as : List Nat) :=
match as with
| [] => 1
| [_] => 2
| _::_::_ => 3
/--
info: [grind] closed `grind.1`
[grind] closed `grind.2`
[grind] closed `grind.3`
-/
#guard_msgs (info) in
set_option trace.grind true in
example : h as ≠ 0 := by
unfold h
grind