lean4-htt/library/init/lean/compiler/constfolding.lean
2019-03-21 15:06:44 -07:00

167 lines
5.5 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import init.lean.expr init.platform
import init.lean.compiler.util
/- Constant folding for primitives that have special runtime support. -/
namespace Lean
namespace Compiler
def BinFoldFn := Bool → Expr → Expr → Option Expr
def UnFoldFn := Bool → Expr → Option Expr
def mkUIntTypeName (nbytes : Nat) : Name :=
mkSimpleName ("UInt" ++ toString nbytes)
structure NumScalarTypeInfo :=
(nbits : Nat)
(id : Name := mkUIntTypeName nbits)
(ofNatFn : Name := Name.mkString id "ofNat")
(size : Nat := 2^nbits)
def numScalarTypes : List NumScalarTypeInfo :=
[{nbits := 8}, {nbits := 16}, {nbits := 32}, {nbits := 64},
{id := `Usize, nbits := System.platform.nbits}]
def isOfNat (fn : Name) : Bool :=
numScalarTypes.any (λ info, info.ofNatFn = fn)
def getInfoFromFn (fn : Name) : List NumScalarTypeInfo → Option NumScalarTypeInfo
| [] := none
| (info::infos) :=
if info.ofNatFn = fn then some info
else getInfoFromFn infos
def getInfoFromVal : Expr → Option NumScalarTypeInfo
| (Expr.app (Expr.const fn _) _) := getInfoFromFn fn numScalarTypes
| _ := none
@[export lean.get_num_lit_core]
def getNumLit : Expr → Option Nat
| (Expr.lit (Literal.natVal n)) := some n
| (Expr.app (Expr.const fn _) a) := if isOfNat fn then getNumLit a else none
| _ := none
def mkUIntLit (info : NumScalarTypeInfo) (n : Nat) : Expr :=
Expr.app (Expr.const info.ofNatFn []) (Expr.lit (Literal.natVal (n%info.size)))
def mkUInt32Lit (n : Nat) : Expr :=
mkUIntLit {nbits := 32} n
def foldBinUInt (fn : NumScalarTypeInfo → Bool → Nat → Nat → Nat) (beforeErasure : Bool) (a₁ a₂ : Expr) : Option Expr :=
do n₁ ← getNumLit a₁,
n₂ ← getNumLit a₂,
info ← getInfoFromVal a₁,
pure $ mkUIntLit info (fn info beforeErasure n₁ n₂)
def foldUIntAdd := foldBinUInt $ λ _ _, (+)
def foldUIntMul := foldBinUInt $ λ _ _, (*)
def foldUIntDiv := foldBinUInt $ λ _ _, (/)
def foldUIntMod := foldBinUInt $ λ _ _, (%)
def foldUIntSub := foldBinUInt $ λ info _ a b, (a + (info.size - b))
def preUIntBinFoldFns : List (Name × BinFoldFn) :=
[(`add, foldUIntAdd), (`mul, foldUIntMul), (`div, foldUIntDiv),
(`mod, foldUIntMod), (`sub, foldUIntSub)]
def uintBinFoldFns : List (Name × BinFoldFn) :=
numScalarTypes.foldl (λ r info, r ++ (preUIntBinFoldFns.map (λ ⟨suffix, fn⟩, (info.id ++ suffix, fn)))) []
def foldNatBinOp (fn : Nat → Nat → Nat) (a₁ a₂ : Expr) : Option Expr :=
do n₁ ← getNumLit a₁,
n₂ ← getNumLit a₂,
pure $ Expr.lit (Literal.natVal (fn n₁ n₂))
def foldNatAdd (_ : Bool) := foldNatBinOp (+)
def foldNatMul (_ : Bool) := foldNatBinOp (*)
def foldNatDiv (_ : Bool) := foldNatBinOp (/)
def foldNatMod (_ : Bool) := foldNatBinOp (%)
def mkNatEq (a b : Expr) : Expr :=
mkBinApp (Expr.app (Expr.const `Eq [Level.one]) (Expr.const `Nat [])) a b
def mkNatLt (a b : Expr) : Expr :=
mkBinApp (mkBinApp (Expr.const `HasLt.lt [Level.zero]) (Expr.const `Nat []) (Expr.const `Nat.HasLt [])) a b
def mkNatLe (a b : Expr) : Expr :=
mkBinApp (mkBinApp (Expr.const `HasLt.le [Level.zero]) (Expr.const `Nat []) (Expr.const `Nat.HasLe [])) a b
def toDecidableExpr (beforeErasure : Bool) (pred : Expr) (r : Bool) : Expr :=
match beforeErasure, r with
| false, true := mkDecIsTrue neutralExpr neutralExpr
| false, false := mkDecIsFalse neutralExpr neutralExpr
| true, true := mkDecIsTrue pred (mkLcProof pred)
| true, false := mkDecIsFalse pred (mkLcProof pred)
def foldNatBinPred (mkPred : Expr → Expr → Expr) (fn : Nat → Nat → Bool)
(beforeErasure : Bool) (a₁ a₂ : Expr) : Option Expr :=
do n₁ ← getNumLit a₁,
n₂ ← getNumLit a₂,
pure $ toDecidableExpr beforeErasure (mkPred a₁ a₂) (fn n₁ n₂)
def foldNatDecEq := foldNatBinPred mkNatEq (λ a b, a = b)
def foldNatDecLt := foldNatBinPred mkNatLt (λ a b, a < b)
def foldNatDecLe := foldNatBinPred mkNatLe (λ a b, a ≤ b)
def natFoldFns : List (Name × BinFoldFn) :=
[(`Nat.add, foldNatAdd),
(`Nat.mul, foldNatMul),
(`Nat.div, foldNatDiv),
(`Nat.mod, foldNatMod),
(`Nat.decEq, foldNatDecEq),
(`Nat.decLt, foldNatDecLt),
(`Nat.decLe, foldNatDecLe)]
def binFoldFns : List (Name × BinFoldFn) :=
uintBinFoldFns ++ natFoldFns
def foldNatSucc (_ : Bool) (a : Expr) : Option Expr :=
do n ← getNumLit a,
pure $ Expr.lit (Literal.natVal (n+1))
def foldCharOfNat (beforeErasure : Bool) (a : Expr) : Option Expr :=
do guard (!beforeErasure),
n ← getNumLit a,
pure $
if isValidChar (UInt32.ofNat n) then mkUInt32Lit n
else mkUInt32Lit 0
def unFoldFns : List (Name × UnFoldFn) :=
[(`Nat.succ, foldNatSucc),
(`Char.ofNat, foldCharOfNat)]
-- TODO(Leo): move
private def {u} alistFind {α : Type u} (n : Name) : List (Name × α) → Option α
| [] := none
| ((k, v)::r) :=
if n = k then some v else alistFind r
def findBinFoldFn (fn : Name) : Option BinFoldFn :=
alistFind fn binFoldFns
def findUnFoldFn (fn : Name) : Option UnFoldFn :=
alistFind fn unFoldFns
@[export lean.fold_bin_op_core]
def foldBinOp (beforeErasure : Bool) (f : Expr) (a : Expr) (b : Expr) : Option Expr :=
match f with
| Expr.const fn _ := do
foldFn ← findBinFoldFn fn,
foldFn beforeErasure a b
| _ := none
@[export lean.fold_un_op_core]
def foldUnOp (beforeErasure : Bool) (f : Expr) (a : Expr) : Option Expr :=
match f with
| Expr.const fn _ := do
foldFn ← findUnFoldFn fn,
foldFn beforeErasure a
| _ := none
end Compiler
end Lean