From 37cd4cc99601ff6c991c3fc1280dc8a95c17d2ab Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 28 Feb 2024 18:24:47 -0800 Subject: [PATCH] fix: match-expression when patterns cover all cases of a `BitVec` finite type (#3538) --- src/Lean/Meta/Match/Match.lean | 34 ++++++++---- tests/lean/run/match_lit_fin_cover.lean | 74 +++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 11 deletions(-) create mode 100644 tests/lean/run/match_lit_fin_cover.lean diff --git a/src/Lean/Meta/Match/Match.lean b/src/Lean/Meta/Match/Match.lean index b3183a08fe..fada4e3b97 100644 --- a/src/Lean/Meta/Match/Match.lean +++ b/src/Lean/Meta/Match/Match.lean @@ -139,15 +139,21 @@ private def isValueTransition (p : Problem) : Bool := | .var _ :: _ => true | _ => false -private def isFinValueTransition (p : Problem) : MetaM Bool := do +private def isValueOnlyTransitionCore (p : Problem) (isValue : Expr → MetaM Bool) : MetaM Bool := do if hasVarPattern p then return false if !hasValPattern p then return false p.alts.allM fun alt => do match alt.patterns with - | .val v :: _ => return (← getFinValue? v).isSome + | .val v :: _ => isValue v | .ctor .. :: _ => return true | _ => return false +private def isFinValueTransition (p : Problem) : MetaM Bool := + isValueOnlyTransitionCore p fun e => return (← getFinValue? e).isSome + +private def isBitVecValueTransition (p : Problem) : MetaM Bool := + isValueOnlyTransitionCore p fun e => return (← getBitVecValue? e).isSome + private def isArrayLitTransition (p : Problem) : Bool := hasArrayLitPattern p && hasVarPattern p && p.alts.all fun alt => match alt.patterns with @@ -647,15 +653,18 @@ private def expandIntValuePattern (p : Problem) : MetaM Problem := do private def expandFinValuePattern (p : Problem) : MetaM Problem := do let alts ← p.alts.mapM fun alt => do - match alt.patterns with - | .val n :: ps => - match (← getFinValue? n) with - | some ⟨n, v⟩ => - let p ← mkLt (toExpr v.val) (toExpr n) - let h ← mkDecideProof p - return { alt with patterns := .ctor ``Fin.mk [] [toExpr n] [.val (toExpr v.val), .inaccessible h] :: ps } - | _ => return alt - | _ => return alt + let .val n :: ps := alt.patterns | return alt + let some ⟨n, v⟩ ← getFinValue? n | return alt + let p ← mkLt (toExpr v.val) (toExpr n) + let h ← mkDecideProof p + return { alt with patterns := .ctor ``Fin.mk [] [toExpr n] [.val (toExpr v.val), .inaccessible h] :: ps } + return { p with alts := alts } + +private def expandBitVecValuePattern (p : Problem) : MetaM Problem := do + let alts ← p.alts.mapM fun alt => do + let .val n :: ps := alt.patterns | return alt + let some ⟨_, v⟩ ← getBitVecValue? n | return alt + return { alt with patterns := .ctor ``BitVec.ofFin [] [] [.val (toExpr v.toFin)] :: ps } return { p with alts := alts } private def traceStep (msg : String) : StateRefT State MetaM Unit := do @@ -710,6 +719,9 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do else if (← isFinValueTransition p) then traceStep ("fin value to constructor") process (← expandFinValuePattern p) + else if (← isBitVecValueTransition p) then + traceStep ("bitvec value to constructor") + process (← expandBitVecValuePattern p) else if !isNextVar p then traceStep ("non variable") let p ← processNonVariable p diff --git a/tests/lean/run/match_lit_fin_cover.lean b/tests/lean/run/match_lit_fin_cover.lean new file mode 100644 index 0000000000..ed219a99ca --- /dev/null +++ b/tests/lean/run/match_lit_fin_cover.lean @@ -0,0 +1,74 @@ +/- +Test for match-expression when we conver all possible +values of a `Fin` or `BitVec` type. +-/ + +def boo (x : Fin 3) : Nat := + match x with + | 0 => 1 + | 1 => 2 + | 2 => 4 + +@[simp] def bla (x : Fin 3) (y : Nat) : Nat := + match x, y with + | 0, _ => 10 + | 1, _ => 20 + | 2, 0 => 30 + | 2, y+1 => bla x y + 1 + +/-- +info: bla._eq_1 (y : Nat) : bla 0 y = 10 +-/ +#guard_msgs in +#check bla._eq_1 + +/-- +info: bla._eq_4 (y_2 : Nat) : bla 2 (Nat.succ y_2) = bla 2 y_2 + 1 +-/ +#guard_msgs in +#check bla._eq_4 + +open BitVec + +def foo (x : BitVec 3) : Nat := + match x with + | 0b000#3 => 7 + | 0b001#3 => 6 + | 0b010#3 => 5 + | 0b011#3 => 4 + | 0b100#3 => 3 + | 0b101#3 => 2 + | 0b110#3 => 1 + | 0b111#3 => 0 + +def foo' (x : BitVec 3) (y : Nat) : Nat := + match x, y with + | 0b000#3, _ => 7 + | 0b001#3, _ => 6 + | 0b010#3, _ => 5 + | 0b011#3, _ => 4 + | 0b100#3, _ => 3 + | 0b101#3, _ => 2 + | 0b110#3, _ => 1 + | 0b111#3, 0 => 0 + | 0b111#3, y+1 => foo' 7 y + 1 + +attribute [simp] foo' + +/-- +info: foo'._eq_1 (y : Nat) : foo' (0#3) y = 7 +-/ +#guard_msgs in +#check foo'._eq_1 + +/-- +info: foo'._eq_2 (y : Nat) : foo' (1#3) y = 6 +-/ +#guard_msgs in +#check foo'._eq_2 + +/-- +info: foo'._eq_9 (y_2 : Nat) : foo' (7#3) (Nat.succ y_2) = foo' 7 y_2 + 1 +-/ +#guard_msgs in +#check foo'._eq_9