feat: FloatLetIn compiler pass

This commit is contained in:
Henrik Böving 2022-10-08 15:02:46 +02:00
parent d132551829
commit dd3c0f77f1
6 changed files with 381 additions and 0 deletions

View file

@ -0,0 +1,311 @@
/-
Copyright (c) 2022 Henrik Böving. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.FVarUtil
import Lean.Compiler.LCNF.PassManager
import Lean.Compiler.LCNF.Types
namespace Lean.Compiler.LCNF
namespace FloatLetIn
/--
The decision of the float mechanism.
-/
inductive Decision where
|
/--
Push into the arm with name `name`.
-/
arm (name : Name)
| /--
Push into the default arm.
-/
default
|
/--
Dont move this declaration it is needed where it is right now.
-/
dont
|
/--
No decision has been made yet.
-/
unknown
deriving Hashable, BEq, Inhabited, Repr
def Decision.ofAlt : Alt → Decision
| .alt name _ _ => .arm name
| .default _ => .default
/--
The context for `BaseFloatM`.
-/
structure BaseFloatContext where
/--
All the declarations that were collected in the current LCNF basic
block up to the current statement (in reverse order for efficiency).
-/
decls : List CodeDecl := []
/--
The state for `FloatM`
-/
structure FloatState where
/--
A map from identifiers of declarations to their current decision.
-/
decision : HashMap FVarId Decision
/--
A map from decisions (excluding `unknown`) to the declarations with
these decisions (in correct order). Basically:
- Which declarations do we not move
- Which declarations do we move into a certain arm
- Which declarations do we move into the default arm
-/
newArms : HashMap Decision (List CodeDecl)
/--
Use to collect relevant declarations for the floating mechanism.
-/
abbrev BaseFloatM := ReaderT BaseFloatContext CompilerM
/--
Use to compute the actual floating.
-/
abbrev FloatM := StateRefT FloatState BaseFloatM
/--
Add `decl` to the list of declarations and run `x` with that updated context.
-/
def withNewCandidate (decl : CodeDecl) (x : BaseFloatM α) : BaseFloatM α :=
withReader (fun r => { r with decls := decl :: r.decls }) do
x
/--
Run `x` with an empty list of declarations.
-/
def withNewScope (x : BaseFloatM α) : BaseFloatM α := do
withReader (fun _ => {}) do
x
/--
Whether to ignore `decl` for the floating mechanism. We want to do this if:
- `decl`' is storing a typeclass instance
- `decl` is a projection from a variable that is storing a typeclass instance
-/
def ignore? (decl : LetDecl) : BaseFloatM Bool := do
if (← isArrowClass? decl.type).isSome then
return true
else if let .proj _ _ (.fvar fvarId) := decl.value then
return (← isArrowClass? (← getType fvarId)).isSome
else
return false
/--
Compute the initial decision for all declarations that `BaseFloatM` collected
up to this point, with respect to `cs`. The initial decisions are:
- `dont` if the declaration is detected by `ignore?`
- `dont` if the declaration is the discriminant of `cs` since we obviously need
the discriminant to be computed before the match.
- `dont` if we see the declaration being used in more than one cases arm
- `arm` or `default` if we see the declaration only being used in exactly one cases arm
- `unknown` otherwise
-/
def initialDecisions (cs : Cases) : BaseFloatM (HashMap FVarId Decision) := do
let mut map := mkHashMap (← read).decls.length
let folder val acc := do
if let .let decl := val then
if (← ignore? decl) then
return acc.insert decl.fvarId .dont
return acc.insert val.fvarId .unknown
map ← (← read).decls.foldrM (init := map) folder
if map.contains cs.discr then
map := map.insert cs.discr .dont
(_, map) ← goCases cs |>.run map
return map
where
goFVar (plannedDecision : Decision) (var : FVarId) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit := do
if let some decision := (← get).find? var then
if decision == .unknown then
modify fun s => s.insert var plannedDecision
else if decision != plannedDecision then
modify fun s => s.insert var .dont
-- otherwise we already have the proper decision
goAlt (alt : Alt) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit :=
forFVarM (goFVar (.ofAlt alt)) alt
goCases (cs : Cases) : StateRefT (HashMap FVarId Decision) BaseFloatM Unit :=
cs.alts.forM goAlt
/--
Compute the initial new arms. This will just set up a map from all arms of
`cs` to empty `Array`s, plus one additional entry for `dont`.
-/
def initialNewArms (cs : Cases) : HashMap Decision (List CodeDecl) := Id.run do
let mut map := mkHashMap (cs.alts.size + 1)
map := map.insert .dont []
cs.alts.foldr (init := map) fun val acc => acc.insert (.ofAlt val) []
/--
Will:
- put `decl` into the `dont` arm
- decide that any free variable that occurs in `decl` and is a declaration
of interest as not getting moved either.
```
let x := ...
let y := ...
let z := x + y
cases z with
| n => z * x
| m => z * y
```
Here `x` and `y` are originally marked as getting floated into `n` and `m`
respectively but since `z` can't be moved we don't want that to move `x` and `y`.
-/
def dontFloat (decl : CodeDecl) : FloatM Unit := do
forFVarM goFVar decl
modify fun s => { s with newArms := s.newArms.insert .dont (decl :: s.newArms.find! .dont) }
where
goFVar (fvar : FVarId) : FloatM Unit := do
if (← get).decision.contains fvar then
modify fun s => { s with decision := s.decision.insert fvar .dont }
/--
Will:
- put `decl` into the arm it is marked to be moved into
- for any variables that might occur in `decl` and are of interest:
- if they are already meant to be floated into the same arm or not at all leave them untouched:
```
let x := ...
let y := x + z
cases z with
| n => x * y
| m => z
```
If we are at `y` `x` is alreayd marked to be floated into `n` as well.
- if there hasn't be a decision yet, that is they are marked with `.unknown` we float
them into the same arm as the current value:
```
let x := ..
let y := x + 2
cases z with
| n => y
| m => z
```
Here `x` is initially marked as `.unknown` since it occurs in no branch, however
since we want to move `y` into the `n` branch we can also decide to move `x`
into the `n` branch. Note that this decision might be revoked later on in the case of:
```
let x := ..
let a := x + 1
let y := x + 2
cases z with
| n => y
| m => a
```
When we visit `a` `x` is now marked as getting moved into `n` but since it also occurs
in `a` which wants to be moved somewhere else we will instead decide to not move `x`
at all.
- if they are meant to be floated somewhere else decide that they wont get floated:
```
let x := ...
let y := x + z
cases z with
| n => y
| m => x
```
If we are at `y` `x` is still marked to be moved but we don't want that.
-/
def float (decl : CodeDecl) : FloatM Unit := do
let arm := (← get).decision.find! decl.fvarId
forFVarM (goFVar · arm) decl
modify fun s => { s with newArms := s.newArms.insert arm (decl :: s.newArms.find! arm) }
where
goFVar (fvar : FVarId) (arm : Decision) : FloatM Unit := do
let some decision := (← get).decision.find? fvar | return ()
if decision != arm then
modify fun s => { s with decision := s.decision.insert fvar .dont }
else if decision == .unknown then
modify fun s => { s with decision := s.decision.insert fvar arm }
/--
Iterate throgh `decl`, pushing local declarations that are only used in one
control flow arm into said arm in order to avoid useless computations.
-/
partial def floatLetIn (decl : Decl) : CompilerM Decl := do
let newValue ← go decl.value |>.run {}
return { decl with value := newValue }
where
/--
Iterate through the collected declarations,
determining from the bottom up whether they (and the declarations they refer to)
should get moved down into the arms of the cases statement or not.
-/
goCases : FloatM Unit := do
for decl in (← read).decls do
let currentDecision := (← get).decision.find! decl.fvarId
if currentDecision == .unknown then
/-
If the decision is still unknown by now this means `decl` is
unused in its continuation and can hence be removed.
-/
eraseCodeDecl decl
else if currentDecision == .dont then
dontFloat decl
else
float decl
go (code : Code) : BaseFloatM Code := do
match code with
| .let decl k =>
withNewCandidate (.let decl) do
go k
| .jp decl k =>
let value ← withNewScope do
go decl.value
let decl ← decl.updateValue value
withNewCandidate (.jp decl) do
go k
| .fun decl k =>
let value ← withNewScope do
go decl.value
let decl ← decl.updateValue value
withNewCandidate (.fun decl) do
go k
| .cases cs =>
let base := {
decision := (← initialDecisions cs)
newArms := initialNewArms cs
}
let (_, res) ← goCases |>.run base
let remainders := res.newArms.find! .dont
let altMapper alt := do
let decision := .ofAlt alt
let newCode := res.newArms.find! decision
trace[Compiler.floatLetIn] s!"Size of code that was pushed into arm: {repr decision} {newCode.length}"
let fused ← withNewScope do
go (attachCodeDecls newCode.toArray alt.getCode)
return alt.updateCode fused
let newAlts ← cs.alts.mapM altMapper
let mut newCases := Code.updateCases! code cs.resultType cs.discr newAlts
return attachCodeDecls remainders.toArray newCases
| .jmp .. | .return .. | .unreach .. =>
return attachCodeDecls (← read).decls.toArray.reverse code
end FloatLetIn
def Decl.floatLetIn (decl : Decl) : CompilerM Decl := do
FloatLetIn.floatLetIn decl
def floatLetIn : Pass :=
.mkPerDeclaration `floatLetIn Decl.floatLetIn .base
builtin_initialize
registerTraceClass `Compiler.floatLetIn (inherited := true)
end Lean.Compiler.LCNF

