From 5286c2b5aa30403cbf75f9e67d406edbfca12b67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Sun, 27 Nov 2022 13:32:43 +0100 Subject: [PATCH] feat: optimize mul/div into shift operations --- src/Lean/Compiler/LCNF/Simp/ConstantFold.lean | 72 +++++++++++++------ tests/lean/CompilerConstantFold.lean | 13 ++++ .../CompilerConstantFold.lean.expected.out | 35 +++++++++ 3 files changed, 99 insertions(+), 21 deletions(-) create mode 100644 tests/lean/CompilerConstantFold.lean create mode 100644 tests/lean/CompilerConstantFold.lean.expected.out diff --git a/src/Lean/Compiler/LCNF/Simp/ConstantFold.lean b/src/Lean/Compiler/LCNF/Simp/ConstantFold.lean index 9654b91f47..86d0978813 100644 --- a/src/Lean/Compiler/LCNF/Simp/ConstantFold.lean +++ b/src/Lean/Compiler/LCNF/Simp/ConstantFold.lean @@ -245,6 +245,33 @@ def Folder.rightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) unless arg == annihilator do return none mkLit zero +def Folder.divShift [Literal α] [BEq α] (shiftRight : Name) (pow2 : α → α) (log2 : α → α) : Folder := fun args => do + unless (← getEnv).contains shiftRight do return none + let #[lhs, .fvar fvarId] := args | return none + let some rhs ← getLit fvarId | return none + let exponent := log2 rhs + unless pow2 exponent == rhs do return none + let shiftLit ← mkAuxLit exponent + return some <| .const shiftRight [] #[lhs, .fvar shiftLit] + +def Folder.mulRhsShift [Literal α] [BEq α] (shiftLeft : Name) (pow2 : α → α) (log2 : α → α) : Folder := fun args => do + unless (← getEnv).contains shiftLeft do return none + let #[lhs, .fvar fvarId] := args | return none + let some rhs ← getLit fvarId | return none + let exponent := log2 rhs + unless pow2 exponent == rhs do return none + let shiftLit ← mkAuxLit exponent + return some <| .const shiftLeft [] #[lhs, .fvar shiftLit] + +def Folder.mulLhsShift [Literal α] [BEq α] (shiftLeft : Name) (pow2 : α → α) (log2 : α → α) : Folder := fun args => do + unless (← getEnv).contains shiftLeft do return none + let #[.fvar fvarId, rhs] := args | return none + let some lhs ← getLit fvarId | return none + let exponent := log2 lhs + unless pow2 exponent == lhs do return none + let shiftLit ← mkAuxLit exponent + return some <| .const shiftLeft [] #[rhs, .fvar shiftLit] + /-- Pick the first folder out of `folders` that succeeds. -/ @@ -276,31 +303,34 @@ def higherOrderLiteralFolders : List (Name × Folder) := [ (``List.toArray, foldArrayLiteral) ] +def Folder.mulShift [Literal α] [BEq α] (shiftLeft : Name) (pow2 : α → α) (log2 : α → α) : Folder := + Folder.first #[Folder.mulLhsShift shiftLeft pow2 log2, Folder.mulRhsShift shiftLeft pow2 log2] + /-- All arithmetic folders. -/ def arithmeticFolders : List (Name × Folder) := [ - (``Nat.succ, Folder.mkUnary Nat.succ), - (``Nat.add, Folder.first #[Folder.mkBinary Nat.add, Folder.leftRightNeutral 0]), - (``UInt8.add, Folder.first #[Folder.mkBinary UInt8.add, Folder.leftRightNeutral (0 : UInt8)]), - (``UInt16.add, Folder.first #[Folder.mkBinary UInt16.add, Folder.leftRightNeutral (0 : UInt16)]), - (``UInt32.add, Folder.first #[Folder.mkBinary UInt32.add, Folder.leftRightNeutral (0 : UInt32)]), - (``UInt64.add, Folder.first #[Folder.mkBinary UInt64.add, Folder.leftRightNeutral (0 : UInt64)]), - (``Nat.sub, Folder.first #[Folder.mkBinary Nat.sub, Folder.leftRightNeutral 0]), - (``UInt8.sub, Folder.first #[Folder.mkBinary UInt8.sub, Folder.leftRightNeutral (0 : UInt8)]), - (``UInt16.sub, Folder.first #[Folder.mkBinary UInt16.sub, Folder.leftRightNeutral (0 : UInt16)]), - (``UInt32.sub, Folder.first #[Folder.mkBinary UInt32.sub, Folder.leftRightNeutral (0 : UInt32)]), - (``UInt64.sub, Folder.first #[Folder.mkBinary UInt64.sub, Folder.leftRightNeutral (0 : UInt64)]), - (``Nat.mul, Folder.first #[Folder.mkBinary Nat.mul, Folder.leftRightNeutral 1, Folder.leftRightAnnihilator 0 0]), - (``UInt8.mul, Folder.first #[Folder.mkBinary UInt8.mul, Folder.leftRightNeutral (1 : UInt8), Folder.leftRightAnnihilator (0 : UInt8) 0]), - (``UInt16.mul, Folder.first #[Folder.mkBinary UInt16.mul, Folder.leftRightNeutral (1 : UInt16), Folder.leftRightAnnihilator (0 : UInt16) 0]), - (``UInt32.mul, Folder.first #[Folder.mkBinary UInt32.mul, Folder.leftRightNeutral (1 : UInt32), Folder.leftRightAnnihilator (0 : UInt32) 0]), - (``UInt64.mul, Folder.first #[Folder.mkBinary UInt64.mul, Folder.leftRightNeutral (1 : UInt64), Folder.leftRightAnnihilator (0 : UInt64) 0]), - (``Nat.div, Folder.first #[Folder.mkBinary Nat.div, Folder.rightNeutral 1]), - (``UInt8.div, Folder.first #[Folder.mkBinary UInt8.div, Folder.rightNeutral (1 : UInt8)]), - (``UInt16.div, Folder.first #[Folder.mkBinary UInt16.div, Folder.rightNeutral (1 : UInt16)]), - (``UInt32.div, Folder.first #[Folder.mkBinary UInt32.div, Folder.rightNeutral (1 : UInt32)]), - (``UInt64.div, Folder.first #[Folder.mkBinary UInt64.div, Folder.rightNeutral (1 : UInt64)]) + (``Nat.succ, Folder.mkUnary Nat.succ), + (``Nat.add, Folder.first #[Folder.mkBinary Nat.add, Folder.leftRightNeutral 0]), + (``UInt8.add, Folder.first #[Folder.mkBinary UInt8.add, Folder.leftRightNeutral (0 : UInt8)]), + (``UInt16.add, Folder.first #[Folder.mkBinary UInt16.add, Folder.leftRightNeutral (0 : UInt16)]), + (``UInt32.add, Folder.first #[Folder.mkBinary UInt32.add, Folder.leftRightNeutral (0 : UInt32)]), + (``UInt64.add, Folder.first #[Folder.mkBinary UInt64.add, Folder.leftRightNeutral (0 : UInt64)]), + (``Nat.sub, Folder.first #[Folder.mkBinary Nat.sub, Folder.leftRightNeutral 0]), + (``UInt8.sub, Folder.first #[Folder.mkBinary UInt8.sub, Folder.leftRightNeutral (0 : UInt8)]), + (``UInt16.sub, Folder.first #[Folder.mkBinary UInt16.sub, Folder.leftRightNeutral (0 : UInt16)]), + (``UInt32.sub, Folder.first #[Folder.mkBinary UInt32.sub, Folder.leftRightNeutral (0 : UInt32)]), + (``UInt64.sub, Folder.first #[Folder.mkBinary UInt64.sub, Folder.leftRightNeutral (0 : UInt64)]), + (``Nat.mul, Folder.first #[Folder.mkBinary Nat.mul, Folder.leftRightNeutral 1, Folder.leftRightAnnihilator 0 0, Folder.mulShift ``Nat.shiftLeft (Nat.pow 2) Nat.log2]), + (``UInt8.mul, Folder.first #[Folder.mkBinary UInt8.mul, Folder.leftRightNeutral (1 : UInt8), Folder.leftRightAnnihilator (0 : UInt8) 0, Folder.mulShift ``UInt8.shiftLeft (UInt8.shiftLeft 1 ·) UInt8.log2]), + (``UInt16.mul, Folder.first #[Folder.mkBinary UInt16.mul, Folder.leftRightNeutral (1 : UInt16), Folder.leftRightAnnihilator (0 : UInt16) 0, Folder.mulShift ``UInt16.shiftLeft (UInt16.shiftLeft 1 ·) UInt16.log2]), + (``UInt32.mul, Folder.first #[Folder.mkBinary UInt32.mul, Folder.leftRightNeutral (1 : UInt32), Folder.leftRightAnnihilator (0 : UInt32) 0, Folder.mulShift ``UInt32.shiftLeft (UInt32.shiftLeft 1 ·) UInt32.log2]), + (``UInt64.mul, Folder.first #[Folder.mkBinary UInt64.mul, Folder.leftRightNeutral (1 : UInt64), Folder.leftRightAnnihilator (0 : UInt64) 0, Folder.mulShift ``UInt64.shiftLeft (UInt64.shiftLeft 1 ·) UInt64.log2]), + (``Nat.div, Folder.first #[Folder.mkBinary Nat.div, Folder.rightNeutral 1, Folder.divShift ``Nat.shiftRight (Nat.pow 2) Nat.log2]), + (``UInt8.div, Folder.first #[Folder.mkBinary UInt8.div, Folder.rightNeutral (1 : UInt8), Folder.divShift ``UInt8.shiftRight (UInt8.shiftLeft 1 ·) UInt8.log2]), + (``UInt16.div, Folder.first #[Folder.mkBinary UInt16.div, Folder.rightNeutral (1 : UInt16), Folder.divShift ``UInt16.shiftRight (UInt16.shiftLeft 1 ·) UInt16.log2]), + (``UInt32.div, Folder.first #[Folder.mkBinary UInt32.div, Folder.rightNeutral (1 : UInt32), Folder.divShift ``UInt32.shiftRight (UInt32.shiftLeft 1 ·) UInt32.log2]), + (``UInt64.div, Folder.first #[Folder.mkBinary UInt64.div, Folder.rightNeutral (1 : UInt64), Folder.divShift ``UInt64.shiftRight (UInt64.shiftLeft 1 ·) UInt64.log2]) ] def relationFolders : List (Name × Folder) := [ diff --git a/tests/lean/CompilerConstantFold.lean b/tests/lean/CompilerConstantFold.lean new file mode 100644 index 0000000000..1dcd0da5ab --- /dev/null +++ b/tests/lean/CompilerConstantFold.lean @@ -0,0 +1,13 @@ +set_option trace.Compiler.result true in +def mulDivShift (a : Nat) (b : UInt8) (c : UInt16) (d : UInt32) (e : UInt64) : Nat := + let a1 := a / 32 + let a2 := a * 32 + let b1 := b / 32 + let b2 := b * 32 + let c1 := c / 32 + let c2 := c * 32 + let d1 := d / 32 + let d2 := d * 32 + let e1 := e / 32 + let e2 := e * 32 + a1 + a2 + b1.val.val + b2.val.val + c1.val.val + c2.val.val + d1.val.val + d2.val.val + e1.val.val + e2.val.val diff --git a/tests/lean/CompilerConstantFold.lean.expected.out b/tests/lean/CompilerConstantFold.lean.expected.out new file mode 100644 index 0000000000..28ae0165ac --- /dev/null +++ b/tests/lean/CompilerConstantFold.lean.expected.out @@ -0,0 +1,35 @@ +[Compiler.result] size: 32 + def mulDivShift a b c d e : Nat := + let _x.1 := 5; + let a1 := Nat.shiftRight a _x.1; + let a2 := Nat.shiftLeft a _x.1; + let _x.2 := UInt8.ofNat _x.1; + let b1 := UInt8.shiftRight b _x.2; + let b2 := UInt8.shiftLeft b _x.2; + let _x.3 := UInt16.ofNat _x.1; + let c1 := UInt16.shiftRight c _x.3; + let c2 := UInt16.shiftLeft c _x.3; + let _x.4 := UInt32.ofNat _x.1; + let d1 := UInt32.shiftRight d _x.4; + let d2 := UInt32.shiftLeft d _x.4; + let _x.5 := UInt64.ofNat _x.1; + let e1 := UInt64.shiftRight e _x.5; + let e2 := UInt64.shiftLeft e _x.5; + let _x.6 := Nat.add a1 a2; + let _x.7 := UInt8.val b1; + let _x.8 := Nat.add _x.6 _x.7; + let _x.9 := UInt8.val b2; + let _x.10 := Nat.add _x.8 _x.9; + let _x.11 := UInt16.val c1; + let _x.12 := Nat.add _x.10 _x.11; + let _x.13 := UInt16.val c2; + let _x.14 := Nat.add _x.12 _x.13; + let _x.15 := UInt32.val d1; + let _x.16 := Nat.add _x.14 _x.15; + let _x.17 := UInt32.val d2; + let _x.18 := Nat.add _x.16 _x.17; + let _x.19 := UInt64.val e1; + let _x.20 := Nat.add _x.18 _x.19; + let _x.21 := UInt64.val e2; + let _x.22 := Nat.add _x.20 _x.21; + return _x.22