fix: match literal pattern support
The equation lemmas were not using the standard representation for literals.
This commit is contained in:
parent
66be8b9d4c
commit
056cb75ee0
7 changed files with 158 additions and 67 deletions
|
|
@ -473,7 +473,7 @@ partial def normalize (e : Expr) : M Expr := do
|
|||
let p ← normalize p
|
||||
addVar h
|
||||
return mkApp4 e.getAppFn (e.getArg! 0) x p h
|
||||
else if isMatchValue e then
|
||||
else if (← isMatchValue e) then
|
||||
return e
|
||||
else if e.isFVar then
|
||||
if (← isExplicitPatternVar e) then
|
||||
|
|
@ -571,8 +571,8 @@ private partial def toPattern (e : Expr) : MetaM Pattern := do
|
|||
match e.getArg! 1, e.getArg! 3 with
|
||||
| Expr.fvar x, Expr.fvar h => return Pattern.as x p h
|
||||
| _, _ => throwError "unexpected occurrence of auxiliary declaration 'namedPattern'"
|
||||
else if isMatchValue e then
|
||||
return Pattern.val e
|
||||
else if (← isMatchValue e) then
|
||||
return Pattern.val (← normLitValue e)
|
||||
else if e.isFVar then
|
||||
return Pattern.var e.fvarId!
|
||||
else
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ It also provides support for the following exceptional cases.
|
|||
- Bit-vectors encoded using `OfNat.ofNat` and `BitVec.ofNat`.
|
||||
- Negative integers encoded using raw natural numbers.
|
||||
- Characters encoded `Char.ofNat n` where `n` can be a raw natural number or an `OfNat.ofNat`.
|
||||
- Nested `Expr.mdata`.
|
||||
-/
|
||||
|
||||
/-- Returns `some n` if `e` is a raw natural number, i.e., it is of the form `.lit (.natVal n)`. -/
|
||||
|
|
@ -26,7 +27,7 @@ def getRawNatValue? (e : Expr) : Option Nat :=
|
|||
|
||||
/-- Return `some (n, type)` if `e` is an `OfNat.ofNat`-application encoding `n` for a type with name `typeDeclName`. -/
|
||||
def getOfNatValue? (e : Expr) (typeDeclName : Name) : MetaM (Option (Nat × Expr)) := OptionT.run do
|
||||
guard <| e.isAppOfArity ``OfNat.ofNat 3
|
||||
guard <| e.isAppOfArity' ``OfNat.ofNat 3
|
||||
let type ← whnfD e.appFn!.appFn!.appArg!
|
||||
guard <| type.getAppFn.isConstOf typeDeclName
|
||||
let .lit (.natVal n) := e.appFn!.appArg! | failure
|
||||
|
|
@ -43,15 +44,15 @@ def getNatValue? (e : Expr) : MetaM (Option Nat) := do
|
|||
def getIntValue? (e : Expr) : MetaM (Option Int) := do
|
||||
if let some (n, _) ← getOfNatValue? e ``Int then
|
||||
return some n
|
||||
if e.isAppOfArity ``Neg.neg 3 then
|
||||
let some (n, _) ← getOfNatValue? e.appArg! ``Int | return none
|
||||
if e.isAppOfArity' ``Neg.neg 3 then
|
||||
let some (n, _) ← getOfNatValue? e.appArg!.consumeMData ``Int | return none
|
||||
return some (-n)
|
||||
return none
|
||||
|
||||
/-- Return `some c` if `e` is a `Char.ofNat`-application encoding character `c`. -/
|
||||
def getCharValue? (e : Expr) : MetaM (Option Char) := OptionT.run do
|
||||
guard <| e.isAppOfArity ``Char.ofNat 1
|
||||
let n ← getNatValue? e.appArg!
|
||||
guard <| e.isAppOfArity' ``Char.ofNat 1
|
||||
let n ← getNatValue? e.appArg!.consumeMData
|
||||
return Char.ofNat n
|
||||
|
||||
/-- Return `some s` if `e` is of the form `.lit (.strVal s)`. -/
|
||||
|
|
@ -70,9 +71,9 @@ def getFinValue? (e : Expr) : MetaM (Option ((n : Nat) × Fin n)) := OptionT.run
|
|||
|
||||
/-- Return `some ⟨n, v⟩` if `e` is af `OfNat.ofNat` application encoding a `BitVec n` with value `v` -/
|
||||
def getBitVecValue? (e : Expr) : MetaM (Option ((n : Nat) × BitVec n)) := OptionT.run do
|
||||
if e.isAppOfArity ``BitVec.ofNat 2 then
|
||||
let n ← getNatValue? e.appFn!.appArg!
|
||||
let v ← getNatValue? e.appArg!
|
||||
if e.isAppOfArity' ``BitVec.ofNat 2 then
|
||||
let n ← getNatValue? e.appFn!.appArg!.consumeMData
|
||||
let v ← getNatValue? e.appArg!.consumeMData
|
||||
return ⟨n, BitVec.ofNat n v⟩
|
||||
let (v, type) ← getOfNatValue? e ``BitVec
|
||||
IO.println v
|
||||
|
|
@ -99,4 +100,22 @@ def getUInt64Value? (e : Expr) : MetaM (Option UInt64) := OptionT.run do
|
|||
let (n, _) ← getOfNatValue? e ``UInt64
|
||||
return UInt64.ofNat n
|
||||
|
||||
/--
|
||||
If `e` is literal value, ensure it is encoded using the standard representation.
|
||||
Otherwise, just return `e`.
|
||||
-/
|
||||
def normLitValue (e : Expr) : MetaM Expr := do
|
||||
let e ← instantiateMVars e
|
||||
if let some n ← getNatValue? e then return toExpr n
|
||||
if let some n ← getIntValue? e then return toExpr n
|
||||
if let some ⟨_, n⟩ ← getFinValue? e then return toExpr n
|
||||
if let some ⟨_, n⟩ ← getBitVecValue? e then return toExpr n
|
||||
if let some s := getStringValue? e then return toExpr s
|
||||
if let some c ← getCharValue? e then return toExpr c
|
||||
if let some n ← getUInt8Value? e then return toExpr n
|
||||
if let some n ← getUInt16Value? e then return toExpr n
|
||||
if let some n ← getUInt32Value? e then return toExpr n
|
||||
if let some n ← getUInt64Value? e then return toExpr n
|
||||
return e
|
||||
|
||||
end Lean.Meta
|
||||
|
|
|
|||
|
|
@ -343,7 +343,7 @@ partial def toPattern (e : Expr) : MetaM Pattern := do
|
|||
match e.getArg! 1, e.getArg! 3 with
|
||||
| Expr.fvar x, Expr.fvar h => return Pattern.as x p h
|
||||
| _, _ => throwError "unexpected occurrence of auxiliary declaration 'namedPattern'"
|
||||
else if isMatchValue e then
|
||||
else if (← isMatchValue e) then
|
||||
return Pattern.val e
|
||||
else if e.isFVar then
|
||||
return Pattern.var e.fvarId!
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ private def caseValueAux (mvarId : MVarId) (fvarId : FVarId) (value : Expr) (hNa
|
|||
let tag ← mvarId.getTag
|
||||
mvarId.checkNotAssigned `caseValue
|
||||
let target ← mvarId.getType
|
||||
let xEqValue ← mkEq (mkFVar fvarId) (foldPatValue value)
|
||||
let xEqValue ← mkEq (mkFVar fvarId) (← normLitValue value)
|
||||
let xNeqValue := mkApp (mkConst `Not) xEqValue
|
||||
let thenTarget := Lean.mkForall hName BinderInfo.default xEqValue target
|
||||
let elseTarget := Lean.mkForall hName BinderInfo.default xNeqValue target
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Meta.LitValues
|
||||
import Lean.Meta.Check
|
||||
import Lean.Meta.Closure
|
||||
import Lean.Meta.Tactic.Cases
|
||||
|
|
@ -94,10 +95,11 @@ private def hasValPattern (p : Problem) : Bool :=
|
|||
| .val _ :: _ => true
|
||||
| _ => false
|
||||
|
||||
private def hasNatValPattern (p : Problem) : Bool :=
|
||||
p.alts.any fun alt => match alt.patterns with
|
||||
| .val v :: _ => v.isRawNatLit -- TODO: support `OfNat.ofNat`?
|
||||
| _ => false
|
||||
private def hasNatValPattern (p : Problem) : MetaM Bool :=
|
||||
p.alts.anyM fun alt => do
|
||||
match alt.patterns with
|
||||
| .val v :: _ => return (← getNatValue? v).isSome
|
||||
| _ => return false
|
||||
|
||||
private def hasVarPattern (p : Problem) : Bool :=
|
||||
p.alts.any fun alt => match alt.patterns with
|
||||
|
|
@ -137,13 +139,13 @@ private def isArrayLitTransition (p : Problem) : Bool :=
|
|||
| .var _ :: _ => true
|
||||
| _ => false
|
||||
|
||||
private def isNatValueTransition (p : Problem) : Bool :=
|
||||
hasNatValPattern p
|
||||
&& (!isNextVar p ||
|
||||
private def isNatValueTransition (p : Problem) : MetaM Bool := do
|
||||
unless (← hasNatValPattern p) do return false
|
||||
return !isNextVar p ||
|
||||
p.alts.any fun alt => match alt.patterns with
|
||||
| .ctor .. :: _ => true
|
||||
| .inaccessible _ :: _ => true
|
||||
| _ => false)
|
||||
| _ => false
|
||||
|
||||
private def processSkipInaccessible (p : Problem) : Problem := Id.run do
|
||||
let x :: xs := p.vars | unreachable!
|
||||
|
|
@ -584,12 +586,16 @@ private def processArrayLit (p : Problem) : MetaM (Array Problem) := do
|
|||
let newAlts := p.alts.filter isFirstPatternVar
|
||||
return { p with mvarId := subgoal.mvarId, alts := newAlts, vars := x::xs }
|
||||
|
||||
private def expandNatValuePattern (p : Problem) : Problem :=
|
||||
let alts := p.alts.map fun alt => match alt.patterns with
|
||||
| .val (.lit (.natVal 0)) :: ps => { alt with patterns := .ctor ``Nat.zero [] [] [] :: ps }
|
||||
| .val (.lit (.natVal (n+1))) :: ps => { alt with patterns := .ctor ``Nat.succ [] [] [.val (mkRawNatLit n)] :: ps }
|
||||
| _ => alt
|
||||
{ p with alts := alts }
|
||||
private def expandNatValuePattern (p : Problem) : MetaM Problem := do
|
||||
let alts ← p.alts.mapM fun alt => do
|
||||
match alt.patterns with
|
||||
| .val n :: ps =>
|
||||
match (← getNatValue? n) with
|
||||
| some 0 => return { alt with patterns := .ctor ``Nat.zero [] [] [] :: ps }
|
||||
| some (n+1) => return { alt with patterns := .ctor ``Nat.succ [] [] [.val (toExpr n)] :: ps }
|
||||
| _ => return alt
|
||||
| _ => return alt
|
||||
return { p with alts := alts }
|
||||
|
||||
private def traceStep (msg : String) : StateRefT State MetaM Unit := do
|
||||
trace[Meta.Match.match] "{msg} step"
|
||||
|
|
@ -634,9 +640,9 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do
|
|||
traceStep ("as-pattern")
|
||||
let p ← processAsPattern p
|
||||
process p
|
||||
else if isNatValueTransition p then
|
||||
else if (← isNatValueTransition p) then
|
||||
traceStep ("nat value to constructor")
|
||||
process (expandNatValuePattern p)
|
||||
process (← expandNatValuePattern p)
|
||||
else if !isNextVar p then
|
||||
traceStep ("non variable")
|
||||
let p ← processNonVariable p
|
||||
|
|
@ -654,11 +660,11 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do
|
|||
else if isArrayLitTransition p then
|
||||
let ps ← processArrayLit p
|
||||
ps.forM process
|
||||
else if hasNatValPattern p then
|
||||
else if (← hasNatValPattern p) then
|
||||
-- This branch is reachable when `p`, for example, is just values without an else-alternative.
|
||||
-- We added it just to get better error messages.
|
||||
traceStep ("nat value to constructor")
|
||||
process (expandNatValuePattern p)
|
||||
process (← expandNatValuePattern p)
|
||||
else
|
||||
checkNextPatternTypes p
|
||||
throwNonSupported p
|
||||
|
|
|
|||
|
|
@ -4,46 +4,24 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Meta.LitValues
|
||||
import Lean.Expr
|
||||
|
||||
namespace Lean.Meta
|
||||
-- TODO: produce error for `USize` because `USize.decEq` depends on an opaque value: `System.Platform.numBits`.
|
||||
|
||||
-- TODO: move?
|
||||
private def UIntTypeNames : Array Name :=
|
||||
#[``UInt8, ``UInt16, ``UInt32, ``UInt64, ``USize]
|
||||
|
||||
private def isUIntTypeName (n : Name) : Bool :=
|
||||
UIntTypeNames.contains n
|
||||
|
||||
def isFinPatLit (e : Expr) : Bool :=
|
||||
e.isAppOfArity `Fin.ofNat 2 && e.appArg!.isRawNatLit
|
||||
|
||||
/-- Return `some (typeName, numLit)` if `v` is of the form `UInt*.mk (Fin.ofNat _ numLit)` -/
|
||||
def isUIntPatLit? (v : Expr) : Option (Name × Expr) :=
|
||||
match v with
|
||||
| Expr.app (Expr.const (Name.str typeName "mk" ..) ..) val .. =>
|
||||
if isUIntTypeName typeName && isFinPatLit val then
|
||||
some (typeName, val.appArg!)
|
||||
else
|
||||
none
|
||||
| _ => none
|
||||
|
||||
def isUIntPatLit (v : Expr) : Bool :=
|
||||
isUIntPatLit? v |>.isSome
|
||||
|
||||
/--
|
||||
The frontend expands uint numerals occurring in patterns into `UInt*.mk ..` constructor applications.
|
||||
This method convert them back into `UInt*.ofNat ..` applications.
|
||||
-/
|
||||
def foldPatValue (v : Expr) : Expr :=
|
||||
match isUIntPatLit? v with
|
||||
| some (typeName, numLit) => mkApp (mkConst (Name.mkStr typeName "ofNat")) numLit
|
||||
| _ => v
|
||||
|
||||
|
||||
/-- Return true is `e` is a term that should be processed by the `match`-compiler using `casesValues` -/
|
||||
def isMatchValue (e : Expr) : Bool :=
|
||||
e.isRawNatLit || e.isCharLit || e.isStringLit || isFinPatLit e || isUIntPatLit e
|
||||
def isMatchValue (e : Expr) : MetaM Bool := do
|
||||
let e ← instantiateMVars e
|
||||
if (← getNatValue? e).isSome then return true
|
||||
if (← getIntValue? e).isSome then return true
|
||||
if (← getFinValue? e).isSome then return true
|
||||
if (← getBitVecValue? e).isSome then return true
|
||||
if (getStringValue? e).isSome then return true
|
||||
if (← getCharValue? e).isSome then return true
|
||||
if (← getUInt8Value? e).isSome then return true
|
||||
if (← getUInt16Value? e).isSome then return true
|
||||
if (← getUInt32Value? e).isSome then return true
|
||||
if (← getUInt64Value? e).isSome then return true
|
||||
return false
|
||||
|
||||
end Lean.Meta
|
||||
|
|
|
|||
88
tests/lean/run/match_lit_issues.lean
Normal file
88
tests/lean/run/match_lit_issues.lean
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
@[simp] def f1 (i : Int) (a b : Nat) : Nat :=
|
||||
match i, a with
|
||||
| -1, _ => b
|
||||
| _, 0 => b+1
|
||||
| _, a+1 => f1 (i-1) a (b*2)
|
||||
|
||||
#check f1._eq_1
|
||||
#check f1._eq_2
|
||||
|
||||
example : f1 (-1) a b = b := by simp -- should work
|
||||
example : f1 (-2) 0 b = b+1 := by simp
|
||||
example : f1 (-2) (a+1) b = f1 (-3) a (b*2) := by simp
|
||||
example (h : i ≠ -1) : f1 i (a+1) b = f1 (i-1) a (b*2) := by simp -- should work
|
||||
|
||||
@[simp] def f2 (c : Char) (a b : Nat) : Nat :=
|
||||
match c, a with
|
||||
| 'a', _ => b
|
||||
| _, 0 => b+1
|
||||
| _, a+1 => f2 c a (b*2)
|
||||
|
||||
example : f2 'a' a b = b := by simp
|
||||
example : f2 'b' 0 b = b+1 := by simp
|
||||
example : f2 'b' (a+1) b = f2 'b' a (b*2) := by simp
|
||||
example (h : c ≠ 'a') : f2 c (a+1) b = f2 c a (b*2) := by simp
|
||||
|
||||
@[simp] def f3 (i : Fin 5) (a b : Nat) : Nat :=
|
||||
match i, a with
|
||||
| 2, _ => b
|
||||
| _, 0 => b+1
|
||||
| _, a+1 => f3 (i+1) a (b*2)
|
||||
|
||||
#check f3._eq_1
|
||||
#check f3._eq_2
|
||||
|
||||
example : f3 2 a b = b := by simp -- should work
|
||||
example : f3 3 0 b = b+1 := by simp
|
||||
example : f3 1 (a+1) b = f3 2 a (b*2) := by simp
|
||||
example (h : i ≠ 2) : f3 i (a+1) b = f3 (i+1) a (b*2) := by simp; done -- should work
|
||||
|
||||
@[simp] def f4 (i : UInt16) (a b : Nat) : Nat :=
|
||||
match i, a with
|
||||
| 2, _ => b
|
||||
| _, 0 => b+1
|
||||
| _, a+1 => f4 (i+1) a (b*2)
|
||||
|
||||
#check f4._eq_1
|
||||
#check f4._eq_2
|
||||
|
||||
example : f4 2 a b = b := by simp -- should work
|
||||
example : f4 3 0 b = b+1 := by simp
|
||||
example : f4 1 (a+1) b = f4 2 a (b*2) := by simp
|
||||
example (h : i ≠ 2) : f4 i (a+1) b = f4 (i+1) a (b*2) := by simp -- should work
|
||||
|
||||
@[simp] def f5 (i : BitVec 8) (a b : Nat) : Nat :=
|
||||
match i, a with
|
||||
| 2, _ => b
|
||||
| _, 0 => b+1
|
||||
| _, a+1 => f5 (i+1) a (b*2)
|
||||
|
||||
#check f5._eq_1
|
||||
#check f5._eq_2
|
||||
|
||||
open BitVec
|
||||
|
||||
example : f5 2 a b = b := by simp -- should work
|
||||
example : f5 2#8 a b = b := by simp -- should work
|
||||
example : f5 3 0 b = b+1 := by simp
|
||||
example : f5 3#8 0 b = b+1 := by simp
|
||||
example : f5 1 (a+1) b = f5 2 a (b*2) := by simp
|
||||
example : f5 1#8 (a+1) b = f5 2 a (b*2) := by simp
|
||||
example (h : i ≠ 2#8) : f5 i (a+1) b = f5 (i+1) a (b*2) := by simp -- should work
|
||||
|
||||
@[simp] def f6 (i : BitVec 8) (a b : Nat) : Nat :=
|
||||
match i, a with
|
||||
| 2#8, _ => b
|
||||
| _, 0 => b+1
|
||||
| _, a+1 => f6 (i+1) a (b*2)
|
||||
|
||||
#check f6._eq_1
|
||||
#check f6._eq_2
|
||||
|
||||
example : f6 2#8 a b = b := by simp -- should work
|
||||
example : f6 2#8 a b = b := by simp -- should work
|
||||
example : f6 3 0 b = b+1 := by simp
|
||||
example : f6 3#8 0 b = b+1 := by simp
|
||||
example : f6 1 (a+1) b = f6 2 a (b*2) := by simp
|
||||
example : f6 1#8 (a+1) b = f6 2 a (b*2) := by simp
|
||||
example (h : i ≠ 2#8) : f6 i (a+1) b = f6 (i+1) a (b*2) := by simp -- should work
|
||||
Loading…
Add table
Reference in a new issue