From 2d3501be61e02ef43ba4fa6dd8845e29c77d6291 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 5 Aug 2025 09:19:51 -0700 Subject: [PATCH] feat: constant functions in `grind` (#9735) This PR extends the propagation rule implemented in #9699 to constant functions. --- src/Lean/Meta/Tactic/Grind/Core.lean | 51 ++++++++++++------------- src/Lean/Meta/Tactic/Grind/Types.lean | 2 - tests/lean/run/grind_fun_singleton.lean | 6 +++ 3 files changed, 31 insertions(+), 28 deletions(-) diff --git a/src/Lean/Meta/Tactic/Grind/Core.lean b/src/Lean/Meta/Tactic/Grind/Core.lean index cf2dd92656..afc4d6a012 100644 --- a/src/Lean/Meta/Tactic/Grind/Core.lean +++ b/src/Lean/Meta/Tactic/Grind/Core.lean @@ -240,34 +240,33 @@ def propagateBeta (lams : Array Expr) (fns : Array Expr) : GoalM Unit := do args := args.push arg curr := f -private def getUnitLikeValue? (type : Expr) : GoalM (Option Expr) := do - if let some u? := (← get).unitLike.map.find? { expr := type } then - return u? - else - let u? ← go? - modify fun s => { s with unitLike.map := s.unitLike.map.insert { expr := type } u? } - return u? -where - go? := do - let u ← getLevel type - let sub := mkApp (mkConst ``Subsingleton [u]) type - let some _ ← synthInstance? sub | return none - let inh := mkApp (mkConst ``Inhabited [u]) type - let some d ← synthInstance? inh | return none - let val ← preprocessLight <| mkApp2 (mkConst ``default [u]) type d - return some val +private def getFunWithGivenDomain? (lams : Array Expr) (d : Expr) : Option Expr := + lams.find? fun + | .lam _ d' _ _ => isSameExpr d d' + | _ => false -private def propagateUnitFuns (lams₁ lams₂ : Array Expr) : GoalM Unit := do +private def propagateUnitConstFuns (lams₁ lams₂ : Array Expr) : GoalM Unit := do if h : lams₁.size = 0 then return () else if h : lams₂.size = 0 then return () else - let .lam _ d₁ b₁ _ := lams₁[0] | return () - let .lam _ d₂ b₂ _ := lams₂[0] | return () - unless isSameExpr d₁ d₂ do return () - let some u ← getUnitLikeValue? d₁ | return () - let lhs := b₁.instantiate1 u - let rhs := b₂.instantiate1 u - let h ← mkEqProof lams₁[0] lams₂[0] - pushNewFact <| mkExpectedPropHint (← mkCongrFun h u) (← mkEq lhs rhs) + for lam₁ in lams₁ do + -- Remark: we have heterogeneous equivalence classes. So, we may have functions + -- with different domains in the same equivalence class. + let .lam _ d₁ b₁ _ := lam₁ | pure () + let u ← getLevel d₁ + let inh := mkApp (mkConst ``Inhabited [u]) d₁ + let some inhInst ← synthInstance? inh | pure () + let isTarget ← if !b₁.hasLooseBVars then + pure true + else + let sub := mkApp (mkConst ``Subsingleton [u]) d₁ + pure (← synthInstance? sub).isSome + if isTarget then + let some (.lam _ d₁ b₂ _) := getFunWithGivenDomain? lams₂ d₁ | pure () + let val ← preprocessLight <| mkApp2 (mkConst ``default [u]) d₁ inhInst + let lhs := b₁.instantiate1 val + let rhs := b₂.instantiate1 val + let h ← mkEqProof lams₁[0] lams₂[0] + pushNewFact <| mkExpectedPropHint (← mkCongrFun h val) (← mkEq lhs rhs) private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do let lhsNode ← getENode lhs @@ -359,7 +358,7 @@ where propagateUp parent for e in toPropagateDown do propagateDown e - propagateUnitFuns lams₁ lams₂ + propagateUnitConstFuns lams₁ lams₂ propagateOffset offsetTodo propagateCutsat cutsatTodo propagateCommRing ringTodo diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index b1196e59fb..4d256c7c1c 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -776,8 +776,6 @@ structure Goal where arith : Arith.State := {} /-- State of the clean name generator. -/ clean : Clean.State := {} - /-- `UnitLike` cache -/ - unitLike : UnitLike.State := {} deriving Inhabited def Goal.hasSameRoot (g : Goal) (a b : Expr) : Bool := diff --git a/tests/lean/run/grind_fun_singleton.lean b/tests/lean/run/grind_fun_singleton.lean index f5436c529b..6aa79d6ef3 100644 --- a/tests/lean/run/grind_fun_singleton.lean +++ b/tests/lean/run/grind_fun_singleton.lean @@ -32,3 +32,9 @@ example (h₄ : f = fun (_ : Unit × Unit) => y + z) : x = z ∧ x + y = w := by grind + +example [Inhabited α] : ((fun (_ : α) => x = a + 1) = fun (_ : α) => True) → x = a + 1 := by + grind + +example : c = 5 → ((fun (_ : Nat × Nat) => { down := a + c = b + 5 : ULift Prop }) = fun (_ : Nat × Nat) => { down := c < 10 : ULift Prop }) → a = b := by + grind