lean4-htt/src/Lean/Compiler/LCNF/Specialize.lean
2022-09-15 19:55:37 -07:00

344 lines
13 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Util.ForEachExpr
import Lean.Compiler.Specialize
import Lean.Compiler.LCNF.Simp
import Lean.Compiler.LCNF.SpecInfo
import Lean.Compiler.LCNF.PrettyPrinter
import Lean.Compiler.LCNF.ToExpr
namespace Lean.Compiler.LCNF
namespace Specialize
structure Context where
/--
Set of free variables in scope. The "collector" uses this information when collecting
dependencies for code specialization.
-/
scope : FVarIdSet := {}
/--
Set of let-declarations in scope that do not depend on parameters.
-/
ground : FVarIdSet := {}
structure State where
decls : Array Decl := #[]
abbrev SpecializeM := ReaderT Context $ StateRefT State CompilerM
@[inline] def withParams (ps : Array Param) (x : SpecializeM α) : SpecializeM α :=
withReader (fun ctx => { ctx with scope := ps.foldl (init := ctx.scope) fun s p => s.insert p.fvarId }) x
@[inline] def withFVar (fvarId : FVarId) (x : SpecializeM α) : SpecializeM α :=
withReader (fun ctx => { ctx with scope := ctx.scope.insert fvarId }) x
/--
Return `true` if `e` is a ground term. That is,
it contains only free variables t
-/
def isGround (e : Expr) : SpecializeM Bool := do
let s := (← read).ground
return !e.hasAnyFVar (!s.contains ·)
@[inline] def withLetDecl (decl : LetDecl) (x : SpecializeM α) : SpecializeM α := do
let grd ← isGround decl.value
let fvarId := decl.fvarId
withReader (fun { scope, ground } => { scope := scope.insert fvarId, ground := if grd then ground.insert fvarId else ground }) x
namespace Collector
/-!
# Dependency collector for the code specialization function.
During code specialization, we select which arguments are going to be used during the specialization.
Then, we have to collect their dependencies. For example, suppose are trying to specialize the following `IO.println`
and `List.forM` applications in the following example:
```
def f xs a.1 :=
let _x.2 := @instMonadEIO IO.Error
let _x.5 := instToStringString
let _x.9 := instToStringNat
let _x.6 := "hello"
let _x.61 := @IO.println String _x.5 _x.6 a.1 -- (*)
cases _x.61
| EStateM.Result.ok a.6 a.7 =>
fun _f.72 _y.69 _y.70 :=
let _x.71 := @IO.println Nat _x.9 _y.69 _y.70 -- (*)
_x.71
let _x.65 := @List.forM (fun α => PUnit → EStateM.Result IO.Error PUnit α) _x.2 Nat xs _f.72 a.7 -- (*)
...
...
```
For `IO.println` the `SpecArgInfo` is `[N, I, O, O]`, i.e., only the first two arguments are considered
for code specialization. The first one is computationally neutral, and the second one is an instance.
For `List.forM`, we have `[N, I, N, O, H]`. In this case, the fifth argument (tagged as `H`) is a function.
Note that the actual `List.forM` application has 6 arguments, the extra argument comes from the `IO` monad.
For the first `IO.println` application, the collector collects `_x.5`.
For the `List.forM`, it collects `_x.2`, `_x.9`, and `_f.72`.
The collected values are used to construct a key to identify the specialization. Arguments that were not considered are
replaced with `lcErased`. The key is used to make sure we don't keep generating the same specialization over and over again.
This is not an optimization, it is essential to prevent the code specializer from looping while specializing recursive functions.
The keys for these two applications are the terms.
```
@IO.println Nat instToStringNat lcErased lcErased
```
and
```
@List.forM
(fun α => PUnit → EStateM.Result IO.Error PUnit α)
(@instMonadEIO IO.Error) Nat lcErased
(fun _y.69 _y.70 =>
let _x.71 := @IO.println Nat instToStringNat _y.69 _y.70;
_x.71)
```
The keys never contain free variables or loose bound variables.
-/
/--
State for the `CollectorM` monad.
-/
structure State where
/--
Set of already visited free variables.
-/
visited : FVarIdSet := {}
/--
Free variables that must become new parameters of the code being specialized.
-/
params : Array Param := #[]
/--
Let-declarations and local function declarations that are going to be "copied" to the code
being specialized. For example, the let-declarations often contain the instance values.
In the current specialization heuristic all let-declarations are ground values (i.e., they do not contain free-variables).
However, local function declarations may contain free variables.
The current heuristic tries to avoid work duplication. If a let-declaration is a ground value,
it most likely will be computed during compilation time, and work duplication is not an issue.
-/
decls : Array CodeDecl := #[]
/--
Monad for implementing the code specializer dependency collector.
See `collect`
-/
abbrev CollectorM := StateRefT State SpecializeM
/--
Mark a free variable as already visited.
We perform a topological sort over the dependencies.
-/
def markVisited (fvarId : FVarId) : CollectorM Unit :=
modify fun s => { s with visited := s.visited.insert fvarId }
mutual
/--
Collect dependencies in parameters. We need this because parameters may
contain other type parameters.
-/
partial def collectParams (params : Array Param) : CollectorM Unit :=
params.forM (collectExpr ·.type)
/--
Collect dependencies in the given code. We need this function to be able
to collect dependencies in a local function declaration.
-/
partial def collectCode (c : Code) : CollectorM Unit := do
match c with
| .let decl k => collectExpr decl.type; collectExpr decl.value; collectCode k
| .fun decl k | .jp decl k => collectFunDecl decl; collectCode k
| .cases c =>
collectExpr c.resultType
collectFVar c.discr
c.alts.forM fun alt => do
match alt with
| .default k => collectCode k
| .alt _ ps k => collectParams ps; collectCode k
| .jmp _ args => args.forM collectExpr
| .unreach type => collectExpr type
| .return fvarId => collectFVar fvarId
/-- Collect dependencies of a local function declaration. -/
partial def collectFunDecl (decl : FunDecl) : CollectorM Unit := do
collectExpr decl.type
collectParams decl.params
collectCode decl.value
/--
Process the given free variable.
If it has not already been visited and is in scope, we collect its dependencies.
-/
partial def collectFVar (fvarId : FVarId) : CollectorM Unit := do
unless (← get).visited.contains fvarId do
markVisited fvarId
if (← read).scope.contains fvarId then
/- We only collect the variables in the scope of the function application being specialized. -/
if let some funDecl ← findFunDecl? fvarId then
collectFunDecl funDecl
modify fun s => { s with decls := s.decls.push <| .fun funDecl }
else if let some param ← findParam? fvarId then
collectExpr param.type
modify fun s => { s with params := s.params.push param }
else if let some letDecl ← findLetDecl? fvarId then
collectExpr letDecl.type
if (← isGround letDecl.value) then
-- It is a ground value, thus we keep collecting dependencies
collectExpr letDecl.value
modify fun s => { s with decls := s.decls.push <| .let letDecl }
else
-- It is not a ground value, we convert declaration into a parameter
modify fun s => { s with params := s.params.push <| { letDecl with borrow := false } }
else
unreachable!
/-- Collect dependencies of the given expression. -/
partial def collectExpr (e : Expr) : CollectorM Unit := do
e.forEach fun e => do
match e with
| .fvar fvarId => collectFVar fvarId
| _ => pure ()
end
/--
Given the specialization mask `paramsInfo` and the arguments `args`,
collect their dependencies, and return an array `mask` of size `paramsInfo.size` s.t.
- `mask[i] = args[i]` if `paramsInfo[i] != .other`
- `mask[i] = lcErased`, otherwise.
That is, `mask` contains only the arguments that are contributing to the code specialization.
We use this information to compute a "key" to uniquely identify the code specialization.
-/
def collect (paramsInfo : Array SpecParamInfo) (args : Array Expr) : CollectorM (Array Expr) := do
let mut argMask := #[]
for paramInfo in paramsInfo, arg in args do
match paramInfo with
| .other =>
argMask := argMask.push (mkConst ``lcErased)
| .fixedNeutral | .user | .fixedInst | .fixedHO =>
argMask := argMask.push arg
collectExpr arg
return argMask
end Collector
/--
Return `true` if it is worth using arguments `args` for specialization given the parameter specialization information.
-/
def shouldSpecialize (paramsInfo : Array SpecParamInfo) (args : Array Expr) : SpecializeM Bool := do
for paramInfo in paramsInfo, arg in args do
match paramInfo with
| .other => pure ()
| .fixedNeutral => pure () -- If we want to monomorphize types such as `Array`, we need to change here
| .fixedInst | .user => if (← isGround arg) then return true
| .fixedHO => return true -- TODO: check whether this is too aggressive
return false
/--
Convert the given declarations into `Expr`, and "zeta-reduce" them into body.
This function is used to compute the key that uniquely identifies an code specialization.
-/
def expandCodeDecls (decls : Array CodeDecl) (body : Expr) : CompilerM Expr := do
let xs := decls.map (mkFVar ·.fvarId)
let values := decls.map fun
| .let decl => decl.value
| .fun decl | .jp decl => decl.toExpr
let rec go (i : Nat) (subst : Array Expr) : Expr :=
if h : i < values.size then
let value := values[i].abstractRange i xs
let value := value.instantiateRev subst
go (i+1) (subst.push value)
else
(body.abstract xs).instantiateRev subst
return go 0 #[]
termination_by go => values.size - i
/--
Create the "key" that uniquely identifies a code specialization.
`params` and `decls` are the declarations collected by the `collect` function above.
-/
def mkKey (params : Array Param) (decls : Array CodeDecl) (body : Expr) : CompilerM Expr := do
let body ← expandCodeDecls decls body
-- TODO: normalize universe level parameters
return ToExpr.run do
ToExpr.withParams params do
ToExpr.mkLambdaM params (← ToExpr.abstractM body)
/--
Try to specialize the function application in the given let-declaration.
`k` is the continuation for the let-declaration.
-/
def specializeApp? (letDecl : LetDecl) (_k : Code) : SpecializeM (Option Code) := do
unless letDecl.value.isApp do return none
let f := letDecl.value.getAppFn
let .const declName _us := f | return none
if (← Meta.isInstance declName) then return none
let some paramsInfo ← getSpecParamInfo? declName | return none
let args := letDecl.value.getAppArgs
unless (← shouldSpecialize paramsInfo args) do return none
let some _decl ← getStage1Decl? declName | return none
trace[Compiler.specialize.candidate] "{letDecl.value}, {paramsInfo}"
let (argMask, { params, decls, .. }) ← Collector.collect paramsInfo args |>.run {}
let keyBody := mkAppN f argMask
let key ← mkKey params decls keyBody
trace[Compiler.specialize.candidate] "key: {key}"
assert! !key.hasLooseBVars
assert! !key.hasFVar
trace[Compiler.specialize.candidate] "decls: {decls.map fun | .let decl | .jp decl | .fun decl => decl.binderName}"
trace[Compiler.specialize.candidate] "params: {params.map fun p => p.binderName}"
-- TODO
return none
mutual
partial def visitFunDecl (funDecl : FunDecl) : SpecializeM FunDecl := do
let value ← withParams funDecl.params <| visitCode funDecl.value
funDecl.update' funDecl.type value
partial def visitCode (code : Code) : SpecializeM Code := do
match code with
| .let decl k =>
if let some k ← specializeApp? decl k then
visitCode k
else
let k ← withLetDecl decl <| visitCode k
return code.updateLet! decl k
| .fun decl k | .jp decl k =>
let decl ← visitFunDecl decl
let k ← withFVar decl.fvarId <| visitCode k
return code.updateFun! decl k
| .cases c =>
let alts ← c.alts.mapMonoM fun alt =>
match alt with
| .default k => return alt.updateCode (← visitCode k)
| .alt _ ps k => withParams ps do return alt.updateCode (← visitCode k)
return code.updateAlts! alts
| .unreach .. | .jmp .. | .return .. => return code
end
def main (decl : Decl) : SpecializeM Decl := do
if (← isTemplateLike decl) then
return decl
else
let value ← withParams decl.params <| visitCode decl.value
return { decl with value }
end Specialize
partial def Decl.specialize (decl : Decl) : CompilerM (Array Decl) := do
let (decl, s) ← Specialize.main decl |>.run {} |>.run {}
return s.decls.push decl
def specialize : Pass where
phase := .base
name := `specialize
run := fun decls => do
saveSpecParamInfo decls
decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.specialize)
builtin_initialize
registerTraceClass `Compiler.specialize (inherited := true)
registerTraceClass `Compiler.specialize.candidate
end Lean.Compiler.LCNF