/- Copyright (c) 2019 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ prelude import init.lean.expr init.platform import init.lean.compiler.util /- Constant folding for primitives that have special runtime support. -/ namespace Lean namespace Compiler def BinFoldFn := Bool → Expr → Expr → Option Expr def UnFoldFn := Bool → Expr → Option Expr def mkUIntTypeName (nbytes : Nat) : Name := mkSimpleName ("UInt" ++ toString nbytes) structure NumScalarTypeInfo := (nbits : Nat) (id : Name := mkUIntTypeName nbits) (ofNatFn : Name := Name.mkString id "ofNat") (toNatFn : Name := Name.mkString id "toNat") (size : Nat := 2^nbits) def numScalarTypes : List NumScalarTypeInfo := [{nbits := 8}, {nbits := 16}, {nbits := 32}, {nbits := 64}, {id := `USize, nbits := System.platform.nbits}] def isOfNat (fn : Name) : Bool := numScalarTypes.any (λ info, info.ofNatFn = fn) def isToNat (fn : Name) : Bool := numScalarTypes.any (λ info, info.toNatFn = fn) def getInfoFromFn (fn : Name) : List NumScalarTypeInfo → Option NumScalarTypeInfo | [] := none | (info::infos) := if info.ofNatFn = fn then some info else getInfoFromFn infos def getInfoFromVal : Expr → Option NumScalarTypeInfo | (Expr.app (Expr.const fn _) _) := getInfoFromFn fn numScalarTypes | _ := none @[export lean.get_num_lit_core] def getNumLit : Expr → Option Nat | (Expr.lit (Literal.natVal n)) := some n | (Expr.app (Expr.const fn _) a) := if isOfNat fn then getNumLit a else none | _ := none def mkUIntLit (info : NumScalarTypeInfo) (n : Nat) : Expr := Expr.app (Expr.const info.ofNatFn []) (Expr.lit (Literal.natVal (n%info.size))) def mkUInt32Lit (n : Nat) : Expr := mkUIntLit {nbits := 32} n def foldBinUInt (fn : NumScalarTypeInfo → Bool → Nat → Nat → Nat) (beforeErasure : Bool) (a₁ a₂ : Expr) : Option Expr := do n₁ ← getNumLit a₁, n₂ ← getNumLit a₂, info ← getInfoFromVal a₁, pure $ mkUIntLit info (fn info beforeErasure n₁ n₂) def foldUIntAdd := foldBinUInt $ λ _ _, (+) def foldUIntMul := foldBinUInt $ λ _ _, (*) def foldUIntDiv := foldBinUInt $ λ _ _, (/) def foldUIntMod := foldBinUInt $ λ _ _, (%) def foldUIntSub := foldBinUInt $ λ info _ a b, (a + (info.size - b)) def preUIntBinFoldFns : List (Name × BinFoldFn) := [(`add, foldUIntAdd), (`mul, foldUIntMul), (`div, foldUIntDiv), (`mod, foldUIntMod), (`sub, foldUIntSub)] def uintBinFoldFns : List (Name × BinFoldFn) := numScalarTypes.foldl (λ r info, r ++ (preUIntBinFoldFns.map (λ ⟨suffix, fn⟩, (info.id ++ suffix, fn)))) [] def foldNatBinOp (fn : Nat → Nat → Nat) (a₁ a₂ : Expr) : Option Expr := do n₁ ← getNumLit a₁, n₂ ← getNumLit a₂, pure $ Expr.lit (Literal.natVal (fn n₁ n₂)) def foldNatAdd (_ : Bool) := foldNatBinOp (+) def foldNatMul (_ : Bool) := foldNatBinOp (*) def foldNatDiv (_ : Bool) := foldNatBinOp (/) def foldNatMod (_ : Bool) := foldNatBinOp (%) def foldNatPow (_ : Bool) := foldNatBinOp (^) def mkNatEq (a b : Expr) : Expr := mkBinApp (Expr.app (Expr.const `Eq [Level.one]) (Expr.const `Nat [])) a b def mkNatLt (a b : Expr) : Expr := mkBinApp (mkBinApp (Expr.const `HasLt.lt [Level.zero]) (Expr.const `Nat []) (Expr.const `Nat.HasLt [])) a b def mkNatLe (a b : Expr) : Expr := mkBinApp (mkBinApp (Expr.const `HasLt.le [Level.zero]) (Expr.const `Nat []) (Expr.const `Nat.HasLe [])) a b def toDecidableExpr (beforeErasure : Bool) (pred : Expr) (r : Bool) : Expr := match beforeErasure, r with | false, true := mkDecIsTrue neutralExpr neutralExpr | false, false := mkDecIsFalse neutralExpr neutralExpr | true, true := mkDecIsTrue pred (mkLcProof pred) | true, false := mkDecIsFalse pred (mkLcProof pred) def foldNatBinPred (mkPred : Expr → Expr → Expr) (fn : Nat → Nat → Bool) (beforeErasure : Bool) (a₁ a₂ : Expr) : Option Expr := do n₁ ← getNumLit a₁, n₂ ← getNumLit a₂, pure $ toDecidableExpr beforeErasure (mkPred a₁ a₂) (fn n₁ n₂) def foldNatDecEq := foldNatBinPred mkNatEq (λ a b, a = b) def foldNatDecLt := foldNatBinPred mkNatLt (λ a b, a < b) def foldNatDecLe := foldNatBinPred mkNatLe (λ a b, a ≤ b) def natFoldFns : List (Name × BinFoldFn) := [(`Nat.add, foldNatAdd), (`Nat.mul, foldNatMul), (`Nat.div, foldNatDiv), (`Nat.mod, foldNatMod), (`Nat.pow, foldNatPow), (`Nat.pow._main, foldNatPow), (`Nat.decEq, foldNatDecEq), (`Nat.decLt, foldNatDecLt), (`Nat.decLe, foldNatDecLe)] def getBoolLit : Expr → Option Bool | (Expr.const `Bool.true _) := some true | (Expr.const `Bool.false _) := some false | _ := none def foldStrictAnd (_ : Bool) (a₁ a₂ : Expr) : Option Expr := let v₁ := getBoolLit a₁ in let v₂ := getBoolLit a₂ in match v₁, v₂ with | some true, _ := a₂ | some false, _ := a₁ | _, some true := a₁ | _, some false := a₂ | _, _ := none def foldStrictOr (_ : Bool) (a₁ a₂ : Expr) : Option Expr := let v₁ := getBoolLit a₁ in let v₂ := getBoolLit a₂ in match v₁, v₂ with | some true, _ := a₁ | some false, _ := a₂ | _, some true := a₂ | _, some false := a₁ | _, _ := none def boolFoldFns : List (Name × BinFoldFn) := [(`strictOr, foldStrictOr), (`strictAnd, foldStrictAnd)] def binFoldFns : List (Name × BinFoldFn) := boolFoldFns ++ uintBinFoldFns ++ natFoldFns def foldNatSucc (_ : Bool) (a : Expr) : Option Expr := do n ← getNumLit a, pure $ Expr.lit (Literal.natVal (n+1)) def foldCharOfNat (beforeErasure : Bool) (a : Expr) : Option Expr := do guard (!beforeErasure), n ← getNumLit a, pure $ if isValidChar (UInt32.ofNat n) then mkUInt32Lit n else mkUInt32Lit 0 def foldToNat (_ : Bool) (a : Expr) : Option Expr := do n ← getNumLit a, pure $ Expr.lit (Literal.natVal n) def uintFoldToNatFns : List (Name × UnFoldFn) := numScalarTypes.foldl (λ r info, (info.toNatFn, foldToNat) :: r) [] def unFoldFns : List (Name × UnFoldFn) := [(`Nat.succ, foldNatSucc), (`Char.ofNat, foldCharOfNat)] ++ uintFoldToNatFns def findBinFoldFn (fn : Name) : Option BinFoldFn := binFoldFns.lookup fn def findUnFoldFn (fn : Name) : Option UnFoldFn := unFoldFns.lookup fn @[export lean.fold_bin_op_core] def foldBinOp (beforeErasure : Bool) (f : Expr) (a : Expr) (b : Expr) : Option Expr := match f with | Expr.const fn _ := do foldFn ← findBinFoldFn fn, foldFn beforeErasure a b | _ := none @[export lean.fold_un_op_core] def foldUnOp (beforeErasure : Bool) (f : Expr) (a : Expr) : Option Expr := match f with | Expr.const fn _ := do foldFn ← findUnFoldFn fn, foldFn beforeErasure a | _ := none end Compiler end Lean