This PR enables the specializer to also recursively specialize in some
non trivial higher order situations.
The main motivation for this change is the upcoming changes to do
notation by sgraf. In there he uses combinators such as
```lean
@[specialize, expose]
def List.newForIn {α β γ} (l : List α) (b : β) (kcons : α → (β → γ) → β → γ) (knil : β → γ) : γ :=
match l with
| [] => knil b
| a :: l => kcons a (l.newForIn · kcons knil) b
```
in programs such as
```lean
def testing :=
let x := 42;
List.newForIn (β := Nat) (γ := Id Nat)
[1,2,3]
x
(fun i kcontinue s =>
let x := s;
List.newForIn
[i:10].toList x
(fun j kcontinue s =>
let x := s;
let x := x + i + j;
kcontinue x)
kcontinue)
pure
```
inspecting this IR right before we get to the specializer in the current
compiler we get:
```
[Compiler.eagerLambdaLifting] size: 22
def testing : Nat :=
fun _f.1 _y.2 : Nat :=
return _y.2;
let x := 42;
let _x.3 := 1;
fun _f.4 i kcontinue s : Nat :=
fun _f.5 j kcontinue s : Nat :=
let _x.6 := Nat.add s i;
let x := Nat.add _x.6 j;
let _x.7 := kcontinue x;
return _x.7;
let _x.8 := 10;
let _x.9 := Nat.sub _x.8 i;
let _x.10 := Nat.add _x.9 _x.3;
let _x.11 := 1;
let _x.12 := Nat.sub _x.10 _x.11;
let _x.13 := Nat.mul _x.3 _x.12;
let _x.14 := Nat.add i _x.13;
let _x.15 := @List.nil _;
let _x.16 := List.range'TR.go _x.3 _x.12 _x.14 _x.15;
let _x.17 := @List.newForIn _ _ _ _x.16 s _f.5 kcontinue;
return _x.17;
let _x.18 := 2;
let _x.19 := 3;
let _x.20 := @List.nil _;
let _x.21 := @List.cons _ _x.19 _x.20;
let _x.22 := @List.cons _ _x.18 _x.21;
let _x.23 := @List.cons _ _x.3 _x.22;
let _x.24 := @List.newForIn _ _ _ _x.23 x _f.4 _f.1;
return _x.24
```
Here the `kcontinue` higher order functions pose a special challenge
because they delay the discovery of new specialization opportunities.
Inspecting the IR after the current specializer (and a cleanup simp
step) we get functions that look as follows:
```
[simp] size: 7
def List.newForIn._at_.testing.spec_0 i kcontinue l b : Nat :=
cases l : Nat
| List.nil =>
let _x.1 := kcontinue b;
return _x.1
| List.cons head.2 tail.3 =>
let _x.4 := Nat.add b i;
let x := Nat.add _x.4 head.2;
let _x.5 := List.newForIn._at_.testing.spec_0 i kcontinue tail.3 x;
return _x.5
[simp] size: 14
def List.newForIn._at_.List.newForIn._at_.testing.spec_1.spec_1 _x.1 l b : Nat :=
cases l : Nat
| List.nil =>
return b
| List.cons head.2 tail.3 =>
fun _f.4 x.5 : Nat :=
let _x.6 := List.newForIn._at_.List.newForIn._at_.testing.spec_1.spec_1 _x.1 tail.3 x.5;
return _x.6;
let _x.7 := 10;
let _x.8 := Nat.sub _x.7 head.2;
let _x.9 := Nat.add _x.8 _x.1;
let _x.10 := 1;
let _x.11 := Nat.sub _x.9 _x.10;
let _x.12 := Nat.mul _x.1 _x.11;
let _x.13 := Nat.add head.2 _x.12;
let _x.14 := @List.nil _;
let _x.15 := List.range'TR.go _x.1 _x.11 _x.13 _x.14;
let _x.16 := List.newForIn._at_.testing.spec_0 head.2 _f.4 _x.15 b;
return _x.16
```
Observe that the specializer decided to abstract over `kcontinue`
instead of specializing further recursively. Thus this tight loop is now
going through an indirect call.
This PR now changes the specializer somewhat fundamentally to handle
situations like this. The most notable change is going to a fixpoint
loop of:
1. Specialize all current declarations in the worklist
2. If a declaration
- succeeded in specializing run the simplifier on it and put it back
onto the worklist
- if it didn't don't put it back onto the worklist anymore
3. Put all newly generated specialisations on the worklist
4. Recompute fixed parameters for the current SCC
5. Repeat until the worklist is empty
Furthermore, declarations that were already specialized:
- only consider `fixedHO` parameters for specialization, in order to
avoid termination issues with repeated specialization and abstraction of
type class parameters under binders
- recursively specialized declarations only allow specialization if at
least one of their fixedHO arguments is not a parameter itself. The
reason for allowing this in first generation specialization is that we
refrain from specializing inside the body of a declaration marked as
`@[specialize]`. Thus we need to specialize them even if their arguments
don't actually contain anything of interest in order to ensure that type
classes etc. are correctly cleaned up within their bodies.
There is one last trade-off to consider. When specializing code
generated by the new do elaborator we sometimes generate intermediate
specializations that are not actually part of any call graph after we
are done specializing. We could in principle detect these functions and
delete them but having them in cache is potentially helpful for further
specializations later. Once the new do elaborator lands we plan to test
this trade-off.
Closes #10924
272 lines
10 KiB
Text
272 lines
10 KiB
Text
/-
|
||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Leonardo de Moura
|
||
-/
|
||
module
|
||
|
||
prelude
|
||
public import Lean.Compiler.LCNF.FixedParams
|
||
public import Lean.Compiler.LCNF.InferType
|
||
|
||
public section
|
||
|
||
namespace Lean.Compiler.LCNF
|
||
|
||
/--
|
||
Each parameter is associated with a `SpecParamInfo`. This information is used by `LCNF/Specialize.lean`.
|
||
-/
|
||
inductive SpecParamInfo where
|
||
/--
|
||
A parameter that is an type class instance (or an arrow that produces a type class instance),
|
||
and is fixed in recursive declarations. By default, Lean always specializes this kind of argument.
|
||
-/
|
||
| fixedInst
|
||
/--
|
||
A parameter that is a function and is fixed in recursive declarations. If the user tags a declaration
|
||
with `@[specialize]` without specifying which arguments should be specialized, Lean will specialize
|
||
`.fixedHO` arguments in addition to `.fixedInst`.
|
||
-/
|
||
| fixedHO
|
||
/--
|
||
Computationally irrelevant parameters that are fixed in recursive declarations,
|
||
*and* there is a `fixedInst`, `fixedHO`, or `user` param that depends on it.
|
||
-/
|
||
| fixedNeutral
|
||
/--
|
||
An argument that has been specified in the `@[specialize]` attribute. Lean specializes it even if it is
|
||
not fixed in recursive declarations. Non-termination can happen, and Lean interrupts it with an error message
|
||
based on the stack depth.
|
||
-/
|
||
| user
|
||
/--
|
||
Parameter is not going to be specialized.
|
||
-/
|
||
| other
|
||
deriving Inhabited, Repr
|
||
|
||
namespace SpecParamInfo
|
||
|
||
@[inline]
|
||
def causesSpecialization : SpecParamInfo → Bool
|
||
| .fixedInst | .fixedHO | .user => true
|
||
| .fixedNeutral | .other => false
|
||
|
||
end SpecParamInfo
|
||
|
||
instance : ToMessageData SpecParamInfo where
|
||
toMessageData
|
||
| .fixedInst => "I"
|
||
| .fixedHO => "H"
|
||
| .fixedNeutral => "N"
|
||
| .user => "U"
|
||
| .other => "O"
|
||
|
||
structure SpecEntry where
|
||
/--
|
||
The name of the declaration.
|
||
-/
|
||
declName : Name
|
||
/--
|
||
Information about which parameters of the declaration qualify for specialization.
|
||
-/
|
||
paramsInfo : Array SpecParamInfo
|
||
/--
|
||
True if `declName` was already specialized before. This is relevant because we specialize
|
||
declarations that have already been specialized less aggressively than declarations that have not.
|
||
-/
|
||
alreadySpecialized : Bool
|
||
deriving Inhabited
|
||
|
||
instance : ToMessageData SpecEntry where
|
||
toMessageData := fun { declName, paramsInfo, alreadySpecialized } =>
|
||
m!"{declName}, alreadySpecialized? {alreadySpecialized}, info: {paramsInfo}"
|
||
|
||
structure SpecState where
|
||
specInfo : PHashMap Name SpecEntry := {}
|
||
deriving Inhabited
|
||
|
||
namespace SpecState
|
||
|
||
def addEntry (s : SpecState) (e : SpecEntry) : SpecState :=
|
||
match s with
|
||
| { specInfo } => { specInfo := specInfo.insert e.declName e }
|
||
|
||
end SpecState
|
||
|
||
private abbrev declLt (a b : SpecEntry) :=
|
||
Name.quickLt a.declName b.declName
|
||
|
||
private abbrev sortEntries (entries : Array SpecEntry) : Array SpecEntry :=
|
||
entries.qsort declLt
|
||
|
||
private abbrev findAtSorted? (entries : Array SpecEntry) (declName : Name) : Option SpecEntry :=
|
||
entries.binSearch { declName, paramsInfo := #[], alreadySpecialized := false } declLt
|
||
|
||
/--
|
||
Extension for storing `SpecParamInfo` for declarations being compiled.
|
||
Remark: we only store information for declarations that will be specialized.
|
||
-/
|
||
builtin_initialize specExtension : SimplePersistentEnvExtension SpecEntry SpecState ←
|
||
registerSimplePersistentEnvExtension {
|
||
addEntryFn := SpecState.addEntry
|
||
addImportedFn := fun _ => {}
|
||
toArrayFn := fun s => sortEntries s.toArray
|
||
asyncMode := .sync
|
||
replay? := some <| SimplePersistentEnvExtension.replayOfFilter
|
||
(!·.specInfo.contains ·.declName) SpecState.addEntry
|
||
}
|
||
|
||
/--
|
||
Return `true` if `type` is a type tagged with `@[nospecialize]` or an arrow that produces this kind of type.
|
||
For example, this function returns true for `Inhabited Nat`, and `Nat → Inhabited Nat`.
|
||
-/
|
||
private def isNoSpecType (env : Environment) (type : Expr) : Bool :=
|
||
match type with
|
||
| .forallE _ _ b _ => isNoSpecType env b
|
||
| _ =>
|
||
if let .const declName _ := type.getAppFn then
|
||
hasNospecializeAttribute env declName
|
||
else
|
||
false
|
||
|
||
/-!
|
||
*Note*: `fixedNeutral` must have forward dependencies.
|
||
|
||
The code specializer consider a `fixedNeutral` parameter during code specialization
|
||
only if it contains forward dependencies that are tagged as `.user`, `.fixedHO`, or `.fixedInst`.
|
||
The motivation is to minimize the number of code specializations that have little or no impact on
|
||
performance. For example, let's consider the function.
|
||
```
|
||
def liftMacroM
|
||
{α : Type} {m : Type → Type}
|
||
[Monad m] [MonadMacroAdapter m] [MonadEnv m] [MonadRecDepth m] [MonadError m]
|
||
[MonadResolveName m] [MonadTrace m] [MonadOptions m] [AddMessageContext m] [MonadLiftT IO m] (x : MacroM α) : m α := do
|
||
```
|
||
The parameter `α` does not occur in any local instance, and `x` is marked as `.other` since the function
|
||
is not tagged as `[specialize]`. There is little value in considering `α` during code specialization,
|
||
but if we do many copies of this function will be generated.
|
||
Recall users may still force the code specializer to take `α` into account by using `[specialize α]` (`α` has `.user` info),
|
||
or `[specialize x]` (`α` has `.fixedNeutral` since `x` is a forward dependency tagged as `.user`),
|
||
or `[specialize]` (`α` has `.fixedNeutral` since `x` is a forward dependency tagged as `.fixedHO`).
|
||
-/
|
||
|
||
/--
|
||
Return `true` if parameter `j` of the given declaration has a forward dependency at parameter `k`,
|
||
and `k` is tagged as `.user`, `.fixedHO`, or `.fixedInst`.
|
||
|
||
See comment at `.fixedNeutral`.
|
||
-/
|
||
private def hasFwdDeps (decl : Decl) (paramsInfo : Array SpecParamInfo) (j : Nat) : Bool := Id.run do
|
||
let param := decl.params[j]!
|
||
for h : k in (j+1)...decl.params.size do
|
||
if paramsInfo[k]!.causesSpecialization then
|
||
let param' := decl.params[k]
|
||
if param'.type.containsFVar param.fvarId then
|
||
return true
|
||
return false
|
||
|
||
/--
|
||
Compute specialization information for `decls`. We assume that `decls` contains a full SCC of
|
||
computationally relevant declarations. Furthermore this function takes:
|
||
- `autoSpecialize` which determines whether we apply "automated" specialization to a decl, that is
|
||
whether we automatically specialize for all fixedHO parameters. It receives both the name and
|
||
the array of arguments mentioned in `@[specialize]` if any.
|
||
- `alreadySpecialized` which is a mask that says whether a decl is already a specialized declaration
|
||
itself.
|
||
-/
|
||
def computeSpecEntries (decls : Array Decl) (autoSpecialize : Name → Option (Array Nat) → Bool)
|
||
(alreadySpecialized : Array Bool) : CompilerM (Array SpecEntry) := do
|
||
let mut declsInfo := #[]
|
||
for decl in decls do
|
||
if hasNospecializeAttribute (← getEnv) decl.name then
|
||
declsInfo := declsInfo.push (.replicate decl.params.size .other)
|
||
else
|
||
let specArgs? := getSpecializationArgs? (← getEnv) decl.name
|
||
let contains (i : Nat) : Bool := specArgs?.getD #[] |>.contains i
|
||
let mut paramsInfo : Array SpecParamInfo := #[]
|
||
for h :i in *...decl.params.size do
|
||
let param := decl.params[i]
|
||
let info ←
|
||
if contains i then
|
||
pure .user
|
||
/-
|
||
If the user tagged class (e.g., `Inhabited`) with the `@[nospecialize]` attribute,
|
||
then parameters of this type should not be considered for specialization.
|
||
-/
|
||
else if isNoSpecType (← getEnv) param.type then
|
||
pure .other
|
||
else if isTypeFormerType param.type then
|
||
pure .fixedNeutral
|
||
else if (← isArrowClass? param.type).isSome then
|
||
pure .fixedInst
|
||
/-
|
||
Recall that if `specArgs? == some #[]`, then user annotated function with `@[specialize]`, but did not
|
||
specify which arguments must be specialized besides instances. In this case, we try to specialize
|
||
any "fixed higher-order argument"
|
||
-/
|
||
else if autoSpecialize decl.name specArgs? && param.type matches .forallE .. then
|
||
pure .fixedHO
|
||
else
|
||
pure .other
|
||
paramsInfo := paramsInfo.push info
|
||
pure ()
|
||
declsInfo := declsInfo.push paramsInfo
|
||
if declsInfo.any fun paramsInfo => paramsInfo.any SpecParamInfo.causesSpecialization then
|
||
let m := mkFixedParamsMap decls
|
||
let mut entries := Array.emptyWithCapacity decls.size
|
||
for hi : i in *...decls.size do
|
||
let decl := decls[i]
|
||
let mut paramsInfo := declsInfo[i]!
|
||
let some mask := m.find? decl.name | unreachable!
|
||
paramsInfo := Array.zipWith (as := paramsInfo) (bs := mask) fun info fixed =>
|
||
if fixed || info matches .user then
|
||
info
|
||
else
|
||
.other
|
||
for j in *...paramsInfo.size do
|
||
let mut info := paramsInfo[j]!
|
||
if info matches .fixedNeutral && !hasFwdDeps decl paramsInfo j then
|
||
paramsInfo := paramsInfo.set! j .other
|
||
entries := entries.push {
|
||
declName := decl.name,
|
||
paramsInfo,
|
||
alreadySpecialized := alreadySpecialized[i]!
|
||
}
|
||
return entries
|
||
else
|
||
return decls.mapIdx fun i decl => {
|
||
declName := decl.name,
|
||
paramsInfo := Array.replicate decl.params.size .other
|
||
alreadySpecialized := alreadySpecialized[i]!
|
||
}
|
||
|
||
/--
|
||
Compute and save specialization information for `decls`. Assumes that `decls` is an SCC of user
|
||
defined declarations.
|
||
-/
|
||
def saveSpecEntries (decls : Array Decl) : CompilerM Unit := do
|
||
let entries ← computeSpecEntries
|
||
decls
|
||
(fun _ specArgs? => specArgs? == some #[])
|
||
(Array.replicate decls.size false)
|
||
for entry in entries do
|
||
if entry.paramsInfo.any SpecParamInfo.causesSpecialization then
|
||
trace[Compiler.specialize.info] "{entry.declName} {entry.paramsInfo}"
|
||
modifyEnv fun env => specExtension.addEntry env entry
|
||
|
||
def getSpecEntryCore? (env : Environment) (declName : Name) : Option SpecEntry :=
|
||
match env.getModuleIdxFor? declName with
|
||
| some modIdx => findAtSorted? (specExtension.getModuleEntries env modIdx) declName
|
||
| none => (specExtension.getState env).specInfo.find? declName
|
||
|
||
def getSpecEntry? [Monad m] [MonadEnv m] (declName : Name) : m (Option SpecEntry) :=
|
||
return getSpecEntryCore? (← getEnv) declName
|
||
|
||
def isSpecCandidate [Monad m] [MonadEnv m] (declName : Name) : m Bool := do
|
||
return getSpecEntryCore? (← getEnv) declName |>.isSome
|
||
|
||
builtin_initialize
|
||
registerTraceClass `Compiler.specialize.info
|
||
|
||
end Lean.Compiler.LCNF
|