feat: simprocs for folding numeric literals (#3586)
This PR folds exposed `BitVec` (`Fin`, `UInt??`, and `Int`) ground literals. cc @shigoel
This commit is contained in:
parent
3ad078fec9
commit
bba4ef3728
7 changed files with 129 additions and 57 deletions
|
|
@ -18,9 +18,10 @@ private abbrev withInstantiatedMVars (e : Expr) (k : Expr → OptionT MetaM α)
|
|||
k eNew
|
||||
|
||||
def isNatProjInst (declName : Name) (numArgs : Nat) : Bool :=
|
||||
(numArgs == 4 && (declName == ``Add.add || declName == ``Sub.sub || declName == ``Mul.mul))
|
||||
|| (numArgs == 6 && (declName == ``HAdd.hAdd || declName == ``HSub.hSub || declName == ``HMul.hMul))
|
||||
|| (numArgs == 3 && declName == ``OfNat.ofNat)
|
||||
(numArgs == 4 && (declName == ``Add.add || declName == ``Sub.sub || declName == ``Mul.mul || declName == ``Div.div || declName == ``Mod.mod || declName == ``NatPow.pow))
|
||||
|| (numArgs == 5 && (declName == ``Pow.pow))
|
||||
|| (numArgs == 6 && (declName == ``HAdd.hAdd || declName == ``HSub.hSub || declName == ``HMul.hMul || declName == ``HDiv.hDiv || declName == ``HMod.hMod || declName == ``HPow.hPow))
|
||||
|| (numArgs == 3 && declName == ``OfNat.ofNat)
|
||||
|
||||
/--
|
||||
Evaluate simple `Nat` expressions.
|
||||
|
|
@ -35,31 +36,21 @@ partial def evalNat (e : Expr) : OptionT MetaM Nat := do
|
|||
| _ => failure
|
||||
where
|
||||
visit e := do
|
||||
let f := e.getAppFn
|
||||
match f with
|
||||
| .mvar .. => withInstantiatedMVars e evalNat
|
||||
| .const c _ =>
|
||||
let nargs := e.getAppNumArgs
|
||||
if c == ``Nat.succ && nargs == 1 then
|
||||
let v ← evalNat (e.getArg! 0)
|
||||
return v+1
|
||||
else if c == ``Nat.add && nargs == 2 then
|
||||
let v₁ ← evalNat (e.getArg! 0)
|
||||
let v₂ ← evalNat (e.getArg! 1)
|
||||
return v₁ + v₂
|
||||
else if c == ``Nat.sub && nargs == 2 then
|
||||
let v₁ ← evalNat (e.getArg! 0)
|
||||
let v₂ ← evalNat (e.getArg! 1)
|
||||
return v₁ - v₂
|
||||
else if c == ``Nat.mul && nargs == 2 then
|
||||
let v₁ ← evalNat (e.getArg! 0)
|
||||
let v₂ ← evalNat (e.getArg! 1)
|
||||
return v₁ * v₂
|
||||
else if isNatProjInst c nargs then
|
||||
match_expr e with
|
||||
| Nat.succ a => return (← evalNat a) + 1
|
||||
| Nat.add a b => return (← evalNat a) + (← evalNat b)
|
||||
| Nat.sub a b => return (← evalNat a) - (← evalNat b)
|
||||
| Nat.mul a b => return (← evalNat a) * (← evalNat b)
|
||||
| Nat.div a b => return (← evalNat a) / (← evalNat b)
|
||||
| Nat.mod a b => return (← evalNat a) % (← evalNat b)
|
||||
| Nat.pow a b => return (← evalNat a) ^ (← evalNat b)
|
||||
| _ =>
|
||||
let e ← instantiateMVarsIfMVarApp e
|
||||
let f := e.getAppFn
|
||||
if f.isConst && isNatProjInst f.constName! e.getAppNumArgs then
|
||||
evalNat (← unfoldProjInst? e)
|
||||
else
|
||||
failure
|
||||
| _ => failure
|
||||
|
||||
mutual
|
||||
|
||||
|
|
|
|||
|
|
@ -74,42 +74,31 @@ def addAsVar (e : Expr) : M LinearExpr := do
|
|||
|
||||
partial def toLinearExpr (e : Expr) : M LinearExpr := do
|
||||
match e with
|
||||
| Expr.lit (Literal.natVal n) => return num n
|
||||
| Expr.mdata _ e => toLinearExpr e
|
||||
| Expr.const ``Nat.zero .. => return num 0
|
||||
| Expr.app .. => visit e
|
||||
| Expr.mvar .. => visit e
|
||||
| _ => addAsVar e
|
||||
| .lit (.natVal n) => return num n
|
||||
| .mdata _ e => toLinearExpr e
|
||||
| .const ``Nat.zero .. => return num 0
|
||||
| .app .. => visit e
|
||||
| .mvar .. => visit e
|
||||
| _ => addAsVar e
|
||||
where
|
||||
visit (e : Expr) : M LinearExpr := do
|
||||
let f := e.getAppFn
|
||||
match f with
|
||||
| Expr.mvar .. =>
|
||||
let eNew ← instantiateMVars e
|
||||
if eNew != e then
|
||||
toLinearExpr eNew
|
||||
match_expr e with
|
||||
| Nat.succ a => return inc (← toLinearExpr a)
|
||||
| Nat.add a b => return add (← toLinearExpr a) (← toLinearExpr b)
|
||||
| Nat.mul a b =>
|
||||
match (← evalNat a |>.run) with
|
||||
| some k => return mulL k (← toLinearExpr b)
|
||||
| none => match (← evalNat b |>.run) with
|
||||
| some k => return mulR (← toLinearExpr a) k
|
||||
| none => addAsVar e
|
||||
| _ =>
|
||||
let e ← instantiateMVarsIfMVarApp e
|
||||
let f := e.getAppFn
|
||||
if f.isConst && isNatProjInst f.constName! e.getAppNumArgs then
|
||||
let some e ← unfoldProjInst? e | addAsVar e
|
||||
toLinearExpr e
|
||||
else
|
||||
addAsVar e
|
||||
| Expr.const declName .. =>
|
||||
let numArgs := e.getAppNumArgs
|
||||
if declName == ``Nat.succ && numArgs == 1 then
|
||||
return inc (← toLinearExpr e.appArg!)
|
||||
else if declName == ``Nat.add && numArgs == 2 then
|
||||
return add (← toLinearExpr (e.getArg! 0)) (← toLinearExpr (e.getArg! 1))
|
||||
else if declName == ``Nat.mul && numArgs == 2 then
|
||||
match (← evalNat (e.getArg! 0) |>.run) with
|
||||
| some k => return mulL k (← toLinearExpr (e.getArg! 1))
|
||||
| none => match (← evalNat (e.getArg! 1) |>.run) with
|
||||
| some k => return mulR (← toLinearExpr (e.getArg! 0)) k
|
||||
| none => addAsVar e
|
||||
else if isNatProjInst declName numArgs then
|
||||
if let some e ← unfoldProjInst? e then
|
||||
toLinearExpr e
|
||||
else
|
||||
addAsVar e
|
||||
else
|
||||
addAsVar e
|
||||
| _ => addAsVar e
|
||||
|
||||
partial def toLinearCnstr? (e : Expr) : M (Option LinearCnstr) := do
|
||||
let f := e.getAppFn
|
||||
|
|
|
|||
|
|
@ -268,4 +268,15 @@ builtin_simproc [simp, seval] reduceAllOnes (allOnes _) := fun e => do
|
|||
let some n ← Nat.fromExpr? n | return .continue
|
||||
return .done { expr := toExpr (allOnes n) }
|
||||
|
||||
builtin_simproc [simp, seval] reduceBitVecOfFin (BitVec.ofFin _) := fun e => do
|
||||
let_expr BitVec.ofFin w v ← e | return .continue
|
||||
let some w ← evalNat w |>.run | return .continue
|
||||
let some ⟨_, v⟩ ← getFinValue? v | return .continue
|
||||
return .done { expr := toExpr (BitVec.ofNat w v.val) }
|
||||
|
||||
builtin_simproc [simp, seval] reduceBitVecToFin (BitVec.toFin _) := fun e => do
|
||||
let_expr BitVec.toFin _ v ← e | return .continue
|
||||
let some ⟨_, v⟩ ← getBitVecValue? v | return .continue
|
||||
return .done { expr := toExpr v.toFin }
|
||||
|
||||
end BitVec
|
||||
|
|
|
|||
|
|
@ -71,4 +71,13 @@ builtin_simproc [simp, seval] isValue ((OfNat.ofNat _ : Fin _)) := fun e => do
|
|||
return .done { expr := e }
|
||||
return .done { expr := toExpr v }
|
||||
|
||||
builtin_simproc [simp, seval] reduceFinMk (Fin.mk _ _) := fun e => do
|
||||
let_expr Fin.mk n v _ ← e | return .continue
|
||||
let some n ← evalNat n |>.run | return .continue
|
||||
let some v ← getNatValue? v | return .continue
|
||||
if h : n > 0 then
|
||||
return .done { expr := toExpr (Fin.ofNat' v h) }
|
||||
else
|
||||
return .continue
|
||||
|
||||
end Fin
|
||||
|
|
|
|||
|
|
@ -89,4 +89,14 @@ builtin_simproc [simp, seval] reduceBNe (( _ : Int) != _) := reduceBoolPred ``
|
|||
builtin_simproc [simp, seval] reduceAbs (natAbs _) := reduceNatCore ``natAbs natAbs
|
||||
builtin_simproc [simp, seval] reduceToNat (Int.toNat _) := reduceNatCore ``Int.toNat Int.toNat
|
||||
|
||||
builtin_simproc [simp, seval] reduceNegSucc (Int.negSucc _) := fun e => do
|
||||
let_expr Int.negSucc a ← e | return .continue
|
||||
let some a ← getNatValue? a | return .continue
|
||||
return .done { expr := toExpr (-(Int.ofNat a + 1)) }
|
||||
|
||||
builtin_simproc [simp, seval] reduceOfNat (Int.ofNat _) := fun e => do
|
||||
let_expr Int.ofNat a ← e | return .continue
|
||||
let some a ← getNatValue? a | return .continue
|
||||
return .done { expr := toExpr (Int.ofNat a) }
|
||||
|
||||
end Int
|
||||
|
|
|
|||
|
|
@ -60,6 +60,12 @@ builtin_simproc [simp, seval] $(mkIdent `reduceOfNatCore):ident ($ofNatCore _ _)
|
|||
let value := $(mkIdent ofNat) value
|
||||
return .done { expr := toExpr value }
|
||||
|
||||
builtin_simproc [simp, seval] $(mkIdent `reduceOfNat):ident ($(mkIdent ofNat) _) := fun e => do
|
||||
unless e.isAppOfArity $(quote ofNat) 1 do return .continue
|
||||
let some value ← Nat.fromExpr? e.appArg! | return .continue
|
||||
let value := $(mkIdent ofNat) value
|
||||
return .done { expr := toExpr value }
|
||||
|
||||
builtin_simproc [simp, seval] $(mkIdent `reduceToNat):ident ($toNat _) := fun e => do
|
||||
unless e.isAppOfArity $(quote toNat.getId) 1 do return .continue
|
||||
let some v ← ($fromExpr e.appArg!) | return .continue
|
||||
|
|
|
|||
56
tests/lean/run/foldLits.lean
Normal file
56
tests/lean/run/foldLits.lean
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
open BitVec
|
||||
|
||||
example : (Fin.mk 5 (by decide) : Fin 10) + 2 = x := by
|
||||
simp
|
||||
guard_target =ₛ 7 = x
|
||||
sorry
|
||||
|
||||
example : (Fin.mk 5 (by decide) : Fin 10) + 2 = x := by
|
||||
simp (config := { ground := true }) only
|
||||
guard_target =ₛ 7 = x
|
||||
sorry
|
||||
|
||||
example : (BitVec.ofFin (Fin.mk 2 (by decide)) : BitVec 32) + 2 = x := by
|
||||
simp
|
||||
guard_target =ₛ 4#32 = x
|
||||
sorry
|
||||
|
||||
example : (BitVec.ofFin (Fin.mk 2 (by decide)) : BitVec 32) + 2 = x := by
|
||||
simp (config := { ground := true }) only
|
||||
guard_target =ₛ 4#32 = x
|
||||
sorry
|
||||
|
||||
example : (BitVec.ofFin 2 : BitVec 32) + 2 = x := by
|
||||
simp
|
||||
guard_target =ₛ 4#32 = x
|
||||
sorry
|
||||
|
||||
example (h : -2 = x) : Int.negSucc 3 + 2 = x := by
|
||||
simp
|
||||
guard_target =ₛ -2 = x
|
||||
assumption
|
||||
|
||||
example (h : -2 = x) : Int.negSucc 3 + 2 = x := by
|
||||
simp (config := { ground := true }) only
|
||||
guard_target =ₛ -2 = x
|
||||
assumption
|
||||
|
||||
example : Int.ofNat 3 + 2 = x := by
|
||||
simp
|
||||
guard_target =ₛ 5 = x
|
||||
sorry
|
||||
|
||||
example : Int.ofNat 3 + 2 = x := by
|
||||
simp (config := { ground := true }) only
|
||||
guard_target =ₛ 5 = x
|
||||
sorry
|
||||
|
||||
example (h : 5 = x) : UInt32.ofNat 2 + 3 = x := by
|
||||
simp
|
||||
guard_target =ₛ 5 = x
|
||||
assumption
|
||||
|
||||
example (h : 5 = x) : UInt32.ofNat 2 + 3 = x := by
|
||||
simp (config := { ground := true }) only
|
||||
guard_target =ₛ 5 = x
|
||||
assumption
|
||||
Loading…
Add table
Reference in a new issue