refactor: add Closure.lean

This module will also be used by the lambda lifter.
This commit is contained in:
Leonardo de Moura 2022-10-07 15:56:10 -07:00
parent e7a36f32f1
commit f11e44910b
2 changed files with 157 additions and 116 deletions

View file

@ -0,0 +1,142 @@
/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Util.ForEachExpr
import Lean.Compiler.LCNF.CompilerM
namespace Lean.Compiler.LCNF
namespace Closure
/-!
# Dependency collector for code specialization and lambda lifting.
During code specialization and lambda lifting, we have code `C` containing free variables. These free variables
are in a scope, and we say we are computing `C`'s closure.
This module is used to compute the closure.
-/
structure Context where
/--
`inScope x` returns `true` if `x` is a variable that is not in `C`.
-/
inScope : FVarId → Bool
/--
If `abstractLet x` returns `true`, we convert `x` into a closure parameter. Otherwise,
we collect the dependecies in the `let`-value too, and include the `let`-declaration in the closure.
Remark: the lambda lifting pass abstracts all `let`-declarations.
-/
abstractLet : LetDecl → Bool
/--
State for the `ClosureM` monad.
-/
structure State where
/--
Set of already visited free variables.
-/
visited : FVarIdSet := {}
/--
Free variables that must become new parameters of the code being specialized.
-/
params : Array Param := #[]
/--
Let-declarations and local function declarations that are going to be "copied" to the code
being processed. For example, when this module is used in the code specializer, the let-declarations
often contain the instance values. In the current specialization heuristic all let-declarations are ground values
(i.e., they do not contain free-variables).
However, local function declarations may contain free variables.
All customers of this module try to avoid work duplication. If a let-declaration is a ground value,
it most likely will be computed during compilation time, and work duplication is not an issue.
-/
decls : Array CodeDecl := #[]
/--
Monad for implementing the dependency collector.
-/
abbrev ClosureM := ReaderT Context $ StateRefT State CompilerM
/--
Mark a free variable as already visited.
We perform a topological sort over the dependencies.
-/
def markVisited (fvarId : FVarId) : ClosureM Unit :=
modify fun s => { s with visited := s.visited.insert fvarId }
mutual
/--
Collect dependencies in parameters. We need this because parameters may
contain other type parameters.
-/
partial def collectParams (params : Array Param) : ClosureM Unit :=
params.forM (collectExpr ·.type)
/--
Collect dependencies in the given code. We need this function to be able
to collect dependencies in a local function declaration.
-/
partial def collectCode (c : Code) : ClosureM Unit := do
match c with
| .let decl k => collectExpr decl.type; collectExpr decl.value; collectCode k
| .fun decl k | .jp decl k => collectFunDecl decl; collectCode k
| .cases c =>
collectExpr c.resultType
collectFVar c.discr
c.alts.forM fun alt => do
match alt with
| .default k => collectCode k
| .alt _ ps k => collectParams ps; collectCode k
| .jmp _ args => args.forM collectExpr
| .unreach type => collectExpr type
| .return fvarId => collectFVar fvarId
/-- Collect dependencies of a local function declaration. -/
partial def collectFunDecl (decl : FunDecl) : ClosureM Unit := do
collectExpr decl.type
collectParams decl.params
collectCode decl.value
/--
Process the given free variable.
If it has not already been visited and is in scope, we collect its dependencies.
-/
partial def collectFVar (fvarId : FVarId) : ClosureM Unit := do
unless (← get).visited.contains fvarId do
markVisited fvarId
if (← read).inScope fvarId then
/- We only collect the variables in the scope of the function application being specialized. -/
if let some funDecl ← findFunDecl? fvarId then
collectFunDecl funDecl
modify fun s => { s with decls := s.decls.push <| .fun funDecl }
else if let some param ← findParam? fvarId then
collectExpr param.type
modify fun s => { s with params := s.params.push param }
else if let some letDecl ← findLetDecl? fvarId then
collectExpr letDecl.type
if (← read).abstractLet letDecl then
-- It is a ground value, thus we keep collecting dependencies
collectExpr letDecl.value
modify fun s => { s with decls := s.decls.push <| .let letDecl }
else
-- It is not a ground value, we convert declaration into a parameter
modify fun s => { s with params := s.params.push <| { letDecl with borrow := false } }
else
unreachable!
/-- Collect dependencies of the given expression. -/
partial def collectExpr (e : Expr) : ClosureM Unit := do
e.forEach fun e => do
match e with
| .fvar fvarId => collectFVar fvarId
| _ => pure ()
end
def run (x : ClosureM α) (inScope : FVarId → Bool) (abstractLet : LetDecl → Bool := fun _ => false) : CompilerM (α × Array Param × Array CodeDecl) := do
let (a, s) ← x { inScope, abstractLet } |>.run {}
return (a, s.params, s.decls)
end Closure
end Lean.Compiler.LCNF

View file

