From 72d233d181baa183a5288a3a52defd219015e952 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 25 Feb 2024 09:44:08 -0800 Subject: [PATCH] fix: `match` patterns containing int values and constructors (#3496) --- src/Lean/Expr.lean | 19 ++++++++++++++++ src/Lean/Meta/LitValues.lean | 16 ++++++++------ src/Lean/Meta/Match/Match.lean | 40 +++++++++++++++++++++++++++++----- 3 files changed, 63 insertions(+), 12 deletions(-) diff --git a/src/Lean/Expr.lean b/src/Lean/Expr.lean index 74c9df1348..0c67801208 100644 --- a/src/Lean/Expr.lean +++ b/src/Lean/Expr.lean @@ -1037,6 +1037,14 @@ def getAppFn : Expr → Expr | app f _ => getAppFn f | e => e +/-- +Similar to `getAppFn`, but skips `mdata` +-/ +def getAppFn' : Expr → Expr + | app f _ => getAppFn' f + | mdata _ a => getAppFn' a + | e => e + /-- Given `f a₀ a₁ ... aₙ`, returns true if `f` is a constant with name `n`. -/ def isAppOf (e : Expr) (n : Name) : Bool := match e.getAppFn with @@ -1207,10 +1215,21 @@ def getRevArg! : Expr → Nat → Expr | app f _, i+1 => getRevArg! f i | _, _ => panic! "invalid index" +/-- Similar to `getRevArg!` but skips `mdata` -/ +def getRevArg!' : Expr → Nat → Expr + | mdata _ a, i => getRevArg!' a i + | app _ a, 0 => a + | app f _, i+1 => getRevArg!' f i + | _, _ => panic! "invalid index" + /-- Given `f a₀ a₁ ... aₙ`, returns the `i`th argument or panics if out of bounds. -/ @[inline] def getArg! (e : Expr) (i : Nat) (n := e.getAppNumArgs) : Expr := getRevArg! e (n - i - 1) +/-- Similar to `getArg!`, but skips mdata -/ +@[inline] def getArg!' (e : Expr) (i : Nat) (n := e.getAppNumArgs) : Expr := + getRevArg!' e (n - i - 1) + /-- Given `f a₀ a₁ ... aₙ`, returns the `i`th argument or returns `v₀` if out of bounds. -/ @[inline] def getArgD (e : Expr) (i : Nat) (v₀ : Expr) (n := e.getAppNumArgs) : Expr := getRevArgD e (n - i - 1) v₀ diff --git a/src/Lean/Meta/LitValues.lean b/src/Lean/Meta/LitValues.lean index 912cea0a24..8986210e67 100644 --- a/src/Lean/Meta/LitValues.lean +++ b/src/Lean/Meta/LitValues.lean @@ -21,20 +21,22 @@ It also provides support for the following exceptional cases. /-- Returns `some n` if `e` is a raw natural number, i.e., it is of the form `.lit (.natVal n)`. -/ def getRawNatValue? (e : Expr) : Option Nat := - match e with + match e.consumeMData with | .lit (.natVal n) => some n | _ => none /-- Return `some (n, type)` if `e` is an `OfNat.ofNat`-application encoding `n` for a type with name `typeDeclName`. -/ def getOfNatValue? (e : Expr) (typeDeclName : Name) : MetaM (Option (Nat × Expr)) := OptionT.run do + let e := e.consumeMData guard <| e.isAppOfArity' ``OfNat.ofNat 3 - let type ← whnfD e.appFn!.appFn!.appArg! + let type ← whnfD (e.getArg!' 0) guard <| type.getAppFn.isConstOf typeDeclName - let .lit (.natVal n) := e.appFn!.appArg! | failure + let .lit (.natVal n) := (e.getArg!' 1).consumeMData | failure return (n, type) /-- Return `some n` if `e` is a raw natural number or an `OfNat.ofNat`-application encoding `n`. -/ def getNatValue? (e : Expr) : MetaM (Option Nat) := do + let e := e.consumeMData if let some n := getRawNatValue? e then return some n let some (n, _) ← getOfNatValue? e ``Nat | return none @@ -45,14 +47,14 @@ def getIntValue? (e : Expr) : MetaM (Option Int) := do if let some (n, _) ← getOfNatValue? e ``Int then return some n if e.isAppOfArity' ``Neg.neg 3 then - let some (n, _) ← getOfNatValue? e.appArg!.consumeMData ``Int | return none + let some (n, _) ← getOfNatValue? (e.getArg!' 2) ``Int | return none return some (-n) return none /-- Return `some c` if `e` is a `Char.ofNat`-application encoding character `c`. -/ def getCharValue? (e : Expr) : MetaM (Option Char) := OptionT.run do guard <| e.isAppOfArity' ``Char.ofNat 1 - let n ← getNatValue? e.appArg!.consumeMData + let n ← getNatValue? (e.getArg!' 0) return Char.ofNat n /-- Return `some s` if `e` is of the form `.lit (.strVal s)`. -/ @@ -72,8 +74,8 @@ def getFinValue? (e : Expr) : MetaM (Option ((n : Nat) × Fin n)) := OptionT.run /-- Return `some ⟨n, v⟩` if `e` is af `OfNat.ofNat` application encoding a `BitVec n` with value `v` -/ def getBitVecValue? (e : Expr) : MetaM (Option ((n : Nat) × BitVec n)) := OptionT.run do if e.isAppOfArity' ``BitVec.ofNat 2 then - let n ← getNatValue? e.appFn!.appArg!.consumeMData - let v ← getNatValue? e.appArg!.consumeMData + let n ← getNatValue? (e.getArg!' 0) + let v ← getNatValue? (e.getArg!' 1) return ⟨n, BitVec.ofNat n v⟩ let (v, type) ← getOfNatValue? e ``BitVec IO.println v diff --git a/src/Lean/Meta/Match/Match.lean b/src/Lean/Meta/Match/Match.lean index 75a81ae03e..3e659d186c 100644 --- a/src/Lean/Meta/Match/Match.lean +++ b/src/Lean/Meta/Match/Match.lean @@ -101,6 +101,12 @@ private def hasNatValPattern (p : Problem) : MetaM Bool := | .val v :: _ => return (← getNatValue? v).isSome | _ => return false +private def hasIntValPattern (p : Problem) : MetaM Bool := + p.alts.anyM fun alt => do + match alt.patterns with + | .val v :: _ => return (← getIntValue? v).isSome + | _ => return false + private def hasVarPattern (p : Problem) : Bool := p.alts.any fun alt => match alt.patterns with | .var _ :: _ => true @@ -148,13 +154,20 @@ private def isArrayLitTransition (p : Problem) : Bool := | .var _ :: _ => true | _ => false +private def hasCtorOrInaccessible (p : Problem) : Bool := + !isNextVar p || + p.alts.any fun alt => match alt.patterns with + | .ctor .. :: _ => true + | .inaccessible _ :: _ => true + | _ => false + private def isNatValueTransition (p : Problem) : MetaM Bool := do unless (← hasNatValPattern p) do return false - return !isNextVar p || - p.alts.any fun alt => match alt.patterns with - | .ctor .. :: _ => true - | .inaccessible _ :: _ => true - | _ => false + return hasCtorOrInaccessible p + +private def isIntValueTransition (p : Problem) : MetaM Bool := do + unless (← hasIntValPattern p) do return false + return hasCtorOrInaccessible p private def processSkipInaccessible (p : Problem) : Problem := Id.run do let x :: xs := p.vars | unreachable! @@ -606,6 +619,20 @@ private def expandNatValuePattern (p : Problem) : MetaM Problem := do | _ => return alt return { p with alts := alts } +private def expandIntValuePattern (p : Problem) : MetaM Problem := do + let alts ← p.alts.mapM fun alt => do + match alt.patterns with + | .val n :: ps => + match (← getIntValue? n) with + | some i => + if i >= 0 then + return { alt with patterns := .ctor ``Int.ofNat [] [] [.val (toExpr i.toNat)] :: ps } + else + return { alt with patterns := .ctor ``Int.negSucc [] [] [.val (toExpr (-(i + 1)).toNat)] :: ps } + | _ => return alt + | _ => return alt + return { p with alts := alts } + private def expandFinValuePattern (p : Problem) : MetaM Problem := do let alts ← p.alts.mapM fun alt => do match alt.patterns with @@ -665,6 +692,9 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do else if (← isNatValueTransition p) then traceStep ("nat value to constructor") process (← expandNatValuePattern p) + else if (← isIntValueTransition p) then + traceStep ("int value to constructor") + process (← expandIntValuePattern p) else if (← isFinValueTransition p) then traceStep ("fin value to constructor") process (← expandFinValuePattern p)