feat: propagation for functions with singleton domain in grind (#9699)

This PR adds propagation rules for functions that take singleton types.
This feature is useful for discharging verification conditions produced
by `mvcgen`. For example:

```lean
example (h : (fun (_ : Unit) => x + 1) = (fun _ => 1 + y)) : x = y := by
  grind
```
This commit is contained in:
Leonardo de Moura 2025-08-03 14:00:29 +02:00 committed by GitHub
parent af473b085a
commit d0dc5dfd3d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 74 additions and 0 deletions

View file

@ -240,6 +240,35 @@ def propagateBeta (lams : Array Expr) (fns : Array Expr) : GoalM Unit := do
args := args.push arg
curr := f
private def getUnitLikeValue? (type : Expr) : GoalM (Option Expr) := do
if let some u? := (← get).unitLike.map.find? { expr := type } then
return u?
else
let u? ← go?
modify fun s => { s with unitLike.map := s.unitLike.map.insert { expr := type } u? }
return u?
where
go? := do
let u ← getLevel type
let sub := mkApp (mkConst ``Subsingleton [u]) type
let some _ ← synthInstance? sub | return none
let inh := mkApp (mkConst ``Inhabited [u]) type
let some d ← synthInstance? inh | return none
let val ← preprocessLight <| mkApp2 (mkConst ``default [u]) type d
return some val
private def propagateUnitFuns (lams₁ lams₂ : Array Expr) : GoalM Unit := do
if h : lams₁.size = 0 then return () else
if h : lams₂.size = 0 then return () else
let .lam _ d₁ b₁ _ := lams₁[0] | return ()
let .lam _ d₂ b₂ _ := lams₂[0] | return ()
unless isSameExpr d₁ d₂ do return ()
let some u ← getUnitLikeValue? d₁ | return ()
let lhs := b₁.instantiate1 u
let rhs := b₂.instantiate1 u
let h ← mkEqProof lams₁[0] lams₂[0]
pushNewFact <| mkExpectedPropHint (← mkCongrFun h u) (← mkEq lhs rhs)
private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
let lhsNode ← getENode lhs
let rhsNode ← getENode rhs
@ -330,6 +359,7 @@ where
propagateUp parent
for e in toPropagateDown do
propagateDown e
propagateUnitFuns lams₁ lams₂
propagateOffset offsetTodo
propagateCutsat cutsatTodo
propagateCommRing ringTodo

View file

@ -734,6 +734,14 @@ structure Clean.State where
next : PHashMap Name Nat := {}
deriving Inhabited
/--
Cache for `Unit`-like types. It maps the type to its element.
We say a type is `Unit`-like if it is a subsingleton and is inhabited.
-/
structure UnitLike.State where
map : PHashMap ExprPtr (Option Expr) := {}
deriving Inhabited
/-- The `grind` goal. -/
structure Goal where
mvarId : MVarId
@ -768,6 +776,8 @@ structure Goal where
arith : Arith.State := {}
/-- State of the clean name generator. -/
clean : Clean.State := {}
/-- `UnitLike` cache -/
unitLike : UnitLike.State := {}
deriving Inhabited
def Goal.hasSameRoot (g : Goal) (a b : Expr) : Bool :=

View file

@ -0,0 +1,34 @@
example (h : (fun (_ : Unit) => x = 1) = (fun _ => True)) : x = 1 := by
grind
example
(h₁ : f = fun (_ : Unit) => x = 1)
(h₂ : g = fun (_ : Unit) => True)
(h₃ : f = g)
: x = 1 := by
grind
example
(h₁ : f = fun (_ : Unit × Unit) => x = 1)
(h₂ : g = fun (_ : Unit × Unit) => True)
(h₃ : f = g)
: x = 1 := by
grind
example (h : (fun (_ : True → Unit) (_ : Unit) => x + 1) = (fun _ _ => 1 + y)) : x = y := by
grind
example (h : (fun (_ : Unit) => x + 1) = (fun _ => 1 + y)) : x = y := by
grind
example (h : (fun (_ : Unit → Unit) => x + 1) = (fun _ => 1 + y)) : x = y := by
grind
example
(x y z : Nat)
(h₁ : f = fun (_ : Unit × Unit) => x + y)
(h₂ : g = fun (_ : Unit × Unit) => w)
(h₃ : f = g)
(h₄ : f = fun (_ : Unit × Unit) => y + z)
: x = z ∧ x + y = w := by
grind