fix: match literal pattern support

The equation lemmas were not using the standard representation for literals.
This commit is contained in:
Leonardo de Moura 2024-02-23 22:06:48 -08:00 committed by Leonardo de Moura
parent 66be8b9d4c
commit 056cb75ee0
7 changed files with 158 additions and 67 deletions

View file

@ -473,7 +473,7 @@ partial def normalize (e : Expr) : M Expr := do
let p ← normalize p
addVar h
return mkApp4 e.getAppFn (e.getArg! 0) x p h
else if isMatchValue e then
else if (← isMatchValue e) then
return e
else if e.isFVar then
if (← isExplicitPatternVar e) then
@ -571,8 +571,8 @@ private partial def toPattern (e : Expr) : MetaM Pattern := do
match e.getArg! 1, e.getArg! 3 with
| Expr.fvar x, Expr.fvar h => return Pattern.as x p h
| _, _ => throwError "unexpected occurrence of auxiliary declaration 'namedPattern'"
else if isMatchValue e then
return Pattern.val e
else if (← isMatchValue e) then
return Pattern.val (← normLitValue e)
else if e.isFVar then
return Pattern.var e.fvarId!
else

View file

@ -16,6 +16,7 @@ It also provides support for the following exceptional cases.
- Bit-vectors encoded using `OfNat.ofNat` and `BitVec.ofNat`.
- Negative integers encoded using raw natural numbers.
- Characters encoded `Char.ofNat n` where `n` can be a raw natural number or an `OfNat.ofNat`.
- Nested `Expr.mdata`.
-/
/-- Returns `some n` if `e` is a raw natural number, i.e., it is of the form `.lit (.natVal n)`. -/
@ -26,7 +27,7 @@ def getRawNatValue? (e : Expr) : Option Nat :=
/-- Return `some (n, type)` if `e` is an `OfNat.ofNat`-application encoding `n` for a type with name `typeDeclName`. -/
def getOfNatValue? (e : Expr) (typeDeclName : Name) : MetaM (Option (Nat × Expr)) := OptionT.run do
guard <| e.isAppOfArity ``OfNat.ofNat 3
guard <| e.isAppOfArity' ``OfNat.ofNat 3
let type ← whnfD e.appFn!.appFn!.appArg!
guard <| type.getAppFn.isConstOf typeDeclName
let .lit (.natVal n) := e.appFn!.appArg! | failure
@ -43,15 +44,15 @@ def getNatValue? (e : Expr) : MetaM (Option Nat) := do
def getIntValue? (e : Expr) : MetaM (Option Int) := do
if let some (n, _) ← getOfNatValue? e ``Int then
return some n
if e.isAppOfArity ``Neg.neg 3 then
let some (n, _) ← getOfNatValue? e.appArg! ``Int | return none
if e.isAppOfArity' ``Neg.neg 3 then
let some (n, _) ← getOfNatValue? e.appArg!.consumeMData ``Int | return none
return some (-n)
return none
/-- Return `some c` if `e` is a `Char.ofNat`-application encoding character `c`. -/
def getCharValue? (e : Expr) : MetaM (Option Char) := OptionT.run do
guard <| e.isAppOfArity ``Char.ofNat 1
let n ← getNatValue? e.appArg!
guard <| e.isAppOfArity' ``Char.ofNat 1
let n ← getNatValue? e.appArg!.consumeMData
return Char.ofNat n
/-- Return `some s` if `e` is of the form `.lit (.strVal s)`. -/
@ -70,9 +71,9 @@ def getFinValue? (e : Expr) : MetaM (Option ((n : Nat) × Fin n)) := OptionT.run
/-- Return `some ⟨n, v⟩` if `e` is af `OfNat.ofNat` application encoding a `BitVec n` with value `v` -/
def getBitVecValue? (e : Expr) : MetaM (Option ((n : Nat) × BitVec n)) := OptionT.run do
if e.isAppOfArity ``BitVec.ofNat 2 then
let n ← getNatValue? e.appFn!.appArg!
let v ← getNatValue? e.appArg!
if e.isAppOfArity' ``BitVec.ofNat 2 then
let n ← getNatValue? e.appFn!.appArg!.consumeMData
let v ← getNatValue? e.appArg!.consumeMData
return ⟨n, BitVec.ofNat n v⟩
let (v, type) ← getOfNatValue? e ``BitVec
IO.println v
@ -99,4 +100,22 @@ def getUInt64Value? (e : Expr) : MetaM (Option UInt64) := OptionT.run do
let (n, _) ← getOfNatValue? e ``UInt64
return UInt64.ofNat n
/--
If `e` is literal value, ensure it is encoded using the standard representation.
Otherwise, just return `e`.
-/
def normLitValue (e : Expr) : MetaM Expr := do
let e ← instantiateMVars e
if let some n ← getNatValue? e then return toExpr n
if let some n ← getIntValue? e then return toExpr n
if let some ⟨_, n⟩ ← getFinValue? e then return toExpr n
if let some ⟨_, n⟩ ← getBitVecValue? e then return toExpr n
if let some s := getStringValue? e then return toExpr s
if let some c ← getCharValue? e then return toExpr c
if let some n ← getUInt8Value? e then return toExpr n
if let some n ← getUInt16Value? e then return toExpr n
if let some n ← getUInt32Value? e then return toExpr n
if let some n ← getUInt64Value? e then return toExpr n
return e
end Lean.Meta