@ -3,7 +3,6 @@ Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Util.ForEachExpr
import Lean.Compiler.Specialize
import Lean.Compiler.LCNF.Simp
import Lean.Compiler.LCNF.SpecInfo
@ -12,6 +11,7 @@ import Lean.Compiler.LCNF.ToExpr
import Lean.Compiler.LCNF.Level
import Lean.Compiler.LCNF.PhaseExt
import Lean.Compiler.LCNF.MonadScope
import Lean.Compiler.LCNF.Closure
namespace Lean.Compiler.LCNF
namespace Specialize
@ -113,110 +113,6 @@ and
The keys never contain free variables or loose bound variables.
-/
/--
State for the `CollectorM` monad.
-/
structure State where
/--
Set of already visited free variables.
-/
visited : FVarIdSet := {}
/--
Free variables that must become new parameters of the code being specialized.
-/
params : Array Param := #[]
/--
Let-declarations and local function declarations that are going to be "copied" to the code
being specialized. For example, the let-declarations often contain the instance values.
In the current specialization heuristic all let-declarations are ground values (i.e., they do not contain free-variables).
However, local function declarations may contain free variables.
The current heuristic tries to avoid work duplication. If a let-declaration is a ground value,
it most likely will be computed during compilation time, and work duplication is not an issue.
-/
decls : Array CodeDecl := #[]
/--
Monad for implementing the code specializer dependency collector.
See `collect`
-/
abbrev CollectorM := StateRefT State SpecializeM
/--
Mark a free variable as already visited.
We perform a topological sort over the dependencies.
-/
def markVisited (fvarId : FVarId) : CollectorM Unit :=
modify fun s => { s with visited := s.visited.insert fvarId }
mutual
/--
Collect dependencies in parameters. We need this because parameters may
contain other type parameters.
-/
partial def collectParams (params : Array Param) : CollectorM Unit :=
params.forM (collectExpr ·.type)
/--
Collect dependencies in the given code. We need this function to be able
to collect dependencies in a local function declaration.
-/
partial def collectCode (c : Code) : CollectorM Unit := do
match c with
| .let decl k => collectExpr decl.type; collectExpr decl.value; collectCode k
| .fun decl k | .jp decl k => collectFunDecl decl; collectCode k
| .cases c =>
collectExpr c.resultType
collectFVar c.discr
c.alts.forM fun alt => do
match alt with
| .default k => collectCode k
| .alt _ ps k => collectParams ps; collectCode k
| .jmp _ args => args.forM collectExpr
| .unreach type => collectExpr type
| .return fvarId => collectFVar fvarId
/-- Collect dependencies of a local function declaration. -/
partial def collectFunDecl (decl : FunDecl) : CollectorM Unit := do
collectExpr decl.type
collectParams decl.params
collectCode decl.value
/--
Process the given free variable.
If it has not already been visited and is in scope, we collect its dependencies.
-/
partial def collectFVar (fvarId : FVarId) : CollectorM Unit := do
unless (← get).visited.contains fvarId do
markVisited fvarId
if (← inScope fvarId) then
/- We only collect the variables in the scope of the function application being specialized. -/
if let some funDecl ← findFunDecl? fvarId then
collectFunDecl funDecl
modify fun s => { s with decls := s.decls.push <| .fun funDecl }
else if let some param ← findParam? fvarId then
collectExpr param.type
modify fun s => { s with params := s.params.push param }
else if let some letDecl ← findLetDecl? fvarId then
collectExpr letDecl.type
if (← isGround letDecl.value) then
-- It is a ground value, thus we keep collecting dependencies
collectExpr letDecl.value
modify fun s => { s with decls := s.decls.push <| .let letDecl }
else
-- It is not a ground value, we convert declaration into a parameter
modify fun s => { s with params := s.params.push <| { letDecl with borrow := false } }
else
unreachable!
/-- Collect dependencies of the given expression. -/
partial def collectExpr (e : Expr) : CollectorM Unit := do
e.forEach fun e => do
match e with
| .fvar fvarId => collectFVar fvarId
| _ => pure ()
end
/--
Given the specialization mask `paramsInfo` and the arguments `args`,
collect their dependencies, and return an array `mask` of size `paramsInfo.size` s.t.
@ -226,16 +122,19 @@ That is, `mask` contains only the arguments that are contributing to the code sp
We use this information to compute a "key" to uniquely identify the code specialization, and
creating the specialized code.
-/
def collect (paramsInfo : Array SpecParamInfo) (args : Array Expr) : CollectorM (Array (Option Expr)) := do
let mut argMask := #[]
for paramInfo in paramsInfo, arg in args do
match paramInfo with
| .other =>
argMask := argMask.push none
| .fixedNeutral | .user | .fixedInst | .fixedHO =>
argMask := argMask.push (some arg)
collectExpr arg
return argMask
def collect (paramsInfo : Array SpecParamInfo) (args : Array Expr) : SpecializeM (Array (Option Expr) × Array Param × Array CodeDecl) := do
let ctx ← read
let isGround decl := ctx.ground.contains decl.fvarId
Closure.run (inScope := ctx.scope.contains) (abstractLet := isGround) do
let mut argMask := #[]
for paramInfo in paramsInfo, arg in args do
match paramInfo with
| .other =>
argMask := argMask.push none
| .fixedNeutral | .user | .fixedInst | .fixedHO =>
argMask := argMask.push (some arg)
Closure.collectExpr arg
return argMask
end Collector
@ -356,7 +255,7 @@ mutual
unless (← shouldSpecialize paramsInfo args) do return none
let some decl ← getDecl? declName | return none
trace[Compiler.specialize.candidate] "{e}, {paramsInfo}"
let (argMask, { params, decls, .. }) ← Collector.collect paramsInfo args |>.run {}
let (argMask, params, decls) ← Collector.collect paramsInfo args
let keyBody := mkAppN f (argMask.filterMap id)
let (key, levelParamsNew) ← mkKey params decls keyBody
trace[Compiler.specialize.candidate] "key: {key}"