lean4-htt/library/init/lean/compiler/const_folding.lean

167 lines
5.9 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import init.lean.expr init.platform
import init.lean.compiler.util
/- Constant folding for primitives that have special runtime support. -/
namespace lean
namespace compiler
def bin_fold_fn := bool → expr → expr → option expr
def un_fold_fn := bool → expr → option expr
def mk_uint_type_name (nbytes : nat) : name :=
mk_simple_name ("uint" ++ to_string nbytes)
structure num_scalar_type_info :=
(nbits : nat)
(id : name := mk_uint_type_name nbits)
(of_nat_fn : name := name.mk_string id "of_nat")
(size : nat := 2^nbits)
def num_scalar_types : list num_scalar_type_info :=
[{nbits := 8}, {nbits := 16}, {nbits := 32}, {nbits := 64},
{id := `usize, nbits := system.platform.nbits}]
def is_of_nat (fn : name) : bool :=
num_scalar_types.any (λ info, info.of_nat_fn = fn)
def get_info_from_fn (fn : name) : list num_scalar_type_info → option num_scalar_type_info
| [] := none
| (info::infos) :=
if info.of_nat_fn = fn then some info
else get_info_from_fn infos
def get_info_from_val : expr → option num_scalar_type_info
| (expr.app (expr.const fn _) _) := get_info_from_fn fn num_scalar_types
| _ := none
@[export lean.get_num_lit_core]
def get_num_lit : expr → option nat
| (expr.lit (literal.nat_val n)) := some n
| (expr.app (expr.const fn _) a) := if is_of_nat fn then get_num_lit a else none
| _ := none
def mk_uint_lit (info : num_scalar_type_info) (n : nat) : expr :=
expr.app (expr.const info.of_nat_fn []) (expr.lit (literal.nat_val (n%info.size)))
def mk_uint32_lit (n : nat) : expr :=
mk_uint_lit {nbits := 32} n
def fold_bin_uint (fn : num_scalar_type_info → bool → nat → nat → nat) (before_erasure : bool) (a₁ a₂ : expr) : option expr :=
do n₁ ← get_num_lit a₁,
n₂ ← get_num_lit a₂,
info ← get_info_from_val a₁,
pure $ mk_uint_lit info (fn info before_erasure n₁ n₂)
def fold_uint_add := fold_bin_uint $ λ _ _, (+)
def fold_uint_mul := fold_bin_uint $ λ _ _, (*)
def fold_uint_div := fold_bin_uint $ λ _ _, (/)
def fold_uint_mod := fold_bin_uint $ λ _ _, (%)
def fold_uint_sub := fold_bin_uint $ λ info _ a b, (a + (info.size - b))
def pre_uint_bin_fold_fns : list (name × bin_fold_fn) :=
[(`add, fold_uint_add), (`mul, fold_uint_mul), (`div, fold_uint_div),
(`mod, fold_uint_mod), (`sub, fold_uint_sub)]
def uint_bin_fold_fns : list (name × bin_fold_fn) :=
num_scalar_types.foldl (λ r info, r ++ (pre_uint_bin_fold_fns.map (λ ⟨suffix, fn⟩, (info.id ++ suffix, fn)))) []
def fold_nat_bin_op (fn : nat → nat → nat) (a₁ a₂ : expr) : option expr :=
do n₁ ← get_num_lit a₁,
n₂ ← get_num_lit a₂,
pure $ expr.lit (literal.nat_val (fn n₁ n₂))
def fold_nat_add (_ : bool) := fold_nat_bin_op (+)
def fold_nat_mul (_ : bool) := fold_nat_bin_op (*)
def fold_nat_div (_ : bool) := fold_nat_bin_op (/)
def fold_nat_mod (_ : bool) := fold_nat_bin_op (%)
def mk_nat_eq (a b : expr) : expr :=
mk_bin_app (expr.app (expr.const `eq [level.one]) (expr.const `nat [])) a b
def mk_nat_lt (a b : expr) : expr :=
mk_bin_app (mk_bin_app (expr.const `has_lt.lt [level.zero]) (expr.const `nat []) (expr.const `nat.has_lt [])) a b
def mk_nat_le (a b : expr) : expr :=
mk_bin_app (mk_bin_app (expr.const `has_lt.le [level.zero]) (expr.const `nat []) (expr.const `nat.has_le [])) a b
def to_decidable_expr (before_erasure : bool) (pred : expr) (r : bool) : expr :=
match before_erasure, r with
| ff, tt := mk_dec_is_true neutral_expr neutral_expr
| ff, ff := mk_dec_is_false neutral_expr neutral_expr
| tt, tt := mk_dec_is_true pred (mk_lc_proof pred)
| tt, ff := mk_dec_is_false pred (mk_lc_proof pred)
def fold_nat_bin_pred (mk_pred : expr → expr → expr) (fn : nat → nat → bool)
(before_erasure : bool) (a₁ a₂ : expr) : option expr :=
do n₁ ← get_num_lit a₁,
n₂ ← get_num_lit a₂,
pure $ to_decidable_expr before_erasure (mk_pred a₁ a₂) (fn n₁ n₂)
def fold_nat_dec_eq := fold_nat_bin_pred mk_nat_eq (λ a b, a = b)
def fold_nat_dec_lt := fold_nat_bin_pred mk_nat_lt (λ a b, a < b)
def fold_nat_dec_le := fold_nat_bin_pred mk_nat_le (λ a b, a ≤ b)
def nat_fold_fns : list (name × bin_fold_fn) :=
[(`nat.add, fold_nat_add),
(`nat.mul, fold_nat_mul),
(`nat.div, fold_nat_div),
(`nat.mod, fold_nat_mod),
(`nat.dec_eq, fold_nat_dec_eq),
(`nat.dec_lt, fold_nat_dec_lt),
(`nat.dec_le, fold_nat_dec_le)]
def bin_fold_fns : list (name × bin_fold_fn) :=
uint_bin_fold_fns ++ nat_fold_fns
def fold_nat_succ (_ : bool) (a : expr) : option expr :=
do n ← get_num_lit a,
pure $ expr.lit (literal.nat_val (n+1))
def fold_char_of_nat (before_erasure : bool) (a : expr) : option expr :=
do guard (!before_erasure),
n ← get_num_lit a,
pure $
if is_valid_char (uint32.of_nat n) then mk_uint32_lit n
else mk_uint32_lit 0
def un_fold_fns : list (name × un_fold_fn) :=
[(`nat.succ, fold_nat_succ),
(`char.of_nat, fold_char_of_nat)]
-- TODO(Leo): move
private def {u} alist_find {α : Type u} (n : name) : list (name × α) → option α
| [] := none
| ((k, v)::r) :=
if n = k then some v else alist_find r
def find_bin_fold_fn (fn : name) : option bin_fold_fn :=
alist_find fn bin_fold_fns
def find_un_fold_fn (fn : name) : option un_fold_fn :=
alist_find fn un_fold_fns
@[export lean.fold_bin_op_core]
def fold_bin_op (before_erasure : bool) (f : expr) (a : expr) (b : expr) : option expr :=
match f with
| expr.const fn _ := do
fold_fn ← find_bin_fold_fn fn,
fold_fn before_erasure a b
| _ := none
@[export lean.fold_un_op_core]
def fold_un_op (before_erasure : bool) (f : expr) (a : expr) : option expr :=
match f with
| expr.const fn _ := do
fold_fn ← find_un_fold_fn fn,
fold_fn before_erasure a
| _ := none
end compiler
end lean