View file

@ -343,7 +343,7 @@ partial def toPattern (e : Expr) : MetaM Pattern := do
match e.getArg! 1, e.getArg! 3 with
| Expr.fvar x, Expr.fvar h => return Pattern.as x p h
| _, _ => throwError "unexpected occurrence of auxiliary declaration 'namedPattern'"
else if isMatchValue e then
else if (← isMatchValue e) then
return Pattern.val e
else if e.isFVar then
return Pattern.var e.fvarId!

View file

@ -30,7 +30,7 @@ private def caseValueAux (mvarId : MVarId) (fvarId : FVarId) (value : Expr) (hNa
let tag ← mvarId.getTag
mvarId.checkNotAssigned `caseValue
let target ← mvarId.getType
let xEqValue ← mkEq (mkFVar fvarId) (foldPatValue value)
let xEqValue ← mkEq (mkFVar fvarId) (← normLitValue value)
let xNeqValue := mkApp (mkConst `Not) xEqValue
let thenTarget := Lean.mkForall hName BinderInfo.default xEqValue target
let elseTarget := Lean.mkForall hName BinderInfo.default xNeqValue target

View file

@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.LitValues
import Lean.Meta.Check
import Lean.Meta.Closure
import Lean.Meta.Tactic.Cases
@ -94,10 +95,11 @@ private def hasValPattern (p : Problem) : Bool :=
| .val _ :: _ => true
| _ => false
private def hasNatValPattern (p : Problem) : Bool :=
p.alts.any fun alt => match alt.patterns with
| .val v :: _ => v.isRawNatLit -- TODO: support `OfNat.ofNat`?
| _ => false
private def hasNatValPattern (p : Problem) : MetaM Bool :=
p.alts.anyM fun alt => do
match alt.patterns with
| .val v :: _ => return (← getNatValue? v).isSome
| _ => return false
private def hasVarPattern (p : Problem) : Bool :=
p.alts.any fun alt => match alt.patterns with
@ -137,13 +139,13 @@ private def isArrayLitTransition (p : Problem) : Bool :=
| .var _ :: _ => true
| _ => false
private def isNatValueTransition (p : Problem) : Bool :=
hasNatValPattern p
&& (!isNextVar p ||
private def isNatValueTransition (p : Problem) : MetaM Bool := do
unless (← hasNatValPattern p) do return false
return !isNextVar p ||
p.alts.any fun alt => match alt.patterns with
| .ctor .. :: _ => true
| .inaccessible _ :: _ => true
| _ => false)
| _ => false
private def processSkipInaccessible (p : Problem) : Problem := Id.run do
let x :: xs := p.vars | unreachable!
@ -584,12 +586,16 @@ private def processArrayLit (p : Problem) : MetaM (Array Problem) := do
let newAlts := p.alts.filter isFirstPatternVar
return { p with mvarId := subgoal.mvarId, alts := newAlts, vars := x::xs }
private def expandNatValuePattern (p : Problem) : Problem :=
let alts := p.alts.map fun alt => match alt.patterns with
| .val (.lit (.natVal 0)) :: ps => { alt with patterns := .ctor ``Nat.zero [] [] [] :: ps }
| .val (.lit (.natVal (n+1))) :: ps => { alt with patterns := .ctor ``Nat.succ [] [] [.val (mkRawNatLit n)] :: ps }
| _ => alt
{ p with alts := alts }
private def expandNatValuePattern (p : Problem) : MetaM Problem := do
let alts ← p.alts.mapM fun alt => do
match alt.patterns with
| .val n :: ps =>
match (← getNatValue? n) with
| some 0 => return { alt with patterns := .ctor ``Nat.zero [] [] [] :: ps }
| some (n+1) => return { alt with patterns := .ctor ``Nat.succ [] [] [.val (toExpr n)] :: ps }
| _ => return alt
| _ => return alt
return { p with alts := alts }
private def traceStep (msg : String) : StateRefT State MetaM Unit := do
trace[Meta.Match.match] "{msg} step"
@ -634,9 +640,9 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do
traceStep ("as-pattern")
let p ← processAsPattern p
process p
else if isNatValueTransition p then
else if (← isNatValueTransition p) then
traceStep ("nat value to constructor")
process (expandNatValuePattern p)
process (expandNatValuePattern p)
else if !isNextVar p then
traceStep ("non variable")
let p ← processNonVariable p
@ -654,11 +660,11 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do
else if isArrayLitTransition p then
let ps ← processArrayLit p
ps.forM process
else if hasNatValPattern p then
else if (← hasNatValPattern p) then
-- This branch is reachable when `p`, for example, is just values without an else-alternative.
-- We added it just to get better error messages.
traceStep ("nat value to constructor")
process (expandNatValuePattern p)
process (expandNatValuePattern p)
else
checkNextPatternTypes p
throwNonSupported p

