fix: bug at mkCongrSimpCore? (#9395)

This PR fixes a bug at `mkCongrSimpCore?`. It fixes the issue reported
by @joehendrix at #9388.
The fix is just commit: afc4ba617fe2ca5828e0e252558d893d7791d56b. The
rest of the PR is just cleaning up the file.

closes #9388
This commit is contained in:
Leonardo de Moura 2025-07-15 17:54:31 -07:00 committed by GitHub
parent 62ded77e81
commit dc2f256448
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 90 additions and 23 deletions

View file

@ -87,10 +87,10 @@ where
let yType := (← inferType y).cleanupAnnotations
if xType == yType then
withLocalDeclD ((`e).appendIndexAfter (i+1)) (← mkEq x y) fun h =>
loop (i+1) (eqs.push h) (kinds.push CongrArgKind.eq)
loop (i+1) (eqs.push h) (kinds.push .eq)
else
withLocalDeclD ((`e).appendIndexAfter (i+1)) (← mkHEq x y) fun h =>
loop (i+1) (eqs.push h) (kinds.push CongrArgKind.heq)
loop (i+1) (eqs.push h) (kinds.push .heq)
else
k eqs kinds
loop 0 #[] #[]
@ -120,23 +120,24 @@ def mkHCongr (f : Expr) : MetaM CongrTheorem := do
mkHCongrWithArity f (← getFunInfo f).getArity
/--
Ensure that all dependencies for `congr_arg_kind::Eq` are `congr_arg_kind::Fixed`.
Ensures all dependencies for `.eq` are `.fixed`.
-/
private def fixKindsForDependencies (info : FunInfo) (kinds : Array CongrArgKind) : Array CongrArgKind := Id.run do
let mut kinds := kinds
for i in *...info.paramInfo.size do
for hj : j in (i+1)...info.paramInfo.size do
if info.paramInfo[j].backDeps.contains i then
if kinds[j]! matches CongrArgKind.eq || kinds[j]! matches CongrArgKind.fixed then
if kinds[j]! matches .eq | .fixed then
-- We must fix `i` because there is a `j` that depends on `i` and `j` is not cast-fixed.
kinds := kinds.set! i CongrArgKind.fixed
kinds := kinds.set! i .fixed
break
return kinds
/--
(Try to) cast expression `e` to the given type using the equations `eqs`.
`deps` contains the indices of the relevant equalities.
Remark: deps is sorted. -/
(Tries to) cast expression `e` to the given type using the equations `eqs`.
`deps` contains the indices of the relevant equalities.
Remark: deps is sorted.
-/
private partial def mkCast (e : Expr) (type : Expr) (deps : Array Nat) (eqs : Array (Option Expr)) : MetaM Expr := do
let rec go (i : Nat) (type : Expr) : MetaM Expr := do
if i < deps.size then
@ -144,7 +145,7 @@ private partial def mkCast (e : Expr) (type : Expr) (deps : Array Nat) (eqs : Ar
| none => go (i+1) type
| some major =>
let some (_, lhs, rhs) := (← inferType major).eq? | unreachable!
if (← dependsOn type major.fvarId!) then
if (← pure major.isFVar <&&> dependsOn type major.fvarId!) then
let motive ← mkLambdaFVars #[rhs, major] type
let typeNew := type.replaceFVar rhs lhs |>.replaceFVar major (← mkEqRefl lhs)
let minor ← go (i+1) typeNew
@ -158,19 +159,21 @@ private partial def mkCast (e : Expr) (type : Expr) (deps : Array Nat) (eqs : Ar
return e
go 0 type
/-- Returns `true` if `kinds` contains `.cast` or `.subsingletonInst` -/
private def hasCastLike (kinds : Array CongrArgKind) : Bool :=
kinds.any fun kind => kind matches CongrArgKind.cast || kind matches CongrArgKind.subsingletonInst
kinds.any fun kind => kind matches .cast | .subsingletonInst
private def withNext (type : Expr) (k : Expr → Expr → MetaM α) : MetaM α := do
forallBoundedTelescope type (some 1) (cleanupAnnotations := true) fun xs type => k xs[0]! type
/--
Test whether we should use `subsingletonInst` kind for instances which depend on `eq`.
(Otherwise `fixKindsForDependencies`will downgrade them to Fixed -/
Tests whether we should use `subsingletonInst` kind for instances which depend on `eq`.
(Otherwise `fixKindsForDependencies`will downgrade them to Fixed
-/
private def shouldUseSubsingletonInst (info : FunInfo) (kinds : Array CongrArgKind) (i : Nat) : Bool := Id.run do
if info.paramInfo[i]!.isDecInst then
for j in info.paramInfo[i]!.backDeps do
if kinds[j]! matches CongrArgKind.eq then
if kinds[j]! matches .eq then
return true
return false
@ -196,7 +199,7 @@ private def getClassSubobjectMask? (f : Expr) : MetaM (Option (Array Bool)) := d
mask := mask.push (isSubobjectField? env val.induct localDecl.userName).isSome
return some mask
/-- Compute `CongrArgKind`s for a simp congruence theorem. -/
/-- Computes `CongrArgKind`s for a simp congruence theorem. -/
def getCongrSimpKinds (f : Expr) (info : FunInfo) : MetaM (Array CongrArgKind) := do
/-
The default `CongrArgKind` is `eq`, which allows `simp` to rewrite this
@ -223,7 +226,7 @@ def getCongrSimpKinds (f : Expr) (info : FunInfo) : MetaM (Array CongrArgKind) :
if let some mask := mask? then
if h2 : i < mask.size then
if mask[i] then
-- Parameter is a subobect field of a class constructor. See comment above.
-- Parameter is a subobject field of a class constructor. See comment above.
result := result.push .eq
continue
if shouldUseSubsingletonInst info result i then
@ -236,8 +239,8 @@ def getCongrSimpKinds (f : Expr) (info : FunInfo) : MetaM (Array CongrArgKind) :
/--
Variant of `getCongrSimpKinds` for rewriting just argument 0.
If it is possible to rewrite, the 0th `CongrArgKind` is `CongrArgKind.eq`,
and otherwise it is `CongrArgKind.fixed`. This is used for the `arg` conv tactic.
If it is possible to rewrite, the 0th `CongrArgKind` is `.eq`,
and otherwise it is `.fixed`. This is used for the `arg` conv tactic.
-/
def getCongrSimpKindsForArgZero (info : FunInfo) : MetaM (Array CongrArgKind) := do
let mut result := #[]
@ -258,15 +261,14 @@ def getCongrSimpKindsForArgZero (info : FunInfo) : MetaM (Array CongrArgKind) :=
return fixKindsForDependencies info result
/--
Create a congruence theorem that is useful for the simplifier and `congr` tactic.
Creates a congruence theorem that is useful for the simplifier and `congr` tactic.
-/
partial def mkCongrSimpCore? (f : Expr) (info : FunInfo) (kinds : Array CongrArgKind) (subsingletonInstImplicitRhs : Bool := true) : MetaM (Option CongrTheorem) := do
if let some result ← mk? f info kinds then
return some result
else if hasCastLike kinds then
-- Simplify kinds and try again
let kinds := kinds.map fun kind =>
if kind matches CongrArgKind.cast || kind matches CongrArgKind.subsingletonInst then CongrArgKind.fixed else kind
let kinds := kinds.map fun kind => if kind matches .cast | .subsingletonInst then .fixed else kind
mk? f info kinds
else
return none
@ -307,10 +309,14 @@ where
| .subsingletonInst =>
-- The `lhs` does not need to instance implicit since it can be inferred from the LHS
withNewBinderInfos #[(lhss[i]!.fvarId!, .implicit)] do
let rhsType := (← inferType lhss[i]!).replaceFVars (lhss[*...rhss.size]) rhss
let lhs := lhss[i]!
let lhsType ← inferType lhs
let rhsType := lhsType.replaceFVars (lhss[*...rhss.size]) rhss
let rhsBi := if subsingletonInstImplicitRhs then .instImplicit else .implicit
withLocalDecl (← lhss[i]!.fvarId!.getDecl).userName rhsBi rhsType fun rhs =>
go (i+1) (rhss.push rhs) (eqs.push none) (hyps.push rhs)
withLocalDecl (← lhss[i]!.fvarId!.getDecl).userName rhsBi rhsType fun rhs => do
let lhs' ← mkCast lhs rhsType info.paramInfo[i]!.backDeps eqs
let heq ← mkAppM ``Subsingleton.elim #[lhs', rhs]
go (i+1) (rhss.push rhs) (eqs.push heq) (hyps.push rhs)
return some (← go 0 #[] #[] #[])
catch _ =>
return none

View file

@ -0,0 +1,61 @@
set_option warn.sorry false
set_option pp.proofs true
inductive Expr (Identifier : Type) : Type where
| mk (c : String)
def fv {I:Type} (e : Expr I) : List I := sorry
def eql {I:Type} [inst : DecidableEq I] (e : Expr I) (_h1 : fv e == []) : Nat := sorry
def eval {I:Type} [inst : DecidableEq I] (n : Nat) (e : Expr I) : Nat :=
match n with
| 0 => 0
| Nat.succ n' =>
let e2' := eval n' e
eql e sorry
termination_by n
/--
info: eql.congr_simp {I : Type} {inst : DecidableEq I} [inst✝ : DecidableEq I] (e e✝ : Expr I) (e_e : e = e✝)
(_h1 : (fv e == []) = true) : eql e _h1 = eql e✝ (Subsingleton.elim inst inst✝ ▸ e_e ▸ _h1)
-/
#guard_msgs in
#check eql.congr_simp
/--
info: eval.congr_simp {I : Type} {inst : DecidableEq I} [inst✝ : DecidableEq I] (n n✝ : Nat) (e_n : n = n✝) (e e✝ : Expr I)
(e_e : e = e✝) : eval n e = eval n✝ e✝
-/
#guard_msgs in
#check eval.congr_simp
def test4 {α} [DecidableEq α] (x : Nat) : Nat := sorry
/--
info: test4.congr_simp.{u_1} {α α✝ : Sort u_1} (e_α : α = α✝) {inst✝ : DecidableEq α} [DecidableEq α✝] (x x✝ : Nat)
(e_x : x = x✝) : test4 x = test4 x✝
-/
#guard_msgs in
#check test4.congr_simp
structure Dep (p : Prop) [Decidable p] : Type where
def test5 {p} [Decidable p] (x : Dep p) : Nat := sorry
/--
info: test5.congr_simp {p : Prop} [Decidable p] (x x✝ : Dep p) (e_x : x = x✝) : test5 x = test5 x✝
-/
#guard_msgs in
#check test5.congr_simp
def test6 (x y : Nat) : Fin x := sorry
/-- info: test6.congr_simp (x y y✝ : Nat) (e_y : y = y✝) : test6 x y = test6 x y✝ -/
#guard_msgs in
#check test6.congr_simp
def test7 {α : Type u} [i : DecidableEq α] {x : α} (h : (x == x) = true) : Nat := sorry
/--
info: test7.congr_simp.{u} {α : Type u} {i : DecidableEq α} [i✝ : DecidableEq α] {x x✝ : α} (e_x : x = x✝)
(h : (x == x) = true) : test7 h = test7 (Subsingleton.elim i i✝ ▸ e_x ▸ h)
-/
#guard_msgs in
#check test7.congr_simp