feat: optimize mul/div into shift operations

This commit is contained in:
Henrik Böving 2022-11-27 13:32:43 +01:00 committed by Gabriel Ebner
parent 24cc6eae6d
commit 5286c2b5aa
3 changed files with 99 additions and 21 deletions

View file

@ -245,6 +245,33 @@ def Folder.rightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α)
unless arg == annihilator do return none
mkLit zero
def Folder.divShift [Literal α] [BEq α] (shiftRight : Name) (pow2 : αα) (log2 : αα) : Folder := fun args => do
unless (← getEnv).contains shiftRight do return none
let #[lhs, .fvar fvarId] := args | return none
let some rhs ← getLit fvarId | return none
let exponent := log2 rhs
unless pow2 exponent == rhs do return none
let shiftLit ← mkAuxLit exponent
return some <| .const shiftRight [] #[lhs, .fvar shiftLit]
def Folder.mulRhsShift [Literal α] [BEq α] (shiftLeft : Name) (pow2 : αα) (log2 : αα) : Folder := fun args => do
unless (← getEnv).contains shiftLeft do return none
let #[lhs, .fvar fvarId] := args | return none
let some rhs ← getLit fvarId | return none
let exponent := log2 rhs
unless pow2 exponent == rhs do return none
let shiftLit ← mkAuxLit exponent
return some <| .const shiftLeft [] #[lhs, .fvar shiftLit]
def Folder.mulLhsShift [Literal α] [BEq α] (shiftLeft : Name) (pow2 : αα) (log2 : αα) : Folder := fun args => do
unless (← getEnv).contains shiftLeft do return none
let #[.fvar fvarId, rhs] := args | return none
let some lhs ← getLit fvarId | return none
let exponent := log2 lhs
unless pow2 exponent == lhs do return none
let shiftLit ← mkAuxLit exponent
return some <| .const shiftLeft [] #[rhs, .fvar shiftLit]
/--
Pick the first folder out of `folders` that succeeds.
-/
@ -276,31 +303,34 @@ def higherOrderLiteralFolders : List (Name × Folder) := [
(``List.toArray, foldArrayLiteral)
]
def Folder.mulShift [Literal α] [BEq α] (shiftLeft : Name) (pow2 : αα) (log2 : αα) : Folder :=
Folder.first #[Folder.mulLhsShift shiftLeft pow2 log2, Folder.mulRhsShift shiftLeft pow2 log2]
/--
All arithmetic folders.
-/
def arithmeticFolders : List (Name × Folder) := [
(``Nat.succ, Folder.mkUnary Nat.succ),
(``Nat.add, Folder.first #[Folder.mkBinary Nat.add, Folder.leftRightNeutral 0]),
(``UInt8.add, Folder.first #[Folder.mkBinary UInt8.add, Folder.leftRightNeutral (0 : UInt8)]),
(``UInt16.add, Folder.first #[Folder.mkBinary UInt16.add, Folder.leftRightNeutral (0 : UInt16)]),
(``UInt32.add, Folder.first #[Folder.mkBinary UInt32.add, Folder.leftRightNeutral (0 : UInt32)]),
(``UInt64.add, Folder.first #[Folder.mkBinary UInt64.add, Folder.leftRightNeutral (0 : UInt64)]),
(``Nat.sub, Folder.first #[Folder.mkBinary Nat.sub, Folder.leftRightNeutral 0]),
(``UInt8.sub, Folder.first #[Folder.mkBinary UInt8.sub, Folder.leftRightNeutral (0 : UInt8)]),
(``UInt16.sub, Folder.first #[Folder.mkBinary UInt16.sub, Folder.leftRightNeutral (0 : UInt16)]),
(``UInt32.sub, Folder.first #[Folder.mkBinary UInt32.sub, Folder.leftRightNeutral (0 : UInt32)]),
(``UInt64.sub, Folder.first #[Folder.mkBinary UInt64.sub, Folder.leftRightNeutral (0 : UInt64)]),
(``Nat.mul, Folder.first #[Folder.mkBinary Nat.mul, Folder.leftRightNeutral 1, Folder.leftRightAnnihilator 0 0]),
(``UInt8.mul, Folder.first #[Folder.mkBinary UInt8.mul, Folder.leftRightNeutral (1 : UInt8), Folder.leftRightAnnihilator (0 : UInt8) 0]),
(``UInt16.mul, Folder.first #[Folder.mkBinary UInt16.mul, Folder.leftRightNeutral (1 : UInt16), Folder.leftRightAnnihilator (0 : UInt16) 0]),
(``UInt32.mul, Folder.first #[Folder.mkBinary UInt32.mul, Folder.leftRightNeutral (1 : UInt32), Folder.leftRightAnnihilator (0 : UInt32) 0]),
(``UInt64.mul, Folder.first #[Folder.mkBinary UInt64.mul, Folder.leftRightNeutral (1 : UInt64), Folder.leftRightAnnihilator (0 : UInt64) 0]),
(``Nat.div, Folder.first #[Folder.mkBinary Nat.div, Folder.rightNeutral 1]),
(``UInt8.div, Folder.first #[Folder.mkBinary UInt8.div, Folder.rightNeutral (1 : UInt8)]),
(``UInt16.div, Folder.first #[Folder.mkBinary UInt16.div, Folder.rightNeutral (1 : UInt16)]),
(``UInt32.div, Folder.first #[Folder.mkBinary UInt32.div, Folder.rightNeutral (1 : UInt32)]),
(``UInt64.div, Folder.first #[Folder.mkBinary UInt64.div, Folder.rightNeutral (1 : UInt64)])
(``Nat.succ, Folder.mkUnary Nat.succ),
(``Nat.add, Folder.first #[Folder.mkBinary Nat.add, Folder.leftRightNeutral 0]),
(``UInt8.add, Folder.first #[Folder.mkBinary UInt8.add, Folder.leftRightNeutral (0 : UInt8)]),
(``UInt16.add, Folder.first #[Folder.mkBinary UInt16.add, Folder.leftRightNeutral (0 : UInt16)]),
(``UInt32.add, Folder.first #[Folder.mkBinary UInt32.add, Folder.leftRightNeutral (0 : UInt32)]),
(``UInt64.add, Folder.first #[Folder.mkBinary UInt64.add, Folder.leftRightNeutral (0 : UInt64)]),
(``Nat.sub, Folder.first #[Folder.mkBinary Nat.sub, Folder.leftRightNeutral 0]),
(``UInt8.sub, Folder.first #[Folder.mkBinary UInt8.sub, Folder.leftRightNeutral (0 : UInt8)]),
(``UInt16.sub, Folder.first #[Folder.mkBinary UInt16.sub, Folder.leftRightNeutral (0 : UInt16)]),
(``UInt32.sub, Folder.first #[Folder.mkBinary UInt32.sub, Folder.leftRightNeutral (0 : UInt32)]),
(``UInt64.sub, Folder.first #[Folder.mkBinary UInt64.sub, Folder.leftRightNeutral (0 : UInt64)]),
(``Nat.mul, Folder.first #[Folder.mkBinary Nat.mul, Folder.leftRightNeutral 1, Folder.leftRightAnnihilator 0 0, Folder.mulShift ``Nat.shiftLeft (Nat.pow 2) Nat.log2]),
(``UInt8.mul, Folder.first #[Folder.mkBinary UInt8.mul, Folder.leftRightNeutral (1 : UInt8), Folder.leftRightAnnihilator (0 : UInt8) 0, Folder.mulShift ``UInt8.shiftLeft (UInt8.shiftLeft 1 ·) UInt8.log2]),
(``UInt16.mul, Folder.first #[Folder.mkBinary UInt16.mul, Folder.leftRightNeutral (1 : UInt16), Folder.leftRightAnnihilator (0 : UInt16) 0, Folder.mulShift ``UInt16.shiftLeft (UInt16.shiftLeft 1 ·) UInt16.log2]),
(``UInt32.mul, Folder.first #[Folder.mkBinary UInt32.mul, Folder.leftRightNeutral (1 : UInt32), Folder.leftRightAnnihilator (0 : UInt32) 0, Folder.mulShift ``UInt32.shiftLeft (UInt32.shiftLeft 1 ·) UInt32.log2]),
(``UInt64.mul, Folder.first #[Folder.mkBinary UInt64.mul, Folder.leftRightNeutral (1 : UInt64), Folder.leftRightAnnihilator (0 : UInt64) 0, Folder.mulShift ``UInt64.shiftLeft (UInt64.shiftLeft 1 ·) UInt64.log2]),
(``Nat.div, Folder.first #[Folder.mkBinary Nat.div, Folder.rightNeutral 1, Folder.divShift ``Nat.shiftRight (Nat.pow 2) Nat.log2]),
(``UInt8.div, Folder.first #[Folder.mkBinary UInt8.div, Folder.rightNeutral (1 : UInt8), Folder.divShift ``UInt8.shiftRight (UInt8.shiftLeft 1 ·) UInt8.log2]),
(``UInt16.div, Folder.first #[Folder.mkBinary UInt16.div, Folder.rightNeutral (1 : UInt16), Folder.divShift ``UInt16.shiftRight (UInt16.shiftLeft 1 ·) UInt16.log2]),
(``UInt32.div, Folder.first #[Folder.mkBinary UInt32.div, Folder.rightNeutral (1 : UInt32), Folder.divShift ``UInt32.shiftRight (UInt32.shiftLeft 1 ·) UInt32.log2]),
(``UInt64.div, Folder.first #[Folder.mkBinary UInt64.div, Folder.rightNeutral (1 : UInt64), Folder.divShift ``UInt64.shiftRight (UInt64.shiftLeft 1 ·) UInt64.log2])
]
def relationFolders : List (Name × Folder) := [

View file

@ -0,0 +1,13 @@
set_option trace.Compiler.result true in
def mulDivShift (a : Nat) (b : UInt8) (c : UInt16) (d : UInt32) (e : UInt64) : Nat :=
let a1 := a / 32
let a2 := a * 32
let b1 := b / 32
let b2 := b * 32
let c1 := c / 32
let c2 := c * 32
let d1 := d / 32
let d2 := d * 32
let e1 := e / 32
let e2 := e * 32
a1 + a2 + b1.val.val + b2.val.val + c1.val.val + c2.val.val + d1.val.val + d2.val.val + e1.val.val + e2.val.val

View file

@ -0,0 +1,35 @@
[Compiler.result] size: 32
def mulDivShift a b c d e : Nat :=
let _x.1 := 5;
let a1 := Nat.shiftRight a _x.1;
let a2 := Nat.shiftLeft a _x.1;
let _x.2 := UInt8.ofNat _x.1;
let b1 := UInt8.shiftRight b _x.2;
let b2 := UInt8.shiftLeft b _x.2;
let _x.3 := UInt16.ofNat _x.1;
let c1 := UInt16.shiftRight c _x.3;
let c2 := UInt16.shiftLeft c _x.3;
let _x.4 := UInt32.ofNat _x.1;
let d1 := UInt32.shiftRight d _x.4;
let d2 := UInt32.shiftLeft d _x.4;
let _x.5 := UInt64.ofNat _x.1;
let e1 := UInt64.shiftRight e _x.5;
let e2 := UInt64.shiftLeft e _x.5;
let _x.6 := Nat.add a1 a2;
let _x.7 := UInt8.val b1;
let _x.8 := Nat.add _x.6 _x.7;
let _x.9 := UInt8.val b2;
let _x.10 := Nat.add _x.8 _x.9;
let _x.11 := UInt16.val c1;
let _x.12 := Nat.add _x.10 _x.11;
let _x.13 := UInt16.val c2;
let _x.14 := Nat.add _x.12 _x.13;
let _x.15 := UInt32.val d1;
let _x.16 := Nat.add _x.14 _x.15;
let _x.17 := UInt32.val d2;
let _x.18 := Nat.add _x.16 _x.17;
let _x.19 := UInt64.val e1;
let _x.20 := Nat.add _x.18 _x.19;
let _x.21 := UInt64.val e2;
let _x.22 := Nat.add _x.20 _x.21;
return _x.22