diff --git a/src/Lean/Meta/Tactic/Grind/Core.lean b/src/Lean/Meta/Tactic/Grind/Core.lean index a4ed1258de..800af64a2f 100644 --- a/src/Lean/Meta/Tactic/Grind/Core.lean +++ b/src/Lean/Meta/Tactic/Grind/Core.lean @@ -240,6 +240,35 @@ def propagateBeta (lams : Array Expr) (fns : Array Expr) : GoalM Unit := do args := args.push arg curr := f +private def getUnitLikeValue? (type : Expr) : GoalM (Option Expr) := do + if let some u? := (← get).unitLike.map.find? { expr := type } then + return u? + else + let u? ← go? + modify fun s => { s with unitLike.map := s.unitLike.map.insert { expr := type } u? } + return u? +where + go? := do + let u ← getLevel type + let sub := mkApp (mkConst ``Subsingleton [u]) type + let some _ ← synthInstance? sub | return none + let inh := mkApp (mkConst ``Inhabited [u]) type + let some d ← synthInstance? inh | return none + let val ← preprocessLight <| mkApp2 (mkConst ``default [u]) type d + return some val + +private def propagateUnitFuns (lams₁ lams₂ : Array Expr) : GoalM Unit := do + if h : lams₁.size = 0 then return () else + if h : lams₂.size = 0 then return () else + let .lam _ d₁ b₁ _ := lams₁[0] | return () + let .lam _ d₂ b₂ _ := lams₂[0] | return () + unless isSameExpr d₁ d₂ do return () + let some u ← getUnitLikeValue? d₁ | return () + let lhs := b₁.instantiate1 u + let rhs := b₂.instantiate1 u + let h ← mkEqProof lams₁[0] lams₂[0] + pushNewFact <| mkExpectedPropHint (← mkCongrFun h u) (← mkEq lhs rhs) + private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do let lhsNode ← getENode lhs let rhsNode ← getENode rhs @@ -330,6 +359,7 @@ where propagateUp parent for e in toPropagateDown do propagateDown e + propagateUnitFuns lams₁ lams₂ propagateOffset offsetTodo propagateCutsat cutsatTodo propagateCommRing ringTodo diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index 05504af703..b1196e59fb 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -734,6 +734,14 @@ structure Clean.State where next : PHashMap Name Nat := {} deriving Inhabited +/-- +Cache for `Unit`-like types. It maps the type to its element. +We say a type is `Unit`-like if it is a subsingleton and is inhabited. +-/ +structure UnitLike.State where + map : PHashMap ExprPtr (Option Expr) := {} + deriving Inhabited + /-- The `grind` goal. -/ structure Goal where mvarId : MVarId @@ -768,6 +776,8 @@ structure Goal where arith : Arith.State := {} /-- State of the clean name generator. -/ clean : Clean.State := {} + /-- `UnitLike` cache -/ + unitLike : UnitLike.State := {} deriving Inhabited def Goal.hasSameRoot (g : Goal) (a b : Expr) : Bool := diff --git a/tests/lean/run/grind_fun_singleton.lean b/tests/lean/run/grind_fun_singleton.lean new file mode 100644 index 0000000000..f5436c529b --- /dev/null +++ b/tests/lean/run/grind_fun_singleton.lean @@ -0,0 +1,34 @@ +example (h : (fun (_ : Unit) => x = 1) = (fun _ => True)) : x = 1 := by + grind + +example + (h₁ : f = fun (_ : Unit) => x = 1) + (h₂ : g = fun (_ : Unit) => True) + (h₃ : f = g) + : x = 1 := by + grind + +example + (h₁ : f = fun (_ : Unit × Unit) => x = 1) + (h₂ : g = fun (_ : Unit × Unit) => True) + (h₃ : f = g) + : x = 1 := by + grind + +example (h : (fun (_ : True → Unit) (_ : Unit) => x + 1) = (fun _ _ => 1 + y)) : x = y := by + grind + +example (h : (fun (_ : Unit) => x + 1) = (fun _ => 1 + y)) : x = y := by + grind + +example (h : (fun (_ : Unit → Unit) => x + 1) = (fun _ => 1 + y)) : x = y := by + grind + +example + (x y z : Nat) + (h₁ : f = fun (_ : Unit × Unit) => x + y) + (h₂ : g = fun (_ : Unit × Unit) => w) + (h₃ : f = g) + (h₄ : f = fun (_ : Unit × Unit) => y + z) + : x = z ∧ x + y = w := by + grind