167 lines
5.9 KiB
Text
167 lines
5.9 KiB
Text
/-
|
||
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Leonardo de Moura
|
||
-/
|
||
prelude
|
||
import init.lean.expr init.platform
|
||
import init.lean.compiler.util
|
||
|
||
/- Constant folding for primitives that have special runtime support. -/
|
||
|
||
namespace lean
|
||
namespace compiler
|
||
|
||
def bin_fold_fn := bool → expr → expr → option expr
|
||
def un_fold_fn := bool → expr → option expr
|
||
|
||
def mk_uint_type_name (nbytes : nat) : name :=
|
||
mk_simple_name ("uint" ++ to_string nbytes)
|
||
|
||
structure num_scalar_type_info :=
|
||
(nbits : nat)
|
||
(id : name := mk_uint_type_name nbits)
|
||
(of_nat_fn : name := name.mk_string id "of_nat")
|
||
(size : nat := 2^nbits)
|
||
|
||
def num_scalar_types : list num_scalar_type_info :=
|
||
[{nbits := 8}, {nbits := 16}, {nbits := 32}, {nbits := 64},
|
||
{id := `usize, nbits := system.platform.nbits}]
|
||
|
||
def is_of_nat (fn : name) : bool :=
|
||
num_scalar_types.any (λ info, info.of_nat_fn = fn)
|
||
|
||
def get_info_from_fn (fn : name) : list num_scalar_type_info → option num_scalar_type_info
|
||
| [] := none
|
||
| (info::infos) :=
|
||
if info.of_nat_fn = fn then some info
|
||
else get_info_from_fn infos
|
||
|
||
def get_info_from_val : expr → option num_scalar_type_info
|
||
| (expr.app (expr.const fn _) _) := get_info_from_fn fn num_scalar_types
|
||
| _ := none
|
||
|
||
@[export lean.get_num_lit_core]
|
||
def get_num_lit : expr → option nat
|
||
| (expr.lit (literal.nat_val n)) := some n
|
||
| (expr.app (expr.const fn _) a) := if is_of_nat fn then get_num_lit a else none
|
||
| _ := none
|
||
|
||
def mk_uint_lit (info : num_scalar_type_info) (n : nat) : expr :=
|
||
expr.app (expr.const info.of_nat_fn []) (expr.lit (literal.nat_val (n%info.size)))
|
||
|
||
def mk_uint32_lit (n : nat) : expr :=
|
||
mk_uint_lit {nbits := 32} n
|
||
|
||
def fold_bin_uint (fn : num_scalar_type_info → bool → nat → nat → nat) (before_erasure : bool) (a₁ a₂ : expr) : option expr :=
|
||
do n₁ ← get_num_lit a₁,
|
||
n₂ ← get_num_lit a₂,
|
||
info ← get_info_from_val a₁,
|
||
pure $ mk_uint_lit info (fn info before_erasure n₁ n₂)
|
||
|
||
def fold_uint_add := fold_bin_uint $ λ _ _, (+)
|
||
def fold_uint_mul := fold_bin_uint $ λ _ _, (*)
|
||
def fold_uint_div := fold_bin_uint $ λ _ _, (/)
|
||
def fold_uint_mod := fold_bin_uint $ λ _ _, (%)
|
||
def fold_uint_sub := fold_bin_uint $ λ info _ a b, (a + (info.size - b))
|
||
|
||
def pre_uint_bin_fold_fns : list (name × bin_fold_fn) :=
|
||
[(`add, fold_uint_add), (`mul, fold_uint_mul), (`div, fold_uint_div),
|
||
(`mod, fold_uint_mod), (`sub, fold_uint_sub)]
|
||
|
||
def uint_bin_fold_fns : list (name × bin_fold_fn) :=
|
||
num_scalar_types.foldl (λ r info, r ++ (pre_uint_bin_fold_fns.map (λ ⟨suffix, fn⟩, (info.id ++ suffix, fn)))) []
|
||
|
||
def fold_nat_bin_op (fn : nat → nat → nat) (a₁ a₂ : expr) : option expr :=
|
||
do n₁ ← get_num_lit a₁,
|
||
n₂ ← get_num_lit a₂,
|
||
pure $ expr.lit (literal.nat_val (fn n₁ n₂))
|
||
|
||
def fold_nat_add (_ : bool) := fold_nat_bin_op (+)
|
||
def fold_nat_mul (_ : bool) := fold_nat_bin_op (*)
|
||
def fold_nat_div (_ : bool) := fold_nat_bin_op (/)
|
||
def fold_nat_mod (_ : bool) := fold_nat_bin_op (%)
|
||
|
||
def mk_nat_eq (a b : expr) : expr :=
|
||
mk_bin_app (expr.app (expr.const `eq [level.one]) (expr.const `nat [])) a b
|
||
|
||
def mk_nat_lt (a b : expr) : expr :=
|
||
mk_bin_app (mk_bin_app (expr.const `has_lt.lt [level.zero]) (expr.const `nat []) (expr.const `nat.has_lt [])) a b
|
||
|
||
def mk_nat_le (a b : expr) : expr :=
|
||
mk_bin_app (mk_bin_app (expr.const `has_lt.le [level.zero]) (expr.const `nat []) (expr.const `nat.has_le [])) a b
|
||
|
||
def to_decidable_expr (before_erasure : bool) (pred : expr) (r : bool) : expr :=
|
||
match before_erasure, r with
|
||
| ff, tt := mk_dec_is_true neutral_expr neutral_expr
|
||
| ff, ff := mk_dec_is_false neutral_expr neutral_expr
|
||
| tt, tt := mk_dec_is_true pred (mk_lc_proof pred)
|
||
| tt, ff := mk_dec_is_false pred (mk_lc_proof pred)
|
||
|
||
def fold_nat_bin_pred (mk_pred : expr → expr → expr) (fn : nat → nat → bool)
|
||
(before_erasure : bool) (a₁ a₂ : expr) : option expr :=
|
||
do n₁ ← get_num_lit a₁,
|
||
n₂ ← get_num_lit a₂,
|
||
pure $ to_decidable_expr before_erasure (mk_pred a₁ a₂) (fn n₁ n₂)
|
||
|
||
def fold_nat_dec_eq := fold_nat_bin_pred mk_nat_eq (λ a b, a = b)
|
||
def fold_nat_dec_lt := fold_nat_bin_pred mk_nat_lt (λ a b, a < b)
|
||
def fold_nat_dec_le := fold_nat_bin_pred mk_nat_le (λ a b, a ≤ b)
|
||
|
||
def nat_fold_fns : list (name × bin_fold_fn) :=
|
||
[(`nat.add, fold_nat_add),
|
||
(`nat.mul, fold_nat_mul),
|
||
(`nat.div, fold_nat_div),
|
||
(`nat.mod, fold_nat_mod),
|
||
(`nat.dec_eq, fold_nat_dec_eq),
|
||
(`nat.dec_lt, fold_nat_dec_lt),
|
||
(`nat.dec_le, fold_nat_dec_le)]
|
||
|
||
def bin_fold_fns : list (name × bin_fold_fn) :=
|
||
uint_bin_fold_fns ++ nat_fold_fns
|
||
|
||
def fold_nat_succ (_ : bool) (a : expr) : option expr :=
|
||
do n ← get_num_lit a,
|
||
pure $ expr.lit (literal.nat_val (n+1))
|
||
|
||
def fold_char_of_nat (before_erasure : bool) (a : expr) : option expr :=
|
||
do guard (!before_erasure),
|
||
n ← get_num_lit a,
|
||
pure $
|
||
if is_valid_char (uint32.of_nat n) then mk_uint32_lit n
|
||
else mk_uint32_lit 0
|
||
|
||
def un_fold_fns : list (name × un_fold_fn) :=
|
||
[(`nat.succ, fold_nat_succ),
|
||
(`char.of_nat, fold_char_of_nat)]
|
||
|
||
-- TODO(Leo): move
|
||
private def {u} alist_find {α : Type u} (n : name) : list (name × α) → option α
|
||
| [] := none
|
||
| ((k, v)::r) :=
|
||
if n = k then some v else alist_find r
|
||
|
||
def find_bin_fold_fn (fn : name) : option bin_fold_fn :=
|
||
alist_find fn bin_fold_fns
|
||
|
||
def find_un_fold_fn (fn : name) : option un_fold_fn :=
|
||
alist_find fn un_fold_fns
|
||
|
||
@[export lean.fold_bin_op_core]
|
||
def fold_bin_op (before_erasure : bool) (f : expr) (a : expr) (b : expr) : option expr :=
|
||
match f with
|
||
| expr.const fn _ := do
|
||
fold_fn ← find_bin_fold_fn fn,
|
||
fold_fn before_erasure a b
|
||
| _ := none
|
||
|
||
@[export lean.fold_un_op_core]
|
||
def fold_un_op (before_erasure : bool) (f : expr) (a : expr) : option expr :=
|
||
match f with
|
||
| expr.const fn _ := do
|
||
fold_fn ← find_un_fold_fn fn,
|
||
fold_fn before_erasure a
|
||
| _ := none
|
||
|
||
end compiler
|
||
end lean
|