View file

@ -4,46 +4,24 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.LitValues
import Lean.Expr
namespace Lean.Meta
-- TODO: produce error for `USize` because `USize.decEq` depends on an opaque value: `System.Platform.numBits`.
-- TODO: move?
private def UIntTypeNames : Array Name :=
#[``UInt8, ``UInt16, ``UInt32, ``UInt64, ``USize]
private def isUIntTypeName (n : Name) : Bool :=
UIntTypeNames.contains n
def isFinPatLit (e : Expr) : Bool :=
e.isAppOfArity `Fin.ofNat 2 && e.appArg!.isRawNatLit
/-- Return `some (typeName, numLit)` if `v` is of the form `UInt*.mk (Fin.ofNat _ numLit)` -/
def isUIntPatLit? (v : Expr) : Option (Name × Expr) :=
match v with
| Expr.app (Expr.const (Name.str typeName "mk" ..) ..) val .. =>
if isUIntTypeName typeName && isFinPatLit val then
some (typeName, val.appArg!)
else
none
| _ => none
def isUIntPatLit (v : Expr) : Bool :=
isUIntPatLit? v |>.isSome
/--
The frontend expands uint numerals occurring in patterns into `UInt*.mk ..` constructor applications.
This method convert them back into `UInt*.ofNat ..` applications.
-/
def foldPatValue (v : Expr) : Expr :=
match isUIntPatLit? v with
| some (typeName, numLit) => mkApp (mkConst (Name.mkStr typeName "ofNat")) numLit
| _ => v
/-- Return true is `e` is a term that should be processed by the `match`-compiler using `casesValues` -/
def isMatchValue (e : Expr) : Bool :=
e.isRawNatLit || e.isCharLit || e.isStringLit || isFinPatLit e || isUIntPatLit e
def isMatchValue (e : Expr) : MetaM Bool := do
let e ← instantiateMVars e
if (← getNatValue? e).isSome then return true
if (← getIntValue? e).isSome then return true
if (← getFinValue? e).isSome then return true
if (← getBitVecValue? e).isSome then return true
if (getStringValue? e).isSome then return true
if (← getCharValue? e).isSome then return true
if (← getUInt8Value? e).isSome then return true
if (← getUInt16Value? e).isSome then return true
if (← getUInt32Value? e).isSome then return true
if (← getUInt64Value? e).isSome then return true
return false
end Lean.Meta

View file

@ -0,0 +1,88 @@
@[simp] def f1 (i : Int) (a b : Nat) : Nat :=
match i, a with
| -1, _ => b
| _, 0 => b+1
| _, a+1 => f1 (i-1) a (b*2)
#check f1._eq_1
#check f1._eq_2
example : f1 (-1) a b = b := by simp -- should work
example : f1 (-2) 0 b = b+1 := by simp
example : f1 (-2) (a+1) b = f1 (-3) a (b*2) := by simp
example (h : i ≠ -1) : f1 i (a+1) b = f1 (i-1) a (b*2) := by simp -- should work
@[simp] def f2 (c : Char) (a b : Nat) : Nat :=
match c, a with
| 'a', _ => b
| _, 0 => b+1
| _, a+1 => f2 c a (b*2)
example : f2 'a' a b = b := by simp
example : f2 'b' 0 b = b+1 := by simp
example : f2 'b' (a+1) b = f2 'b' a (b*2) := by simp
example (h : c ≠ 'a') : f2 c (a+1) b = f2 c a (b*2) := by simp
@[simp] def f3 (i : Fin 5) (a b : Nat) : Nat :=
match i, a with
| 2, _ => b
| _, 0 => b+1
| _, a+1 => f3 (i+1) a (b*2)
#check f3._eq_1
#check f3._eq_2
example : f3 2 a b = b := by simp -- should work
example : f3 3 0 b = b+1 := by simp
example : f3 1 (a+1) b = f3 2 a (b*2) := by simp
example (h : i ≠ 2) : f3 i (a+1) b = f3 (i+1) a (b*2) := by simp; done -- should work
@[simp] def f4 (i : UInt16) (a b : Nat) : Nat :=
match i, a with
| 2, _ => b
| _, 0 => b+1
| _, a+1 => f4 (i+1) a (b*2)
#check f4._eq_1
#check f4._eq_2
example : f4 2 a b = b := by simp -- should work
example : f4 3 0 b = b+1 := by simp
example : f4 1 (a+1) b = f4 2 a (b*2) := by simp
example (h : i ≠ 2) : f4 i (a+1) b = f4 (i+1) a (b*2) := by simp -- should work
@[simp] def f5 (i : BitVec 8) (a b : Nat) : Nat :=
match i, a with
| 2, _ => b
| _, 0 => b+1
| _, a+1 => f5 (i+1) a (b*2)
#check f5._eq_1
#check f5._eq_2
open BitVec
example : f5 2 a b = b := by simp -- should work
example : f5 2#8 a b = b := by simp -- should work
example : f5 3 0 b = b+1 := by simp
example : f5 3#8 0 b = b+1 := by simp
example : f5 1 (a+1) b = f5 2 a (b*2) := by simp
example : f5 1#8 (a+1) b = f5 2 a (b*2) := by simp
example (h : i ≠ 2#8) : f5 i (a+1) b = f5 (i+1) a (b*2) := by simp -- should work
@[simp] def f6 (i : BitVec 8) (a b : Nat) : Nat :=
match i, a with
| 2#8, _ => b
| _, 0 => b+1
| _, a+1 => f6 (i+1) a (b*2)
#check f6._eq_1
#check f6._eq_2
example : f6 2#8 a b = b := by simp -- should work
example : f6 2#8 a b = b := by simp -- should work
example : f6 3 0 b = b+1 := by simp
example : f6 3#8 0 b = b+1 := by simp
example : f6 1 (a+1) b = f6 2 a (b*2) := by simp
example : f6 1#8 (a+1) b = f6 2 a (b*2) := by simp
example (h : i ≠ 2#8) : f6 i (a+1) b = f6 (i+1) a (b*2) := by simp -- should work