fix: match-expression when patterns cover all cases of a BitVec finite type (#3538)
This commit is contained in:
parent
e53ae5d89e
commit
37cd4cc996
2 changed files with 97 additions and 11 deletions
|
|
@ -139,15 +139,21 @@ private def isValueTransition (p : Problem) : Bool :=
|
|||
| .var _ :: _ => true
|
||||
| _ => false
|
||||
|
||||
private def isFinValueTransition (p : Problem) : MetaM Bool := do
|
||||
private def isValueOnlyTransitionCore (p : Problem) (isValue : Expr → MetaM Bool) : MetaM Bool := do
|
||||
if hasVarPattern p then return false
|
||||
if !hasValPattern p then return false
|
||||
p.alts.allM fun alt => do
|
||||
match alt.patterns with
|
||||
| .val v :: _ => return (← getFinValue? v).isSome
|
||||
| .val v :: _ => isValue v
|
||||
| .ctor .. :: _ => return true
|
||||
| _ => return false
|
||||
|
||||
private def isFinValueTransition (p : Problem) : MetaM Bool :=
|
||||
isValueOnlyTransitionCore p fun e => return (← getFinValue? e).isSome
|
||||
|
||||
private def isBitVecValueTransition (p : Problem) : MetaM Bool :=
|
||||
isValueOnlyTransitionCore p fun e => return (← getBitVecValue? e).isSome
|
||||
|
||||
private def isArrayLitTransition (p : Problem) : Bool :=
|
||||
hasArrayLitPattern p && hasVarPattern p
|
||||
&& p.alts.all fun alt => match alt.patterns with
|
||||
|
|
@ -647,15 +653,18 @@ private def expandIntValuePattern (p : Problem) : MetaM Problem := do
|
|||
|
||||
private def expandFinValuePattern (p : Problem) : MetaM Problem := do
|
||||
let alts ← p.alts.mapM fun alt => do
|
||||
match alt.patterns with
|
||||
| .val n :: ps =>
|
||||
match (← getFinValue? n) with
|
||||
| some ⟨n, v⟩ =>
|
||||
let p ← mkLt (toExpr v.val) (toExpr n)
|
||||
let h ← mkDecideProof p
|
||||
return { alt with patterns := .ctor ``Fin.mk [] [toExpr n] [.val (toExpr v.val), .inaccessible h] :: ps }
|
||||
| _ => return alt
|
||||
| _ => return alt
|
||||
let .val n :: ps := alt.patterns | return alt
|
||||
let some ⟨n, v⟩ ← getFinValue? n | return alt
|
||||
let p ← mkLt (toExpr v.val) (toExpr n)
|
||||
let h ← mkDecideProof p
|
||||
return { alt with patterns := .ctor ``Fin.mk [] [toExpr n] [.val (toExpr v.val), .inaccessible h] :: ps }
|
||||
return { p with alts := alts }
|
||||
|
||||
private def expandBitVecValuePattern (p : Problem) : MetaM Problem := do
|
||||
let alts ← p.alts.mapM fun alt => do
|
||||
let .val n :: ps := alt.patterns | return alt
|
||||
let some ⟨_, v⟩ ← getBitVecValue? n | return alt
|
||||
return { alt with patterns := .ctor ``BitVec.ofFin [] [] [.val (toExpr v.toFin)] :: ps }
|
||||
return { p with alts := alts }
|
||||
|
||||
private def traceStep (msg : String) : StateRefT State MetaM Unit := do
|
||||
|
|
@ -710,6 +719,9 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do
|
|||
else if (← isFinValueTransition p) then
|
||||
traceStep ("fin value to constructor")
|
||||
process (← expandFinValuePattern p)
|
||||
else if (← isBitVecValueTransition p) then
|
||||
traceStep ("bitvec value to constructor")
|
||||
process (← expandBitVecValuePattern p)
|
||||
else if !isNextVar p then
|
||||
traceStep ("non variable")
|
||||
let p ← processNonVariable p
|
||||
|
|
|
|||
74
tests/lean/run/match_lit_fin_cover.lean
Normal file
74
tests/lean/run/match_lit_fin_cover.lean
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
/-
|
||||
Test for match-expression when we conver all possible
|
||||
values of a `Fin` or `BitVec` type.
|
||||
-/
|
||||
|
||||
def boo (x : Fin 3) : Nat :=
|
||||
match x with
|
||||
| 0 => 1
|
||||
| 1 => 2
|
||||
| 2 => 4
|
||||
|
||||
@[simp] def bla (x : Fin 3) (y : Nat) : Nat :=
|
||||
match x, y with
|
||||
| 0, _ => 10
|
||||
| 1, _ => 20
|
||||
| 2, 0 => 30
|
||||
| 2, y+1 => bla x y + 1
|
||||
|
||||
/--
|
||||
info: bla._eq_1 (y : Nat) : bla 0 y = 10
|
||||
-/
|
||||
#guard_msgs in
|
||||
#check bla._eq_1
|
||||
|
||||
/--
|
||||
info: bla._eq_4 (y_2 : Nat) : bla 2 (Nat.succ y_2) = bla 2 y_2 + 1
|
||||
-/
|
||||
#guard_msgs in
|
||||
#check bla._eq_4
|
||||
|
||||
open BitVec
|
||||
|
||||
def foo (x : BitVec 3) : Nat :=
|
||||
match x with
|
||||
| 0b000#3 => 7
|
||||
| 0b001#3 => 6
|
||||
| 0b010#3 => 5
|
||||
| 0b011#3 => 4
|
||||
| 0b100#3 => 3
|
||||
| 0b101#3 => 2
|
||||
| 0b110#3 => 1
|
||||
| 0b111#3 => 0
|
||||
|
||||
def foo' (x : BitVec 3) (y : Nat) : Nat :=
|
||||
match x, y with
|
||||
| 0b000#3, _ => 7
|
||||
| 0b001#3, _ => 6
|
||||
| 0b010#3, _ => 5
|
||||
| 0b011#3, _ => 4
|
||||
| 0b100#3, _ => 3
|
||||
| 0b101#3, _ => 2
|
||||
| 0b110#3, _ => 1
|
||||
| 0b111#3, 0 => 0
|
||||
| 0b111#3, y+1 => foo' 7 y + 1
|
||||
|
||||
attribute [simp] foo'
|
||||
|
||||
/--
|
||||
info: foo'._eq_1 (y : Nat) : foo' (0#3) y = 7
|
||||
-/
|
||||
#guard_msgs in
|
||||
#check foo'._eq_1
|
||||
|
||||
/--
|
||||
info: foo'._eq_2 (y : Nat) : foo' (1#3) y = 6
|
||||
-/
|
||||
#guard_msgs in
|
||||
#check foo'._eq_2
|
||||
|
||||
/--
|
||||
info: foo'._eq_9 (y_2 : Nat) : foo' (7#3) (Nat.succ y_2) = foo' 7 y_2 + 1
|
||||
-/
|
||||
#guard_msgs in
|
||||
#check foo'._eq_9
|
||||
Loading…
Add table
Reference in a new issue