diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 0d82df4dbd..97bf6f84b4 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -473,7 +473,7 @@ partial def normalize (e : Expr) : M Expr := do let p ← normalize p addVar h return mkApp4 e.getAppFn (e.getArg! 0) x p h - else if isMatchValue e then + else if (← isMatchValue e) then return e else if e.isFVar then if (← isExplicitPatternVar e) then @@ -571,8 +571,8 @@ private partial def toPattern (e : Expr) : MetaM Pattern := do match e.getArg! 1, e.getArg! 3 with | Expr.fvar x, Expr.fvar h => return Pattern.as x p h | _, _ => throwError "unexpected occurrence of auxiliary declaration 'namedPattern'" - else if isMatchValue e then - return Pattern.val e + else if (← isMatchValue e) then + return Pattern.val (← normLitValue e) else if e.isFVar then return Pattern.var e.fvarId! else diff --git a/src/Lean/Meta/LitValues.lean b/src/Lean/Meta/LitValues.lean index 99404c4283..912cea0a24 100644 --- a/src/Lean/Meta/LitValues.lean +++ b/src/Lean/Meta/LitValues.lean @@ -16,6 +16,7 @@ It also provides support for the following exceptional cases. - Bit-vectors encoded using `OfNat.ofNat` and `BitVec.ofNat`. - Negative integers encoded using raw natural numbers. - Characters encoded `Char.ofNat n` where `n` can be a raw natural number or an `OfNat.ofNat`. +- Nested `Expr.mdata`. -/ /-- Returns `some n` if `e` is a raw natural number, i.e., it is of the form `.lit (.natVal n)`. -/ @@ -26,7 +27,7 @@ def getRawNatValue? (e : Expr) : Option Nat := /-- Return `some (n, type)` if `e` is an `OfNat.ofNat`-application encoding `n` for a type with name `typeDeclName`. -/ def getOfNatValue? (e : Expr) (typeDeclName : Name) : MetaM (Option (Nat × Expr)) := OptionT.run do - guard <| e.isAppOfArity ``OfNat.ofNat 3 + guard <| e.isAppOfArity' ``OfNat.ofNat 3 let type ← whnfD e.appFn!.appFn!.appArg! guard <| type.getAppFn.isConstOf typeDeclName let .lit (.natVal n) := e.appFn!.appArg! | failure @@ -43,15 +44,15 @@ def getNatValue? (e : Expr) : MetaM (Option Nat) := do def getIntValue? (e : Expr) : MetaM (Option Int) := do if let some (n, _) ← getOfNatValue? e ``Int then return some n - if e.isAppOfArity ``Neg.neg 3 then - let some (n, _) ← getOfNatValue? e.appArg! ``Int | return none + if e.isAppOfArity' ``Neg.neg 3 then + let some (n, _) ← getOfNatValue? e.appArg!.consumeMData ``Int | return none return some (-n) return none /-- Return `some c` if `e` is a `Char.ofNat`-application encoding character `c`. -/ def getCharValue? (e : Expr) : MetaM (Option Char) := OptionT.run do - guard <| e.isAppOfArity ``Char.ofNat 1 - let n ← getNatValue? e.appArg! + guard <| e.isAppOfArity' ``Char.ofNat 1 + let n ← getNatValue? e.appArg!.consumeMData return Char.ofNat n /-- Return `some s` if `e` is of the form `.lit (.strVal s)`. -/ @@ -70,9 +71,9 @@ def getFinValue? (e : Expr) : MetaM (Option ((n : Nat) × Fin n)) := OptionT.run /-- Return `some ⟨n, v⟩` if `e` is af `OfNat.ofNat` application encoding a `BitVec n` with value `v` -/ def getBitVecValue? (e : Expr) : MetaM (Option ((n : Nat) × BitVec n)) := OptionT.run do - if e.isAppOfArity ``BitVec.ofNat 2 then - let n ← getNatValue? e.appFn!.appArg! - let v ← getNatValue? e.appArg! + if e.isAppOfArity' ``BitVec.ofNat 2 then + let n ← getNatValue? e.appFn!.appArg!.consumeMData + let v ← getNatValue? e.appArg!.consumeMData return ⟨n, BitVec.ofNat n v⟩ let (v, type) ← getOfNatValue? e ``BitVec IO.println v @@ -99,4 +100,22 @@ def getUInt64Value? (e : Expr) : MetaM (Option UInt64) := OptionT.run do let (n, _) ← getOfNatValue? e ``UInt64 return UInt64.ofNat n +/-- +If `e` is literal value, ensure it is encoded using the standard representation. +Otherwise, just return `e`. +-/ +def normLitValue (e : Expr) : MetaM Expr := do + let e ← instantiateMVars e + if let some n ← getNatValue? e then return toExpr n + if let some n ← getIntValue? e then return toExpr n + if let some ⟨_, n⟩ ← getFinValue? e then return toExpr n + if let some ⟨_, n⟩ ← getBitVecValue? e then return toExpr n + if let some s := getStringValue? e then return toExpr s + if let some c ← getCharValue? e then return toExpr c + if let some n ← getUInt8Value? e then return toExpr n + if let some n ← getUInt16Value? e then return toExpr n + if let some n ← getUInt32Value? e then return toExpr n + if let some n ← getUInt64Value? e then return toExpr n + return e + end Lean.Meta diff --git a/src/Lean/Meta/Match/Basic.lean b/src/Lean/Meta/Match/Basic.lean index 82679f2719..14bd8854d6 100644 --- a/src/Lean/Meta/Match/Basic.lean +++ b/src/Lean/Meta/Match/Basic.lean @@ -343,7 +343,7 @@ partial def toPattern (e : Expr) : MetaM Pattern := do match e.getArg! 1, e.getArg! 3 with | Expr.fvar x, Expr.fvar h => return Pattern.as x p h | _, _ => throwError "unexpected occurrence of auxiliary declaration 'namedPattern'" - else if isMatchValue e then + else if (← isMatchValue e) then return Pattern.val e else if e.isFVar then return Pattern.var e.fvarId! diff --git a/src/Lean/Meta/Match/CaseValues.lean b/src/Lean/Meta/Match/CaseValues.lean index 93726bd213..39aa755c24 100644 --- a/src/Lean/Meta/Match/CaseValues.lean +++ b/src/Lean/Meta/Match/CaseValues.lean @@ -30,7 +30,7 @@ private def caseValueAux (mvarId : MVarId) (fvarId : FVarId) (value : Expr) (hNa let tag ← mvarId.getTag mvarId.checkNotAssigned `caseValue let target ← mvarId.getType - let xEqValue ← mkEq (mkFVar fvarId) (foldPatValue value) + let xEqValue ← mkEq (mkFVar fvarId) (← normLitValue value) let xNeqValue := mkApp (mkConst `Not) xEqValue let thenTarget := Lean.mkForall hName BinderInfo.default xEqValue target let elseTarget := Lean.mkForall hName BinderInfo.default xNeqValue target diff --git a/src/Lean/Meta/Match/Match.lean b/src/Lean/Meta/Match/Match.lean index 923b045235..4c5332d869 100644 --- a/src/Lean/Meta/Match/Match.lean +++ b/src/Lean/Meta/Match/Match.lean @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ prelude +import Lean.Meta.LitValues import Lean.Meta.Check import Lean.Meta.Closure import Lean.Meta.Tactic.Cases @@ -94,10 +95,11 @@ private def hasValPattern (p : Problem) : Bool := | .val _ :: _ => true | _ => false -private def hasNatValPattern (p : Problem) : Bool := - p.alts.any fun alt => match alt.patterns with - | .val v :: _ => v.isRawNatLit -- TODO: support `OfNat.ofNat`? - | _ => false +private def hasNatValPattern (p : Problem) : MetaM Bool := + p.alts.anyM fun alt => do + match alt.patterns with + | .val v :: _ => return (← getNatValue? v).isSome + | _ => return false private def hasVarPattern (p : Problem) : Bool := p.alts.any fun alt => match alt.patterns with @@ -137,13 +139,13 @@ private def isArrayLitTransition (p : Problem) : Bool := | .var _ :: _ => true | _ => false -private def isNatValueTransition (p : Problem) : Bool := - hasNatValPattern p - && (!isNextVar p || +private def isNatValueTransition (p : Problem) : MetaM Bool := do + unless (← hasNatValPattern p) do return false + return !isNextVar p || p.alts.any fun alt => match alt.patterns with | .ctor .. :: _ => true | .inaccessible _ :: _ => true - | _ => false) + | _ => false private def processSkipInaccessible (p : Problem) : Problem := Id.run do let x :: xs := p.vars | unreachable! @@ -584,12 +586,16 @@ private def processArrayLit (p : Problem) : MetaM (Array Problem) := do let newAlts := p.alts.filter isFirstPatternVar return { p with mvarId := subgoal.mvarId, alts := newAlts, vars := x::xs } -private def expandNatValuePattern (p : Problem) : Problem := - let alts := p.alts.map fun alt => match alt.patterns with - | .val (.lit (.natVal 0)) :: ps => { alt with patterns := .ctor ``Nat.zero [] [] [] :: ps } - | .val (.lit (.natVal (n+1))) :: ps => { alt with patterns := .ctor ``Nat.succ [] [] [.val (mkRawNatLit n)] :: ps } - | _ => alt - { p with alts := alts } +private def expandNatValuePattern (p : Problem) : MetaM Problem := do + let alts ← p.alts.mapM fun alt => do + match alt.patterns with + | .val n :: ps => + match (← getNatValue? n) with + | some 0 => return { alt with patterns := .ctor ``Nat.zero [] [] [] :: ps } + | some (n+1) => return { alt with patterns := .ctor ``Nat.succ [] [] [.val (toExpr n)] :: ps } + | _ => return alt + | _ => return alt + return { p with alts := alts } private def traceStep (msg : String) : StateRefT State MetaM Unit := do trace[Meta.Match.match] "{msg} step" @@ -634,9 +640,9 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do traceStep ("as-pattern") let p ← processAsPattern p process p - else if isNatValueTransition p then + else if (← isNatValueTransition p) then traceStep ("nat value to constructor") - process (expandNatValuePattern p) + process (← expandNatValuePattern p) else if !isNextVar p then traceStep ("non variable") let p ← processNonVariable p @@ -654,11 +660,11 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do else if isArrayLitTransition p then let ps ← processArrayLit p ps.forM process - else if hasNatValPattern p then + else if (← hasNatValPattern p) then -- This branch is reachable when `p`, for example, is just values without an else-alternative. -- We added it just to get better error messages. traceStep ("nat value to constructor") - process (expandNatValuePattern p) + process (← expandNatValuePattern p) else checkNextPatternTypes p throwNonSupported p diff --git a/src/Lean/Meta/Match/Value.lean b/src/Lean/Meta/Match/Value.lean index 1a45451a8b..9ccd0a390a 100644 --- a/src/Lean/Meta/Match/Value.lean +++ b/src/Lean/Meta/Match/Value.lean @@ -4,46 +4,24 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ prelude +import Lean.Meta.LitValues import Lean.Expr namespace Lean.Meta --- TODO: produce error for `USize` because `USize.decEq` depends on an opaque value: `System.Platform.numBits`. - --- TODO: move? -private def UIntTypeNames : Array Name := - #[``UInt8, ``UInt16, ``UInt32, ``UInt64, ``USize] - -private def isUIntTypeName (n : Name) : Bool := - UIntTypeNames.contains n - -def isFinPatLit (e : Expr) : Bool := - e.isAppOfArity `Fin.ofNat 2 && e.appArg!.isRawNatLit - -/-- Return `some (typeName, numLit)` if `v` is of the form `UInt*.mk (Fin.ofNat _ numLit)` -/ -def isUIntPatLit? (v : Expr) : Option (Name × Expr) := - match v with - | Expr.app (Expr.const (Name.str typeName "mk" ..) ..) val .. => - if isUIntTypeName typeName && isFinPatLit val then - some (typeName, val.appArg!) - else - none - | _ => none - -def isUIntPatLit (v : Expr) : Bool := - isUIntPatLit? v |>.isSome - -/-- - The frontend expands uint numerals occurring in patterns into `UInt*.mk ..` constructor applications. - This method convert them back into `UInt*.ofNat ..` applications. --/ -def foldPatValue (v : Expr) : Expr := - match isUIntPatLit? v with - | some (typeName, numLit) => mkApp (mkConst (Name.mkStr typeName "ofNat")) numLit - | _ => v - /-- Return true is `e` is a term that should be processed by the `match`-compiler using `casesValues` -/ -def isMatchValue (e : Expr) : Bool := - e.isRawNatLit || e.isCharLit || e.isStringLit || isFinPatLit e || isUIntPatLit e +def isMatchValue (e : Expr) : MetaM Bool := do + let e ← instantiateMVars e + if (← getNatValue? e).isSome then return true + if (← getIntValue? e).isSome then return true + if (← getFinValue? e).isSome then return true + if (← getBitVecValue? e).isSome then return true + if (getStringValue? e).isSome then return true + if (← getCharValue? e).isSome then return true + if (← getUInt8Value? e).isSome then return true + if (← getUInt16Value? e).isSome then return true + if (← getUInt32Value? e).isSome then return true + if (← getUInt64Value? e).isSome then return true + return false end Lean.Meta diff --git a/tests/lean/run/match_lit_issues.lean b/tests/lean/run/match_lit_issues.lean new file mode 100644 index 0000000000..14b434e38e --- /dev/null +++ b/tests/lean/run/match_lit_issues.lean @@ -0,0 +1,88 @@ +@[simp] def f1 (i : Int) (a b : Nat) : Nat := + match i, a with + | -1, _ => b + | _, 0 => b+1 + | _, a+1 => f1 (i-1) a (b*2) + +#check f1._eq_1 +#check f1._eq_2 + +example : f1 (-1) a b = b := by simp -- should work +example : f1 (-2) 0 b = b+1 := by simp +example : f1 (-2) (a+1) b = f1 (-3) a (b*2) := by simp +example (h : i ≠ -1) : f1 i (a+1) b = f1 (i-1) a (b*2) := by simp -- should work + +@[simp] def f2 (c : Char) (a b : Nat) : Nat := + match c, a with + | 'a', _ => b + | _, 0 => b+1 + | _, a+1 => f2 c a (b*2) + +example : f2 'a' a b = b := by simp +example : f2 'b' 0 b = b+1 := by simp +example : f2 'b' (a+1) b = f2 'b' a (b*2) := by simp +example (h : c ≠ 'a') : f2 c (a+1) b = f2 c a (b*2) := by simp + +@[simp] def f3 (i : Fin 5) (a b : Nat) : Nat := + match i, a with + | 2, _ => b + | _, 0 => b+1 + | _, a+1 => f3 (i+1) a (b*2) + +#check f3._eq_1 +#check f3._eq_2 + +example : f3 2 a b = b := by simp -- should work +example : f3 3 0 b = b+1 := by simp +example : f3 1 (a+1) b = f3 2 a (b*2) := by simp +example (h : i ≠ 2) : f3 i (a+1) b = f3 (i+1) a (b*2) := by simp; done -- should work + +@[simp] def f4 (i : UInt16) (a b : Nat) : Nat := + match i, a with + | 2, _ => b + | _, 0 => b+1 + | _, a+1 => f4 (i+1) a (b*2) + +#check f4._eq_1 +#check f4._eq_2 + +example : f4 2 a b = b := by simp -- should work +example : f4 3 0 b = b+1 := by simp +example : f4 1 (a+1) b = f4 2 a (b*2) := by simp +example (h : i ≠ 2) : f4 i (a+1) b = f4 (i+1) a (b*2) := by simp -- should work + +@[simp] def f5 (i : BitVec 8) (a b : Nat) : Nat := + match i, a with + | 2, _ => b + | _, 0 => b+1 + | _, a+1 => f5 (i+1) a (b*2) + +#check f5._eq_1 +#check f5._eq_2 + +open BitVec + +example : f5 2 a b = b := by simp -- should work +example : f5 2#8 a b = b := by simp -- should work +example : f5 3 0 b = b+1 := by simp +example : f5 3#8 0 b = b+1 := by simp +example : f5 1 (a+1) b = f5 2 a (b*2) := by simp +example : f5 1#8 (a+1) b = f5 2 a (b*2) := by simp +example (h : i ≠ 2#8) : f5 i (a+1) b = f5 (i+1) a (b*2) := by simp -- should work + +@[simp] def f6 (i : BitVec 8) (a b : Nat) : Nat := + match i, a with + | 2#8, _ => b + | _, 0 => b+1 + | _, a+1 => f6 (i+1) a (b*2) + +#check f6._eq_1 +#check f6._eq_2 + +example : f6 2#8 a b = b := by simp -- should work +example : f6 2#8 a b = b := by simp -- should work +example : f6 3 0 b = b+1 := by simp +example : f6 3#8 0 b = b+1 := by simp +example : f6 1 (a+1) b = f6 2 a (b*2) := by simp +example : f6 1#8 (a+1) b = f6 2 a (b*2) := by simp +example (h : i ≠ 2#8) : f6 i (a+1) b = f6 (i+1) a (b*2) := by simp -- should work