From aa5b392e35fc73f74a19146a3b056ced7296dc1f Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 22 Jul 2025 21:10:21 -0700 Subject: [PATCH] fix: canonicalization of non-standard `OfNat.ofNat` terms (#9481) This PR fixes a kernel type mismatch that occurs when using `grind` on goals containing non-standard `OfNat.ofNat` terms. For example, in issue #9477, the `0` in the theorem `range_lower` has the form: ```lean (@OfNat.ofNat (Std.PRange.Bound (Std.PRange.RangeShape.lower (Std.PRange.RangeShape.mk Std.PRange.BoundShape.closed Std.PRange.BoundShape.open)) Nat) (nat_lit 0) (instOfNatNat (nat_lit 0))) ``` instead of the more standard form: ```lean (@OfNat.ofNat Nat (nat_lit 0) (instOfNatNat (nat_lit 0))) ``` Closes #9477 --- src/Lean/Meta/Tactic/Grind/Canon.lean | 42 ++++++++++++++++++++++++++- tests/lean/run/grind_9477.lean | 29 ++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 tests/lean/run/grind_9477.lean diff --git a/src/Lean/Meta/Tactic/Grind/Canon.lean b/src/Lean/Meta/Tactic/Grind/Canon.lean index ca486bd1b9..c213b76c45 100644 --- a/src/Lean/Meta/Tactic/Grind/Canon.lean +++ b/src/Lean/Meta/Tactic/Grind/Canon.lean @@ -156,6 +156,40 @@ def shouldCanon (pinfos : Array ParamInfo) (i : Nat) (arg : Expr) : MetaM Should else return .visit +/-- +Auxiliary function for normalizing the arguments of `OfNat.ofNat` during canonicalization. +This is needed because satellite solvers create `Nat` and `Int` numerals using the +APIs `mkNatLit` and `mkIntLit`, which produce terms of the form +`@OfNat.ofNat Nat inst` and `@OfNat.ofNat Int inst`. +This becomes a problem when a term in the input goal has already been canonicalized +and its type is not exactly `Nat` or `Int`. For example, in issue #9477, we have: +``` +structure T where +upper_bound : Nat +def T.range (a : T) := 0...a.upper_bound +theorem range\_lower (a : T) : a.range.lower = 0 := by rfl +``` +Here, the `0` in `range_lower` is actually represented as: +``` +(@OfNat.ofNat + (Std.PRange.Bound (Std.PRange.RangeShape.lower (Std.PRange.RangeShape.mk Std.PRange.BoundShape.closed Std.PRange.BoundShape.open)) Nat) + (nat_lit 0) + (instOfNatNat (nat_lit 0))) +``` +Without this normalization step, the satellite solver would need to handle multiple +representations for `(0 : Nat)` and `(0 : Int)`, complicating reasoning. +-/ +-- Remark: This is not a great solution. We should consider writing a custom canonicalizer for +-- `OfNat.ofNat` and other constants with built-in support in `grind`. +private def normOfNatArgs? (args : Array Expr) : MetaM (Option (Array Expr)) := do + if h : args.size = 3 then + let inst := args[2] + if (← isInstOfNatNat inst) && !args[0].isConstOf ``Nat then + return some <| args.set 0 Nat.mkType + else if (← isInstOfNatInt inst) && !args[0].isConstOf ``Int then + return some <| args.set 0 Int.mkType + return none + /-- Canonicalizes nested types, type formers, and instances in `e`. -/ partial def canon (e : Expr) : GoalM Expr := do profileitM Exception "grind canon" (← getOptions) do trace_goal[grind.debug.canon] "{e}" @@ -184,8 +218,14 @@ where let e' := if isSameExpr prop prop' then e else mkAppN f (args.set! 0 prop') pure e' else - let pinfos := (← getFunInfo f).paramInfo let mut modified := false + let args ← if f.isConstOf ``OfNat.ofNat then + let some args ← normOfNatArgs? args | pure args + modified := true + pure args + else + pure args + let pinfos := (← getFunInfo f).paramInfo let mut args := args.toVector for h : i in *...args.size do let arg := args[i] diff --git a/tests/lean/run/grind_9477.lean b/tests/lean/run/grind_9477.lean new file mode 100644 index 0000000000..ae7a7c6aa4 --- /dev/null +++ b/tests/lean/run/grind_9477.lean @@ -0,0 +1,29 @@ +structure T where + upper_bound : Nat + +def T.range (a : T) := 0...a.upper_bound + +theorem range_lower (a : T) : a.range.lower = 0 := by rfl + +/-- +info: range_lower (a : T) : + @Eq (Std.PRange.Bound { lower := Std.PRange.BoundShape.closed, upper := Std.PRange.BoundShape.open }.lower Nat) + (@Std.PRange.lower { lower := Std.PRange.BoundShape.closed, upper := Std.PRange.BoundShape.open } Nat a.range) + (@OfNat.ofNat + (Std.PRange.Bound { lower := Std.PRange.BoundShape.closed, upper := Std.PRange.BoundShape.open }.lower Nat) + (nat_lit 0) (instOfNatNat (nat_lit 0))) +-/ +#guard_msgs in +set_option pp.explicit true in +#check range_lower + +set_option warn.sorry false + +#guard_msgs in +theorem test (p : T) (n: Nat) : n ≤ p.range.upper := by + fail_if_success grind only [range_lower] + sorry + +example (p : T) (n: Nat) : n ≥ p.range.lower := by + set_option trace.Meta.debug true in + grind only [range_lower]