TODO: remove `OptionM` after update stage0 see: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Do.20we.20still.20need.20OptionM.3F/near/279761084
228 lines
7.4 KiB
Text
228 lines
7.4 KiB
Text
/-
|
||
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Leonardo de Moura
|
||
-/
|
||
import Lean.Expr
|
||
import Lean.Compiler.Util
|
||
|
||
/- Constant folding for primitives that have special runtime support. -/
|
||
|
||
namespace Lean.Compiler
|
||
|
||
abbrev BinFoldFn := Bool → Expr → Expr → Option Expr
|
||
abbrev UnFoldFn := Bool → Expr → Option Expr
|
||
|
||
def mkUIntTypeName (nbytes : Nat) : Name :=
|
||
Name.mkSimple ("UInt" ++ toString nbytes)
|
||
|
||
structure NumScalarTypeInfo where
|
||
nbits : Nat
|
||
id : Name := mkUIntTypeName nbits
|
||
ofNatFn : Name := Name.mkStr id "ofNat"
|
||
toNatFn : Name := Name.mkStr id "toNat"
|
||
size : Nat := 2^nbits
|
||
|
||
def numScalarTypes : List NumScalarTypeInfo :=
|
||
[{nbits := 8}, {nbits := 16}, {nbits := 32}, {nbits := 64},
|
||
{id := ``USize, nbits := System.Platform.numBits}]
|
||
|
||
def isOfNat (fn : Name) : Bool :=
|
||
numScalarTypes.any (fun info => info.ofNatFn == fn)
|
||
|
||
def isToNat (fn : Name) : Bool :=
|
||
numScalarTypes.any (fun info => info.toNatFn == fn)
|
||
|
||
def getInfoFromFn (fn : Name) : List NumScalarTypeInfo → Option NumScalarTypeInfo
|
||
| [] => none
|
||
| info::infos =>
|
||
if info.ofNatFn == fn then some info
|
||
else getInfoFromFn fn infos
|
||
|
||
def getInfoFromVal : Expr → Option NumScalarTypeInfo
|
||
| Expr.app (Expr.const fn _ _) _ _ => getInfoFromFn fn numScalarTypes
|
||
| _ => none
|
||
|
||
@[export lean_get_num_lit]
|
||
def getNumLit : Expr → Option Nat
|
||
| Expr.lit (Literal.natVal n) _ => some n
|
||
| Expr.app (Expr.const fn _ _) a _ => if isOfNat fn then getNumLit a else none
|
||
| _ => none
|
||
|
||
def mkUIntLit (info : NumScalarTypeInfo) (n : Nat) : Expr :=
|
||
mkApp (mkConst info.ofNatFn) (mkRawNatLit (n%info.size))
|
||
|
||
def mkUInt32Lit (n : Nat) : Expr :=
|
||
mkUIntLit {nbits := 32} n
|
||
|
||
def foldBinUInt (fn : NumScalarTypeInfo → Bool → Nat → Nat → Nat) (beforeErasure : Bool) (a₁ a₂ : Expr) : Option Expr := do
|
||
let n₁ ← getNumLit a₁
|
||
let n₂ ← getNumLit a₂
|
||
let info ← getInfoFromVal a₁
|
||
return mkUIntLit info (fn info beforeErasure n₁ n₂)
|
||
|
||
def foldUIntAdd := foldBinUInt fun _ _ => Add.add
|
||
def foldUIntMul := foldBinUInt fun _ _ => Mul.mul
|
||
def foldUIntDiv := foldBinUInt fun _ _ => Div.div
|
||
def foldUIntMod := foldBinUInt fun _ _ => Mod.mod
|
||
def foldUIntSub := foldBinUInt fun info _ a b => (a + (info.size - b))
|
||
|
||
def preUIntBinFoldFns : List (Name × BinFoldFn) :=
|
||
[(`add, foldUIntAdd), (`mul, foldUIntMul), (`div, foldUIntDiv),
|
||
(`mod, foldUIntMod), (`sub, foldUIntSub)]
|
||
|
||
def uintBinFoldFns : List (Name × BinFoldFn) :=
|
||
numScalarTypes.foldl (fun r info => r ++ (preUIntBinFoldFns.map (fun ⟨suffix, fn⟩ => (info.id ++ suffix, fn)))) []
|
||
|
||
def foldNatBinOp (fn : Nat → Nat → Nat) (a₁ a₂ : Expr) : Option Expr := do
|
||
let n₁ ← getNumLit a₁
|
||
let n₂ ← getNumLit a₂
|
||
return mkRawNatLit (fn n₁ n₂)
|
||
|
||
def foldNatAdd (_ : Bool) := foldNatBinOp Add.add
|
||
def foldNatMul (_ : Bool) := foldNatBinOp Mul.mul
|
||
def foldNatDiv (_ : Bool) := foldNatBinOp Div.div
|
||
def foldNatMod (_ : Bool) := foldNatBinOp Mod.mod
|
||
|
||
-- TODO: add option for controlling the limit
|
||
def natPowThreshold := 256
|
||
|
||
def foldNatPow (_ : Bool) (a₁ a₂ : Expr) : Option Expr := do
|
||
let n₁ ← getNumLit a₁
|
||
let n₂ ← getNumLit a₂
|
||
if n₂ < natPowThreshold then
|
||
return mkRawNatLit (n₁ ^ n₂)
|
||
else
|
||
failure
|
||
|
||
def mkNatEq (a b : Expr) : Expr :=
|
||
mkAppN (mkConst ``Eq [levelOne]) #[(mkConst `Nat), a, b]
|
||
|
||
def mkNatLt (a b : Expr) : Expr :=
|
||
mkAppN (mkConst ``LT.lt [levelZero]) #[mkConst ``Nat, mkConst ``Nat.lt, a, b]
|
||
|
||
def mkNatLe (a b : Expr) : Expr :=
|
||
mkAppN (mkConst ``LE.le [levelZero]) #[mkConst ``Nat, mkConst ``Nat.le, a, b]
|
||
|
||
def toDecidableExpr (beforeErasure : Bool) (pred : Expr) (r : Bool) : Expr :=
|
||
match beforeErasure, r with
|
||
| false, true => mkConst ``Bool.true
|
||
| false, false => mkConst ``Bool.false
|
||
| true, true => mkDecIsTrue pred (mkLcProof pred)
|
||
| true, false => mkDecIsFalse pred (mkLcProof pred)
|
||
|
||
def foldNatBinPred (mkPred : Expr → Expr → Expr) (fn : Nat → Nat → Bool)
|
||
(beforeErasure : Bool) (a₁ a₂ : Expr) : Option Expr := do
|
||
let n₁ ← getNumLit a₁
|
||
let n₂ ← getNumLit a₂
|
||
return toDecidableExpr beforeErasure (mkPred a₁ a₂) (fn n₁ n₂)
|
||
|
||
def foldNatDecEq := foldNatBinPred mkNatEq (fun a b => a = b)
|
||
def foldNatDecLt := foldNatBinPred mkNatLt (fun a b => a < b)
|
||
def foldNatDecLe := foldNatBinPred mkNatLe (fun a b => a ≤ b)
|
||
|
||
def foldNatBinBoolPred (fn : Nat → Nat → Bool) (a₁ a₂ : Expr) : Option Expr := do
|
||
let n₁ ← getNumLit a₁
|
||
let n₂ ← getNumLit a₂
|
||
if fn n₁ n₂ then
|
||
return mkConst ``Bool.true
|
||
else
|
||
return mkConst ``Bool.false
|
||
|
||
def foldNatBeq := fun _ : Bool => foldNatBinBoolPred (fun a b => a == b)
|
||
def foldNatBle := fun _ : Bool => foldNatBinBoolPred (fun a b => a < b)
|
||
def foldNatBlt := fun _ : Bool => foldNatBinBoolPred (fun a b => a ≤ b)
|
||
|
||
def natFoldFns : List (Name × BinFoldFn) :=
|
||
[(``Nat.add, foldNatAdd),
|
||
(``Nat.mul, foldNatMul),
|
||
(``Nat.div, foldNatDiv),
|
||
(``Nat.mod, foldNatMod),
|
||
(``Nat.pow, foldNatPow),
|
||
(``Nat.decEq, foldNatDecEq),
|
||
(``Nat.decLt, foldNatDecLt),
|
||
(``Nat.decLe, foldNatDecLe),
|
||
(``Nat.beq, foldNatBeq),
|
||
(``Nat.blt, foldNatBlt),
|
||
(``Nat.ble, foldNatBle)
|
||
]
|
||
|
||
def getBoolLit : Expr → Option Bool
|
||
| Expr.const ``Bool.true _ _ => some true
|
||
| Expr.const ``Bool.false _ _ => some false
|
||
| _ => none
|
||
|
||
def foldStrictAnd (_ : Bool) (a₁ a₂ : Expr) : Option Expr :=
|
||
let v₁ := getBoolLit a₁
|
||
let v₂ := getBoolLit a₂
|
||
match v₁, v₂ with
|
||
| some true, _ => a₂
|
||
| some false, _ => a₁
|
||
| _, some true => a₁
|
||
| _, some false => a₂
|
||
| _, _ => none
|
||
|
||
def foldStrictOr (_ : Bool) (a₁ a₂ : Expr) : Option Expr :=
|
||
let v₁ := getBoolLit a₁
|
||
let v₂ := getBoolLit a₂
|
||
match v₁, v₂ with
|
||
| some true, _ => a₁
|
||
| some false, _ => a₂
|
||
| _, some true => a₂
|
||
| _, some false => a₁
|
||
| _, _ => none
|
||
|
||
def boolFoldFns : List (Name × BinFoldFn) :=
|
||
[(``strictOr, foldStrictOr), (``strictAnd, foldStrictAnd)]
|
||
|
||
def binFoldFns : List (Name × BinFoldFn) :=
|
||
boolFoldFns ++ uintBinFoldFns ++ natFoldFns
|
||
|
||
def foldNatSucc (_ : Bool) (a : Expr) : Option Expr := do
|
||
let n ← getNumLit a
|
||
return mkRawNatLit (n+1)
|
||
|
||
def foldCharOfNat (beforeErasure : Bool) (a : Expr) : Option Expr := do
|
||
guard (!beforeErasure)
|
||
let n ← getNumLit a
|
||
if isValidChar n.toUInt32 then
|
||
return mkUInt32Lit n
|
||
else
|
||
return mkUInt32Lit 0
|
||
|
||
def foldToNat (_ : Bool) (a : Expr) : Option Expr := do
|
||
let n ← getNumLit a
|
||
return mkRawNatLit n
|
||
|
||
def uintFoldToNatFns : List (Name × UnFoldFn) :=
|
||
numScalarTypes.foldl (fun r info => (info.toNatFn, foldToNat) :: r) []
|
||
|
||
def unFoldFns : List (Name × UnFoldFn) :=
|
||
[(``Nat.succ, foldNatSucc),
|
||
(``Char.ofNat, foldCharOfNat)]
|
||
++ uintFoldToNatFns
|
||
|
||
def findBinFoldFn (fn : Name) : Option BinFoldFn :=
|
||
binFoldFns.lookup fn
|
||
|
||
def findUnFoldFn (fn : Name) : Option UnFoldFn :=
|
||
unFoldFns.lookup fn
|
||
|
||
@[export lean_fold_bin_op]
|
||
def foldBinOp (beforeErasure : Bool) (f : Expr) (a : Expr) (b : Expr) : Option Expr := do
|
||
match f with
|
||
| Expr.const fn _ _ =>
|
||
let foldFn ← findBinFoldFn fn
|
||
foldFn beforeErasure a b
|
||
| _ =>
|
||
failure
|
||
|
||
@[export lean_fold_un_op]
|
||
def foldUnOp (beforeErasure : Bool) (f : Expr) (a : Expr) : Option Expr := do
|
||
match f with
|
||
| Expr.const fn _ _ =>
|
||
let foldFn ← findUnFoldFn fn
|
||
foldFn beforeErasure a
|
||
| _ => failure
|
||
|
||
end Lean.Compiler
|