diff --git a/tests/playground/expr_const_folding.lean b/tests/playground/expr_const_folding.lean new file mode 100644 index 0000000000..b7fd1d6a6c --- /dev/null +++ b/tests/playground/expr_const_folding.lean @@ -0,0 +1,73 @@ +inductive Expr +| Var : nat → Expr +| Val : nat → Expr +| Add : Expr → Expr → Expr +| Mul : Expr → Expr → Expr + +open Expr nat + +def mk_expr : nat → nat → Expr +| 0 v := if v = 0 then Var 1 else Val v +| (n+1) v := Add (mk_expr n (v+1)) (mk_expr n (v-1)) + +def append_add : Expr → Expr → Expr +| (Add e₁ e₂) e₃ := Add e₁ (append_add e₂ e₃) +| e₁ e₂ := Add e₁ e₂ + +def append_mul : Expr → Expr → Expr +| (Mul e₁ e₂) e₃ := Mul e₁ (append_mul e₁ e₂) +| e₁ e₂ := Mul e₁ e₂ + +def reassoc : Expr → Expr +| (Add e₁ e₂) := + let e₁' := reassoc e₁ in + let e₂' := reassoc e₂ in + append_add e₁' e₂' +| (Mul e₁ e₂) := + let e₁' := reassoc e₁ in + let e₂' := reassoc e₂ in + append_mul e₁' e₂' +| e := e + +def const_folding : Expr → Expr +| (Add e₁ e₂) := + let e₁ := const_folding e₁ in + let e₂ := const_folding e₂ in + (match e₁, e₂ with + | Val a, Val b := Val (a+b) + | Val a, Add e (Val b) := Add (Val (a+b)) e + | Val a, Add (Val b) e := Add (Val (a+b)) e + | _, _ := Add e₁ e₂) +| (Mul e₁ e₂) := + let e₁ := const_folding e₂ in + let e₁ := const_folding e₂ in + (match e₁, e₂ with + | Val a, Val b := Val (a*b) + | Val a, Mul e (Val b) := Mul (Val (a*b)) e + | Val a, Mul (Val b) e := Mul (Val (a*b)) e + | _, _ := Mul e₁ e₂) +| e := e + +def size : Expr → nat +| (Add l r) := size l + size r + 1 +| (Mul l r) := size l + size r + 1 +| e := 1 + +def to_string_aux : Expr → string → string +| (Var v) r := r ++ "#" ++ to_string v +| (Val v) r := r ++ to_string v +| (Add e₁ e₂) r := (to_string_aux e₂ ((to_string_aux e₁ (r ++ "(")) ++ " + ")) ++ ")" +| (Mul e₁ e₂) r := (to_string_aux e₂ ((to_string_aux e₁ (r ++ "(")) ++ " * ")) ++ ")" + +def eval : Expr → nat +| (Var x) := 0 +| (Val v) := v +| (Add l r) := eval l + eval r +| (Mul l r) := eval l * eval r + +def main : io uint32 := +let e := (mk_expr 23 1) in +let v₁ := eval e in +let v₂ := eval (const_folding (reassoc e)) in +io.println' (to_string v₁ ++ " " ++ to_string v₂) *> +pure 0 diff --git a/tests/playground/expr_const_folding.ml b/tests/playground/expr_const_folding.ml new file mode 100644 index 0000000000..2a61a0bad2 --- /dev/null +++ b/tests/playground/expr_const_folding.ml @@ -0,0 +1,72 @@ +type expr = +| Var of int +| Val of int +| Add of expr * expr +| Mul of expr * expr;; + +let dec n = + if n == 0 then 0 else n - 1;; + +let rec mk_expr n v = + if n == 0 then (if v == 0 then Var 1 else Val v) + else Add (mk_expr (n-1) (v+1), mk_expr (n-1) (dec v));; + +let rec append_add e1 e2 = +match (e1, e2) with +| (Add (e1, e2), e3) -> Add (e1, append_add e2 e3) +| (e1, e2) -> Add (e1, e2);; + +let rec append_mul e1 e2 = +match (e1, e2) with +| (Mul (e1, e2), e3) -> Mul (e1, append_mul e2 e3) +| (e1, e2) -> Mul (e1, e2);; + +let rec reassoc e = +match e with +| Add (e1, e2) -> + let e1' = reassoc e1 in + let e2' = reassoc e2 in + append_add e1' e2' +| Mul (e1, e2) -> + let e1' = reassoc e1 in + let e2' = reassoc e2 in + append_mul e1' e2' +| e -> e;; + +let rec const_folding e = +match e with +| Add (e1, e2) -> + let e1 = const_folding e1 in + let e2 = const_folding e2 in + (match (e1, e2) with + | (Val a, Val b) -> Val (a+b) + | (Val a, Add (e, Val b)) -> Add (Val (a+b), e) + | (Val a, Add (Val b, e)) -> Add (Val (a+b), e) + | _ -> Add (e1, e2)) +| Mul (e1, e2) -> + let e1 = const_folding e1 in + let e2 = const_folding e2 in + (match (e1, e2) with + | (Val a, Val b) -> Val (a*b) + | (Val a, Mul (e, Val b)) -> Mul (Val (a*b), e) + | (Val a, Mul (Val b, e)) -> Mul (Val (a*b), e) + | _ -> Mul (e1, e2)) +| e -> e;; + +let rec size e = +match e with +| Add (e1, e2) -> size e1 + size e2 + 1 +| Mul (e1, e2) -> size e1 + size e2 + 1 +| e -> 1;; + +let rec eeval e = + match e with + | Val n -> n + | Var x -> 0 + | Add (e1, e2) -> eeval e1 + eeval e2 + | Mul (e1, e2) -> eeval e1 * eeval e2;; + +let e = (mk_expr 23 1) in +let v1 = eeval e in +let v2 = eeval (const_folding (reassoc e)) in +Printf.printf "%8d %8d\n" v1 v2;;