refactor: add Closure.lean
This module will also be used by the lambda lifter.
This commit is contained in:
parent
e7a36f32f1
commit
f11e44910b
2 changed files with 157 additions and 116 deletions
142
src/Lean/Compiler/LCNF/Closure.lean
Normal file
142
src/Lean/Compiler/LCNF/Closure.lean
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Util.ForEachExpr
|
||||
import Lean.Compiler.LCNF.CompilerM
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace Closure
|
||||
|
||||
/-!
|
||||
# Dependency collector for code specialization and lambda lifting.
|
||||
|
||||
During code specialization and lambda lifting, we have code `C` containing free variables. These free variables
|
||||
are in a scope, and we say we are computing `C`'s closure.
|
||||
This module is used to compute the closure.
|
||||
-/
|
||||
|
||||
structure Context where
|
||||
/--
|
||||
`inScope x` returns `true` if `x` is a variable that is not in `C`.
|
||||
-/
|
||||
inScope : FVarId → Bool
|
||||
/--
|
||||
If `abstractLet x` returns `true`, we convert `x` into a closure parameter. Otherwise,
|
||||
we collect the dependecies in the `let`-value too, and include the `let`-declaration in the closure.
|
||||
Remark: the lambda lifting pass abstracts all `let`-declarations.
|
||||
-/
|
||||
abstractLet : LetDecl → Bool
|
||||
|
||||
/--
|
||||
State for the `ClosureM` monad.
|
||||
-/
|
||||
structure State where
|
||||
/--
|
||||
Set of already visited free variables.
|
||||
-/
|
||||
visited : FVarIdSet := {}
|
||||
/--
|
||||
Free variables that must become new parameters of the code being specialized.
|
||||
-/
|
||||
params : Array Param := #[]
|
||||
/--
|
||||
Let-declarations and local function declarations that are going to be "copied" to the code
|
||||
being processed. For example, when this module is used in the code specializer, the let-declarations
|
||||
often contain the instance values. In the current specialization heuristic all let-declarations are ground values
|
||||
(i.e., they do not contain free-variables).
|
||||
However, local function declarations may contain free variables.
|
||||
|
||||
All customers of this module try to avoid work duplication. If a let-declaration is a ground value,
|
||||
it most likely will be computed during compilation time, and work duplication is not an issue.
|
||||
-/
|
||||
decls : Array CodeDecl := #[]
|
||||
|
||||
/--
|
||||
Monad for implementing the dependency collector.
|
||||
-/
|
||||
abbrev ClosureM := ReaderT Context $ StateRefT State CompilerM
|
||||
|
||||
/--
|
||||
Mark a free variable as already visited.
|
||||
We perform a topological sort over the dependencies.
|
||||
-/
|
||||
def markVisited (fvarId : FVarId) : ClosureM Unit :=
|
||||
modify fun s => { s with visited := s.visited.insert fvarId }
|
||||
|
||||
mutual
|
||||
/--
|
||||
Collect dependencies in parameters. We need this because parameters may
|
||||
contain other type parameters.
|
||||
-/
|
||||
partial def collectParams (params : Array Param) : ClosureM Unit :=
|
||||
params.forM (collectExpr ·.type)
|
||||
|
||||
/--
|
||||
Collect dependencies in the given code. We need this function to be able
|
||||
to collect dependencies in a local function declaration.
|
||||
-/
|
||||
partial def collectCode (c : Code) : ClosureM Unit := do
|
||||
match c with
|
||||
| .let decl k => collectExpr decl.type; collectExpr decl.value; collectCode k
|
||||
| .fun decl k | .jp decl k => collectFunDecl decl; collectCode k
|
||||
| .cases c =>
|
||||
collectExpr c.resultType
|
||||
collectFVar c.discr
|
||||
c.alts.forM fun alt => do
|
||||
match alt with
|
||||
| .default k => collectCode k
|
||||
| .alt _ ps k => collectParams ps; collectCode k
|
||||
| .jmp _ args => args.forM collectExpr
|
||||
| .unreach type => collectExpr type
|
||||
| .return fvarId => collectFVar fvarId
|
||||
|
||||
/-- Collect dependencies of a local function declaration. -/
|
||||
partial def collectFunDecl (decl : FunDecl) : ClosureM Unit := do
|
||||
collectExpr decl.type
|
||||
collectParams decl.params
|
||||
collectCode decl.value
|
||||
|
||||
/--
|
||||
Process the given free variable.
|
||||
If it has not already been visited and is in scope, we collect its dependencies.
|
||||
-/
|
||||
partial def collectFVar (fvarId : FVarId) : ClosureM Unit := do
|
||||
unless (← get).visited.contains fvarId do
|
||||
markVisited fvarId
|
||||
if (← read).inScope fvarId then
|
||||
/- We only collect the variables in the scope of the function application being specialized. -/
|
||||
if let some funDecl ← findFunDecl? fvarId then
|
||||
collectFunDecl funDecl
|
||||
modify fun s => { s with decls := s.decls.push <| .fun funDecl }
|
||||
else if let some param ← findParam? fvarId then
|
||||
collectExpr param.type
|
||||
modify fun s => { s with params := s.params.push param }
|
||||
else if let some letDecl ← findLetDecl? fvarId then
|
||||
collectExpr letDecl.type
|
||||
if (← read).abstractLet letDecl then
|
||||
-- It is a ground value, thus we keep collecting dependencies
|
||||
collectExpr letDecl.value
|
||||
modify fun s => { s with decls := s.decls.push <| .let letDecl }
|
||||
else
|
||||
-- It is not a ground value, we convert declaration into a parameter
|
||||
modify fun s => { s with params := s.params.push <| { letDecl with borrow := false } }
|
||||
else
|
||||
unreachable!
|
||||
|
||||
/-- Collect dependencies of the given expression. -/
|
||||
partial def collectExpr (e : Expr) : ClosureM Unit := do
|
||||
e.forEach fun e => do
|
||||
match e with
|
||||
| .fvar fvarId => collectFVar fvarId
|
||||
| _ => pure ()
|
||||
end
|
||||
|
||||
def run (x : ClosureM α) (inScope : FVarId → Bool) (abstractLet : LetDecl → Bool := fun _ => false) : CompilerM (α × Array Param × Array CodeDecl) := do
|
||||
let (a, s) ← x { inScope, abstractLet } |>.run {}
|
||||
return (a, s.params, s.decls)
|
||||
|
||||
end Closure
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
@ -3,7 +3,6 @@ Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
|||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Util.ForEachExpr
|
||||
import Lean.Compiler.Specialize
|
||||
import Lean.Compiler.LCNF.Simp
|
||||
import Lean.Compiler.LCNF.SpecInfo
|
||||
|
|
@ -12,6 +11,7 @@ import Lean.Compiler.LCNF.ToExpr
|
|||
import Lean.Compiler.LCNF.Level
|
||||
import Lean.Compiler.LCNF.PhaseExt
|
||||
import Lean.Compiler.LCNF.MonadScope
|
||||
import Lean.Compiler.LCNF.Closure
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace Specialize
|
||||
|
|
@ -113,110 +113,6 @@ and
|
|||
The keys never contain free variables or loose bound variables.
|
||||
-/
|
||||
|
||||
/--
|
||||
State for the `CollectorM` monad.
|
||||
-/
|
||||
structure State where
|
||||
/--
|
||||
Set of already visited free variables.
|
||||
-/
|
||||
visited : FVarIdSet := {}
|
||||
/--
|
||||
Free variables that must become new parameters of the code being specialized.
|
||||
-/
|
||||
params : Array Param := #[]
|
||||
/--
|
||||
Let-declarations and local function declarations that are going to be "copied" to the code
|
||||
being specialized. For example, the let-declarations often contain the instance values.
|
||||
In the current specialization heuristic all let-declarations are ground values (i.e., they do not contain free-variables).
|
||||
However, local function declarations may contain free variables.
|
||||
|
||||
The current heuristic tries to avoid work duplication. If a let-declaration is a ground value,
|
||||
it most likely will be computed during compilation time, and work duplication is not an issue.
|
||||
-/
|
||||
decls : Array CodeDecl := #[]
|
||||
|
||||
/--
|
||||
Monad for implementing the code specializer dependency collector.
|
||||
See `collect`
|
||||
-/
|
||||
abbrev CollectorM := StateRefT State SpecializeM
|
||||
|
||||
/--
|
||||
Mark a free variable as already visited.
|
||||
We perform a topological sort over the dependencies.
|
||||
-/
|
||||
def markVisited (fvarId : FVarId) : CollectorM Unit :=
|
||||
modify fun s => { s with visited := s.visited.insert fvarId }
|
||||
|
||||
mutual
|
||||
/--
|
||||
Collect dependencies in parameters. We need this because parameters may
|
||||
contain other type parameters.
|
||||
-/
|
||||
partial def collectParams (params : Array Param) : CollectorM Unit :=
|
||||
params.forM (collectExpr ·.type)
|
||||
|
||||
/--
|
||||
Collect dependencies in the given code. We need this function to be able
|
||||
to collect dependencies in a local function declaration.
|
||||
-/
|
||||
partial def collectCode (c : Code) : CollectorM Unit := do
|
||||
match c with
|
||||
| .let decl k => collectExpr decl.type; collectExpr decl.value; collectCode k
|
||||
| .fun decl k | .jp decl k => collectFunDecl decl; collectCode k
|
||||
| .cases c =>
|
||||
collectExpr c.resultType
|
||||
collectFVar c.discr
|
||||
c.alts.forM fun alt => do
|
||||
match alt with
|
||||
| .default k => collectCode k
|
||||
| .alt _ ps k => collectParams ps; collectCode k
|
||||
| .jmp _ args => args.forM collectExpr
|
||||
| .unreach type => collectExpr type
|
||||
| .return fvarId => collectFVar fvarId
|
||||
|
||||
/-- Collect dependencies of a local function declaration. -/
|
||||
partial def collectFunDecl (decl : FunDecl) : CollectorM Unit := do
|
||||
collectExpr decl.type
|
||||
collectParams decl.params
|
||||
collectCode decl.value
|
||||
|
||||
/--
|
||||
Process the given free variable.
|
||||
If it has not already been visited and is in scope, we collect its dependencies.
|
||||
-/
|
||||
partial def collectFVar (fvarId : FVarId) : CollectorM Unit := do
|
||||
unless (← get).visited.contains fvarId do
|
||||
markVisited fvarId
|
||||
if (← inScope fvarId) then
|
||||
/- We only collect the variables in the scope of the function application being specialized. -/
|
||||
if let some funDecl ← findFunDecl? fvarId then
|
||||
collectFunDecl funDecl
|
||||
modify fun s => { s with decls := s.decls.push <| .fun funDecl }
|
||||
else if let some param ← findParam? fvarId then
|
||||
collectExpr param.type
|
||||
modify fun s => { s with params := s.params.push param }
|
||||
else if let some letDecl ← findLetDecl? fvarId then
|
||||
collectExpr letDecl.type
|
||||
if (← isGround letDecl.value) then
|
||||
-- It is a ground value, thus we keep collecting dependencies
|
||||
collectExpr letDecl.value
|
||||
modify fun s => { s with decls := s.decls.push <| .let letDecl }
|
||||
else
|
||||
-- It is not a ground value, we convert declaration into a parameter
|
||||
modify fun s => { s with params := s.params.push <| { letDecl with borrow := false } }
|
||||
else
|
||||
unreachable!
|
||||
|
||||
/-- Collect dependencies of the given expression. -/
|
||||
partial def collectExpr (e : Expr) : CollectorM Unit := do
|
||||
e.forEach fun e => do
|
||||
match e with
|
||||
| .fvar fvarId => collectFVar fvarId
|
||||
| _ => pure ()
|
||||
end
|
||||
|
||||
/--
|
||||
Given the specialization mask `paramsInfo` and the arguments `args`,
|
||||
collect their dependencies, and return an array `mask` of size `paramsInfo.size` s.t.
|
||||
|
|
@ -226,16 +122,19 @@ That is, `mask` contains only the arguments that are contributing to the code sp
|
|||
We use this information to compute a "key" to uniquely identify the code specialization, and
|
||||
creating the specialized code.
|
||||
-/
|
||||
def collect (paramsInfo : Array SpecParamInfo) (args : Array Expr) : CollectorM (Array (Option Expr)) := do
|
||||
let mut argMask := #[]
|
||||
for paramInfo in paramsInfo, arg in args do
|
||||
match paramInfo with
|
||||
| .other =>
|
||||
argMask := argMask.push none
|
||||
| .fixedNeutral | .user | .fixedInst | .fixedHO =>
|
||||
argMask := argMask.push (some arg)
|
||||
collectExpr arg
|
||||
return argMask
|
||||
def collect (paramsInfo : Array SpecParamInfo) (args : Array Expr) : SpecializeM (Array (Option Expr) × Array Param × Array CodeDecl) := do
|
||||
let ctx ← read
|
||||
let isGround decl := ctx.ground.contains decl.fvarId
|
||||
Closure.run (inScope := ctx.scope.contains) (abstractLet := isGround) do
|
||||
let mut argMask := #[]
|
||||
for paramInfo in paramsInfo, arg in args do
|
||||
match paramInfo with
|
||||
| .other =>
|
||||
argMask := argMask.push none
|
||||
| .fixedNeutral | .user | .fixedInst | .fixedHO =>
|
||||
argMask := argMask.push (some arg)
|
||||
Closure.collectExpr arg
|
||||
return argMask
|
||||
|
||||
end Collector
|
||||
|
||||
|
|
@ -356,7 +255,7 @@ mutual
|
|||
unless (← shouldSpecialize paramsInfo args) do return none
|
||||
let some decl ← getDecl? declName | return none
|
||||
trace[Compiler.specialize.candidate] "{e}, {paramsInfo}"
|
||||
let (argMask, { params, decls, .. }) ← Collector.collect paramsInfo args |>.run {}
|
||||
let (argMask, params, decls) ← Collector.collect paramsInfo args
|
||||
let keyBody := mkAppN f (argMask.filterMap id)
|
||||
let (key, levelParamsNew) ← mkKey params decls keyBody
|
||||
trace[Compiler.specialize.candidate] "key: {key}"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue