feat(library/init/lean/compiler/constfolding): constant folding for strictAnd and strictOr

This commit is contained in:
Leonardo de Moura 2019-04-05 16:51:29 -07:00
parent 2cd3954198
commit 1bb920322d
3 changed files with 48 additions and 9 deletions

View file

@ -117,8 +117,36 @@ def natFoldFns : List (Name × BinFoldFn) :=
(`Nat.decLt, foldNatDecLt),
(`Nat.decLe, foldNatDecLe)]
def getBoolLit : Expr → Option Bool
| (Expr.const `Bool.true _) := some true
| (Expr.const `Bool.false _) := some false
| _ := none
def foldStrictAnd (_ : Bool) (a₁ a₂ : Expr) : Option Expr :=
let v₁ := getBoolLit a₁ in
let v₂ := getBoolLit a₂ in
match v₁, v₂ with
| some true, _ := a₂
| some false, _ := a₁
| _, some true := a₁
| _, some false := a₂
| _, _ := none
def foldStrictOr (_ : Bool) (a₁ a₂ : Expr) : Option Expr :=
let v₁ := getBoolLit a₁ in
let v₂ := getBoolLit a₂ in
match v₁, v₂ with
| some true, _ := a₁
| some false, _ := a₂
| _, some true := a₂
| _, some false := a₁
| _, _ := none
def boolFoldFns : List (Name × BinFoldFn) :=
[(`strictOr, foldStrictOr), (`strictAnd, foldStrictAnd)]
def binFoldFns : List (Name × BinFoldFn) :=
uintBinFoldFns ++ natFoldFns
boolFoldFns ++ uintBinFoldFns ++ natFoldFns
def foldNatSucc (_ : Bool) (a : Expr) : Option Expr :=
do n ← getNumLit a,
@ -135,17 +163,11 @@ def unFoldFns : List (Name × UnFoldFn) :=
[(`Nat.succ, foldNatSucc),
(`Char.ofNat, foldCharOfNat)]
-- TODO(Leo): move
private def {u} alistFind {α : Type u} (n : Name) : List (Name × α) → Option α
| [] := none
| ((k, v)::r) :=
if n = k then some v else alistFind r
def findBinFoldFn (fn : Name) : Option BinFoldFn :=
alistFind fn binFoldFns
binFoldFns.lookup fn
def findUnFoldFn (fn : Name) : Option UnFoldFn :=
alistFind fn unFoldFns
unFoldFns.lookup fn
@[export lean.fold_bin_op_core]
def foldBinOp (beforeErasure : Bool) (f : Expr) (a : Expr) (b : Expr) : Option Expr :=

View file

@ -0,0 +1,9 @@
def main : IO Unit :=
IO.println (strictOr false false) *>
IO.println (strictOr false true) *>
IO.println (strictOr true false) *>
IO.println (strictOr true true) *>
IO.println (strictAnd false false) *>
IO.println (strictAnd false true) *>
IO.println (strictAnd true false) *>
IO.println (strictAnd true true)

View file

@ -0,0 +1,8 @@
false
true
true
true
false
false
false
true