fix: match patterns containing int values and constructors (#3496)
This commit is contained in:
parent
9e5e0e23b2
commit
72d233d181
3 changed files with 63 additions and 12 deletions
|
|
@ -1037,6 +1037,14 @@ def getAppFn : Expr → Expr
|
|||
| app f _ => getAppFn f
|
||||
| e => e
|
||||
|
||||
/--
|
||||
Similar to `getAppFn`, but skips `mdata`
|
||||
-/
|
||||
def getAppFn' : Expr → Expr
|
||||
| app f _ => getAppFn' f
|
||||
| mdata _ a => getAppFn' a
|
||||
| e => e
|
||||
|
||||
/-- Given `f a₀ a₁ ... aₙ`, returns true if `f` is a constant with name `n`. -/
|
||||
def isAppOf (e : Expr) (n : Name) : Bool :=
|
||||
match e.getAppFn with
|
||||
|
|
@ -1207,10 +1215,21 @@ def getRevArg! : Expr → Nat → Expr
|
|||
| app f _, i+1 => getRevArg! f i
|
||||
| _, _ => panic! "invalid index"
|
||||
|
||||
/-- Similar to `getRevArg!` but skips `mdata` -/
|
||||
def getRevArg!' : Expr → Nat → Expr
|
||||
| mdata _ a, i => getRevArg!' a i
|
||||
| app _ a, 0 => a
|
||||
| app f _, i+1 => getRevArg!' f i
|
||||
| _, _ => panic! "invalid index"
|
||||
|
||||
/-- Given `f a₀ a₁ ... aₙ`, returns the `i`th argument or panics if out of bounds. -/
|
||||
@[inline] def getArg! (e : Expr) (i : Nat) (n := e.getAppNumArgs) : Expr :=
|
||||
getRevArg! e (n - i - 1)
|
||||
|
||||
/-- Similar to `getArg!`, but skips mdata -/
|
||||
@[inline] def getArg!' (e : Expr) (i : Nat) (n := e.getAppNumArgs) : Expr :=
|
||||
getRevArg!' e (n - i - 1)
|
||||
|
||||
/-- Given `f a₀ a₁ ... aₙ`, returns the `i`th argument or returns `v₀` if out of bounds. -/
|
||||
@[inline] def getArgD (e : Expr) (i : Nat) (v₀ : Expr) (n := e.getAppNumArgs) : Expr :=
|
||||
getRevArgD e (n - i - 1) v₀
|
||||
|
|
|
|||
|
|
@ -21,20 +21,22 @@ It also provides support for the following exceptional cases.
|
|||
|
||||
/-- Returns `some n` if `e` is a raw natural number, i.e., it is of the form `.lit (.natVal n)`. -/
|
||||
def getRawNatValue? (e : Expr) : Option Nat :=
|
||||
match e with
|
||||
match e.consumeMData with
|
||||
| .lit (.natVal n) => some n
|
||||
| _ => none
|
||||
|
||||
/-- Return `some (n, type)` if `e` is an `OfNat.ofNat`-application encoding `n` for a type with name `typeDeclName`. -/
|
||||
def getOfNatValue? (e : Expr) (typeDeclName : Name) : MetaM (Option (Nat × Expr)) := OptionT.run do
|
||||
let e := e.consumeMData
|
||||
guard <| e.isAppOfArity' ``OfNat.ofNat 3
|
||||
let type ← whnfD e.appFn!.appFn!.appArg!
|
||||
let type ← whnfD (e.getArg!' 0)
|
||||
guard <| type.getAppFn.isConstOf typeDeclName
|
||||
let .lit (.natVal n) := e.appFn!.appArg! | failure
|
||||
let .lit (.natVal n) := (e.getArg!' 1).consumeMData | failure
|
||||
return (n, type)
|
||||
|
||||
/-- Return `some n` if `e` is a raw natural number or an `OfNat.ofNat`-application encoding `n`. -/
|
||||
def getNatValue? (e : Expr) : MetaM (Option Nat) := do
|
||||
let e := e.consumeMData
|
||||
if let some n := getRawNatValue? e then
|
||||
return some n
|
||||
let some (n, _) ← getOfNatValue? e ``Nat | return none
|
||||
|
|
@ -45,14 +47,14 @@ def getIntValue? (e : Expr) : MetaM (Option Int) := do
|
|||
if let some (n, _) ← getOfNatValue? e ``Int then
|
||||
return some n
|
||||
if e.isAppOfArity' ``Neg.neg 3 then
|
||||
let some (n, _) ← getOfNatValue? e.appArg!.consumeMData ``Int | return none
|
||||
let some (n, _) ← getOfNatValue? (e.getArg!' 2) ``Int | return none
|
||||
return some (-n)
|
||||
return none
|
||||
|
||||
/-- Return `some c` if `e` is a `Char.ofNat`-application encoding character `c`. -/
|
||||
def getCharValue? (e : Expr) : MetaM (Option Char) := OptionT.run do
|
||||
guard <| e.isAppOfArity' ``Char.ofNat 1
|
||||
let n ← getNatValue? e.appArg!.consumeMData
|
||||
let n ← getNatValue? (e.getArg!' 0)
|
||||
return Char.ofNat n
|
||||
|
||||
/-- Return `some s` if `e` is of the form `.lit (.strVal s)`. -/
|
||||
|
|
@ -72,8 +74,8 @@ def getFinValue? (e : Expr) : MetaM (Option ((n : Nat) × Fin n)) := OptionT.run
|
|||
/-- Return `some ⟨n, v⟩` if `e` is af `OfNat.ofNat` application encoding a `BitVec n` with value `v` -/
|
||||
def getBitVecValue? (e : Expr) : MetaM (Option ((n : Nat) × BitVec n)) := OptionT.run do
|
||||
if e.isAppOfArity' ``BitVec.ofNat 2 then
|
||||
let n ← getNatValue? e.appFn!.appArg!.consumeMData
|
||||
let v ← getNatValue? e.appArg!.consumeMData
|
||||
let n ← getNatValue? (e.getArg!' 0)
|
||||
let v ← getNatValue? (e.getArg!' 1)
|
||||
return ⟨n, BitVec.ofNat n v⟩
|
||||
let (v, type) ← getOfNatValue? e ``BitVec
|
||||
IO.println v
|
||||
|
|
|
|||
|
|
@ -101,6 +101,12 @@ private def hasNatValPattern (p : Problem) : MetaM Bool :=
|
|||
| .val v :: _ => return (← getNatValue? v).isSome
|
||||
| _ => return false
|
||||
|
||||
private def hasIntValPattern (p : Problem) : MetaM Bool :=
|
||||
p.alts.anyM fun alt => do
|
||||
match alt.patterns with
|
||||
| .val v :: _ => return (← getIntValue? v).isSome
|
||||
| _ => return false
|
||||
|
||||
private def hasVarPattern (p : Problem) : Bool :=
|
||||
p.alts.any fun alt => match alt.patterns with
|
||||
| .var _ :: _ => true
|
||||
|
|
@ -148,13 +154,20 @@ private def isArrayLitTransition (p : Problem) : Bool :=
|
|||
| .var _ :: _ => true
|
||||
| _ => false
|
||||
|
||||
private def hasCtorOrInaccessible (p : Problem) : Bool :=
|
||||
!isNextVar p ||
|
||||
p.alts.any fun alt => match alt.patterns with
|
||||
| .ctor .. :: _ => true
|
||||
| .inaccessible _ :: _ => true
|
||||
| _ => false
|
||||
|
||||
private def isNatValueTransition (p : Problem) : MetaM Bool := do
|
||||
unless (← hasNatValPattern p) do return false
|
||||
return !isNextVar p ||
|
||||
p.alts.any fun alt => match alt.patterns with
|
||||
| .ctor .. :: _ => true
|
||||
| .inaccessible _ :: _ => true
|
||||
| _ => false
|
||||
return hasCtorOrInaccessible p
|
||||
|
||||
private def isIntValueTransition (p : Problem) : MetaM Bool := do
|
||||
unless (← hasIntValPattern p) do return false
|
||||
return hasCtorOrInaccessible p
|
||||
|
||||
private def processSkipInaccessible (p : Problem) : Problem := Id.run do
|
||||
let x :: xs := p.vars | unreachable!
|
||||
|
|
@ -606,6 +619,20 @@ private def expandNatValuePattern (p : Problem) : MetaM Problem := do
|
|||
| _ => return alt
|
||||
return { p with alts := alts }
|
||||
|
||||
private def expandIntValuePattern (p : Problem) : MetaM Problem := do
|
||||
let alts ← p.alts.mapM fun alt => do
|
||||
match alt.patterns with
|
||||
| .val n :: ps =>
|
||||
match (← getIntValue? n) with
|
||||
| some i =>
|
||||
if i >= 0 then
|
||||
return { alt with patterns := .ctor ``Int.ofNat [] [] [.val (toExpr i.toNat)] :: ps }
|
||||
else
|
||||
return { alt with patterns := .ctor ``Int.negSucc [] [] [.val (toExpr (-(i + 1)).toNat)] :: ps }
|
||||
| _ => return alt
|
||||
| _ => return alt
|
||||
return { p with alts := alts }
|
||||
|
||||
private def expandFinValuePattern (p : Problem) : MetaM Problem := do
|
||||
let alts ← p.alts.mapM fun alt => do
|
||||
match alt.patterns with
|
||||
|
|
@ -665,6 +692,9 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do
|
|||
else if (← isNatValueTransition p) then
|
||||
traceStep ("nat value to constructor")
|
||||
process (← expandNatValuePattern p)
|
||||
else if (← isIntValueTransition p) then
|
||||
traceStep ("int value to constructor")
|
||||
process (← expandIntValuePattern p)
|
||||
else if (← isFinValueTransition p) then
|
||||
traceStep ("fin value to constructor")
|
||||
process (← expandFinValuePattern p)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue