fix: grind canonicalizer (#10469)
This PR fixes an incorrect optimization in the `grind` canonicalizer. See the new test for an example that exposes the problem.
This commit is contained in:
parent
c6abc3c036
commit
d898c9ed17
3 changed files with 72 additions and 31 deletions
|
|
@ -74,40 +74,53 @@ If `useIsDefEqBounded` is `true`, we try `isDefEqBounded` before returning false
|
|||
-/
|
||||
private def canonElemCore (parent : Expr) (f : Expr) (i : Nat) (e : Expr) (useIsDefEqBounded : Bool) : GoalM Expr := do
|
||||
let s ← get'
|
||||
if let some c := s.canon.find? e then
|
||||
let key := { f, i, arg := e : CanonArgKey }
|
||||
/-
|
||||
**Note**: We used to use `s.canon.find? e` instead of `s.canonArg.find? key`. This was incorrect.
|
||||
First, for types and implicit arguments, we recursively visit `e` before invoking this function.
|
||||
Thus, `s.canon.find? e` always returns some value `c`, causing us to miss possible canonicalization opportunities.
|
||||
Moreover, `e` may be the argument of two different `f` functions.
|
||||
-/
|
||||
if let some c := s.canonArg.find? key then
|
||||
return c
|
||||
let key := (f, i)
|
||||
let eType ← inferType e
|
||||
let cs := s.argMap.find? key |>.getD []
|
||||
for (c, cType) in cs do
|
||||
/-
|
||||
We first check the types
|
||||
The following checks are a performance bottleneck.
|
||||
For example, in the test `grind_ite.lean`, there are many checks of the form:
|
||||
```
|
||||
w_4 ∈ assign.insert v true → Prop =?= w_1 ∈ assign.insert v false → Prop
|
||||
```
|
||||
where `grind` unfolds the definition of `DHashMap.insert` and `TreeMap.insert`.
|
||||
-/
|
||||
if (← isDefEqD eType cType) then
|
||||
if (← isDefEq e c) then
|
||||
-- We used to check `c.fvarsSubset e` because it is not
|
||||
-- in general safe to replace `e` with `c` if `c` has more free variables than `e`.
|
||||
-- However, we don't revert previously canonicalized elements in the `grind` tactic.
|
||||
-- Moreover, we store the canonicalizer state in the `Goal` because we case-split
|
||||
-- and different locals are added in different branches.
|
||||
modify' fun s => { s with canon := s.canon.insert e c }
|
||||
trace_goal[grind.debug.canon] "found {e} ===> {c}"
|
||||
return c
|
||||
if useIsDefEqBounded then
|
||||
-- If `e` and `c` are not types, we use `isDefEqBounded`
|
||||
if (← isDefEqBounded e c parent) then
|
||||
let c ← go
|
||||
modify' fun s => { s with canonArg := s.canonArg.insert key c }
|
||||
return c
|
||||
where
|
||||
go : GoalM Expr := do
|
||||
let s ← get'
|
||||
let key := (f, i)
|
||||
let eType ← inferType e
|
||||
let cs := s.argMap.find? key |>.getD []
|
||||
for (c, cType) in cs do
|
||||
/-
|
||||
We first check the types
|
||||
The following checks are a performance bottleneck.
|
||||
For example, in the test `grind_ite.lean`, there are many checks of the form:
|
||||
```
|
||||
w_4 ∈ assign.insert v true → Prop =?= w_1 ∈ assign.insert v false → Prop
|
||||
```
|
||||
where `grind` unfolds the definition of `DHashMap.insert` and `TreeMap.insert`.
|
||||
-/
|
||||
if (← isDefEqD eType cType) then
|
||||
if (← isDefEq e c) then
|
||||
-- We used to check `c.fvarsSubset e` because it is not
|
||||
-- in general safe to replace `e` with `c` if `c` has more free variables than `e`.
|
||||
-- However, we don't revert previously canonicalized elements in the `grind` tactic.
|
||||
-- Moreover, we store the canonicalizer state in the `Goal` because we case-split
|
||||
-- and different locals are added in different branches.
|
||||
modify' fun s => { s with canon := s.canon.insert e c }
|
||||
trace_goal[grind.debug.canon] "found using `isDefEqBounded`: {e} ===> {c}"
|
||||
trace_goal[grind.debug.canon] "found {e} ===> {c}"
|
||||
return c
|
||||
trace_goal[grind.debug.canon] "({f}, {i}) ↦ {e}"
|
||||
modify' fun s => { s with canon := s.canon.insert e e, argMap := s.argMap.insert key ((e, eType)::cs) }
|
||||
return e
|
||||
if useIsDefEqBounded then
|
||||
-- If `e` and `c` are not types, we use `isDefEqBounded`
|
||||
if (← isDefEqBounded e c parent) then
|
||||
modify' fun s => { s with canon := s.canon.insert e c }
|
||||
trace_goal[grind.debug.canon] "found using `isDefEqBounded`: {e} ===> {c}"
|
||||
return c
|
||||
trace_goal[grind.debug.canon] "({f}, {i}) ↦ {e}"
|
||||
modify' fun s => { s with canon := s.canon.insert e e, argMap := s.argMap.insert key ((e, eType)::cs) }
|
||||
return e
|
||||
|
||||
private abbrev canonType (parent f : Expr) (i : Nat) (e : Expr) := withDefault <| canonElemCore parent f i e (useIsDefEqBounded := false)
|
||||
private abbrev canonInst (parent f : Expr) (i : Nat) (e : Expr) := withReducibleAndInstances <| canonElemCore parent f i e (useIsDefEqBounded := true)
|
||||
|
|
|
|||
|
|
@ -600,11 +600,18 @@ structure NewRawFact where
|
|||
splitSource : SplitSource
|
||||
deriving Inhabited
|
||||
|
||||
structure CanonArgKey where
|
||||
f : Expr
|
||||
i : Nat
|
||||
arg : Expr
|
||||
deriving BEq, Hashable
|
||||
|
||||
/-- Canonicalizer state. See `Canon.lean` for additional details. -/
|
||||
structure Canon.State where
|
||||
argMap : PHashMap (Expr × Nat) (List (Expr × Expr)) := {}
|
||||
canon : PHashMap Expr Expr := {}
|
||||
proofCanon : PHashMap Expr Expr := {}
|
||||
canonArg : PHashMap CanonArgKey Expr := {}
|
||||
deriving Inhabited
|
||||
|
||||
/-- Trace information for a case split. -/
|
||||
|
|
|
|||
21
tests/lean/run/grind_canon_bug_2.lean
Normal file
21
tests/lean/run/grind_canon_bug_2.lean
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
import Std.Data.ExtHashMap
|
||||
open Std
|
||||
set_option warn.sorry false
|
||||
|
||||
-- The following trace should contain only one `m[k]` and `(m.insert 1 3)[k]`
|
||||
/--
|
||||
trace: [grind.cutsat.model] k := 101
|
||||
[grind.cutsat.model] (ExtHashMap.filter (fun k x => decide (101 ≤ k)) (m.insert 1 3))[k] := 4
|
||||
[grind.cutsat.model] (m.insert 1 2)[k] := 4
|
||||
[grind.cutsat.model] (m.insert 1 3)[k] := 4
|
||||
[grind.cutsat.model] m[k] := 4
|
||||
[grind.cutsat.model] (m.insert 1 2).getKey k ⋯ := 101
|
||||
[grind.cutsat.model] m.getKey k ⋯ := 101
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (m : ExtHashMap Nat Nat) :
|
||||
(m.insert 1 2).filter (fun k _ => k > 1000) = (m.insert 1 3).filter fun k _ => k > 100 := by
|
||||
ext1 k
|
||||
set_option trace.grind.cutsat.model true in
|
||||
fail_if_success grind (splits := 4)
|
||||
sorry
|
||||
Loading…
Add table
Reference in a new issue