View file

@ -14,6 +14,7 @@ import Lean.Compiler.LCNF.Specialize
import Lean.Compiler.LCNF.PhaseExt
import Lean.Compiler.LCNF.ToMono
import Lean.Compiler.LCNF.LambdaLifting
import Lean.Compiler.LCNF.FloatLetIn
namespace Lean.Compiler.LCNF
@ -52,6 +53,7 @@ def builtinPassManager : PassManager := {
pullInstances,
cse,
simp,
floatLetIn,
findJoinPoints,
pullFunDecls,
reduceJpArity,

View file

@ -0,0 +1,17 @@
/-
Copyright (c) 2022 Henrik Böving. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.PassManager
namespace Lean.Compiler.LCNF
namespace Probing
--abbrev DeclFilter (m : Type → Type) [MonadLiftT] := Decl → OptionT m Decl
end Probing
end Lean.Compiler.LCNF

View file

@ -0,0 +1,12 @@
set_option trace.Compiler.floatLetIn true in
def provokeFloatLet (x y : Nat) (cond : Bool) : Nat :=
let a := x ^ y
let b := x + y
let c := x - y
let dual := x * y
if cond then
match dual with
| 0 => a
| _ + 1 => c
else
b + dual

View file

@ -0,0 +1,20 @@
[Compiler.floatLetIn] Size of code that was pushed into arm: Lean.Compiler.LCNF.FloatLetIn.Decision.arm `Bool.false 1
[Compiler.floatLetIn] Size of code that was pushed into arm: Lean.Compiler.LCNF.FloatLetIn.Decision.arm `Bool.true 2
[Compiler.floatLetIn] Size of code that was pushed into arm: Lean.Compiler.LCNF.FloatLetIn.Decision.arm `Nat.zero 1
[Compiler.floatLetIn] Size of code that was pushed into arm: Lean.Compiler.LCNF.FloatLetIn.Decision.arm `Nat.succ 1
[Compiler.floatLetIn] size: 11
def provokeFloatLet x y cond : Nat :=
let dual := Nat.mul x y
cases cond : Nat
| Bool.false =>
let b := Nat.add x y
let _x.1 := Nat.add b dual
_x.1
| Bool.true =>
cases dual : Nat
| Nat.zero =>
let a := Nat.pow x y
a
| Nat.succ n.2 =>
let c := Nat.sub x y
c

View file

@ -0,0 +1,19 @@
import Lean.Compiler.Main
import Lean.Compiler.LCNF.Testing
import Lean.Elab.Do
open Lean
open Lean.Compiler.LCNF
-- Run compilation twice to avoid the output caused by the inliner
#eval Compiler.compile #[``Lean.Meta.synthInstance, ``Lean.Elab.Term.Do.elabDo]
@[cpass]
def floatLetInFixTest : PassInstaller := Testing.assertIsAtFixPoint |>.install `floatLetIn `floatLetInFix
@[cpass]
def floatLetInSizeTest : PassInstaller :=
Testing.assertReducesOrPreservesSize "FloatLetIn increased size of declaration" |>.install `floatLetIn `floatLetInSizeEq
set_option trace.Compiler.test true in
#eval Compiler.compile #[``Lean.Meta.synthInstance, ``Lean.Elab.Term.Do.elabDo]