From f11e44910bbd90be7164cc671eb80aac07febdfc Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 7 Oct 2022 15:56:10 -0700 Subject: [PATCH] refactor: add `Closure.lean` This module will also be used by the lambda lifter. --- src/Lean/Compiler/LCNF/Closure.lean | 142 +++++++++++++++++++++++++ src/Lean/Compiler/LCNF/Specialize.lean | 131 +++-------------------- 2 files changed, 157 insertions(+), 116 deletions(-) create mode 100644 src/Lean/Compiler/LCNF/Closure.lean diff --git a/src/Lean/Compiler/LCNF/Closure.lean b/src/Lean/Compiler/LCNF/Closure.lean new file mode 100644 index 0000000000..39a3c0feaf --- /dev/null +++ b/src/Lean/Compiler/LCNF/Closure.lean @@ -0,0 +1,142 @@ +/- +Copyright (c) 2022 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +import Lean.Util.ForEachExpr +import Lean.Compiler.LCNF.CompilerM + +namespace Lean.Compiler.LCNF +namespace Closure + +/-! +# Dependency collector for code specialization and lambda lifting. + +During code specialization and lambda lifting, we have code `C` containing free variables. These free variables +are in a scope, and we say we are computing `C`'s closure. +This module is used to compute the closure. +-/ + +structure Context where + /-- + `inScope x` returns `true` if `x` is a variable that is not in `C`. + -/ + inScope : FVarId → Bool + /-- + If `abstractLet x` returns `true`, we convert `x` into a closure parameter. Otherwise, + we collect the dependecies in the `let`-value too, and include the `let`-declaration in the closure. + Remark: the lambda lifting pass abstracts all `let`-declarations. + -/ + abstractLet : LetDecl → Bool + +/-- +State for the `ClosureM` monad. +-/ +structure State where + /-- + Set of already visited free variables. + -/ + visited : FVarIdSet := {} + /-- + Free variables that must become new parameters of the code being specialized. + -/ + params : Array Param := #[] + /-- + Let-declarations and local function declarations that are going to be "copied" to the code + being processed. For example, when this module is used in the code specializer, the let-declarations + often contain the instance values. In the current specialization heuristic all let-declarations are ground values + (i.e., they do not contain free-variables). + However, local function declarations may contain free variables. + + All customers of this module try to avoid work duplication. If a let-declaration is a ground value, + it most likely will be computed during compilation time, and work duplication is not an issue. + -/ + decls : Array CodeDecl := #[] + +/-- +Monad for implementing the dependency collector. +-/ +abbrev ClosureM := ReaderT Context $ StateRefT State CompilerM + +/-- +Mark a free variable as already visited. +We perform a topological sort over the dependencies. +-/ +def markVisited (fvarId : FVarId) : ClosureM Unit := + modify fun s => { s with visited := s.visited.insert fvarId } + +mutual + /-- + Collect dependencies in parameters. We need this because parameters may + contain other type parameters. + -/ + partial def collectParams (params : Array Param) : ClosureM Unit := + params.forM (collectExpr ·.type) + + /-- + Collect dependencies in the given code. We need this function to be able + to collect dependencies in a local function declaration. + -/ + partial def collectCode (c : Code) : ClosureM Unit := do + match c with + | .let decl k => collectExpr decl.type; collectExpr decl.value; collectCode k + | .fun decl k | .jp decl k => collectFunDecl decl; collectCode k + | .cases c => + collectExpr c.resultType + collectFVar c.discr + c.alts.forM fun alt => do + match alt with + | .default k => collectCode k + | .alt _ ps k => collectParams ps; collectCode k + | .jmp _ args => args.forM collectExpr + | .unreach type => collectExpr type + | .return fvarId => collectFVar fvarId + + /-- Collect dependencies of a local function declaration. -/ + partial def collectFunDecl (decl : FunDecl) : ClosureM Unit := do + collectExpr decl.type + collectParams decl.params + collectCode decl.value + + /-- + Process the given free variable. + If it has not already been visited and is in scope, we collect its dependencies. + -/ + partial def collectFVar (fvarId : FVarId) : ClosureM Unit := do + unless (← get).visited.contains fvarId do + markVisited fvarId + if (← read).inScope fvarId then + /- We only collect the variables in the scope of the function application being specialized. -/ + if let some funDecl ← findFunDecl? fvarId then + collectFunDecl funDecl + modify fun s => { s with decls := s.decls.push <| .fun funDecl } + else if let some param ← findParam? fvarId then + collectExpr param.type + modify fun s => { s with params := s.params.push param } + else if let some letDecl ← findLetDecl? fvarId then + collectExpr letDecl.type + if (← read).abstractLet letDecl then + -- It is a ground value, thus we keep collecting dependencies + collectExpr letDecl.value + modify fun s => { s with decls := s.decls.push <| .let letDecl } + else + -- It is not a ground value, we convert declaration into a parameter + modify fun s => { s with params := s.params.push <| { letDecl with borrow := false } } + else + unreachable! + + /-- Collect dependencies of the given expression. -/ + partial def collectExpr (e : Expr) : ClosureM Unit := do + e.forEach fun e => do + match e with + | .fvar fvarId => collectFVar fvarId + | _ => pure () +end + +def run (x : ClosureM α) (inScope : FVarId → Bool) (abstractLet : LetDecl → Bool := fun _ => false) : CompilerM (α × Array Param × Array CodeDecl) := do + let (a, s) ← x { inScope, abstractLet } |>.run {} + return (a, s.params, s.decls) + +end Closure + +end Lean.Compiler.LCNF \ No newline at end of file diff --git a/src/Lean/Compiler/LCNF/Specialize.lean b/src/Lean/Compiler/LCNF/Specialize.lean index 0d30d6221c..8fd5227e6b 100644 --- a/src/Lean/Compiler/LCNF/Specialize.lean +++ b/src/Lean/Compiler/LCNF/Specialize.lean @@ -3,7 +3,6 @@ Copyright (c) 2022 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ -import Lean.Util.ForEachExpr import Lean.Compiler.Specialize import Lean.Compiler.LCNF.Simp import Lean.Compiler.LCNF.SpecInfo @@ -12,6 +11,7 @@ import Lean.Compiler.LCNF.ToExpr import Lean.Compiler.LCNF.Level import Lean.Compiler.LCNF.PhaseExt import Lean.Compiler.LCNF.MonadScope +import Lean.Compiler.LCNF.Closure namespace Lean.Compiler.LCNF namespace Specialize @@ -113,110 +113,6 @@ and The keys never contain free variables or loose bound variables. -/ -/-- -State for the `CollectorM` monad. --/ -structure State where - /-- - Set of already visited free variables. - -/ - visited : FVarIdSet := {} - /-- - Free variables that must become new parameters of the code being specialized. - -/ - params : Array Param := #[] - /-- - Let-declarations and local function declarations that are going to be "copied" to the code - being specialized. For example, the let-declarations often contain the instance values. - In the current specialization heuristic all let-declarations are ground values (i.e., they do not contain free-variables). - However, local function declarations may contain free variables. - - The current heuristic tries to avoid work duplication. If a let-declaration is a ground value, - it most likely will be computed during compilation time, and work duplication is not an issue. - -/ - decls : Array CodeDecl := #[] - -/-- -Monad for implementing the code specializer dependency collector. -See `collect` --/ -abbrev CollectorM := StateRefT State SpecializeM - -/-- -Mark a free variable as already visited. -We perform a topological sort over the dependencies. --/ -def markVisited (fvarId : FVarId) : CollectorM Unit := - modify fun s => { s with visited := s.visited.insert fvarId } - -mutual - /-- - Collect dependencies in parameters. We need this because parameters may - contain other type parameters. - -/ - partial def collectParams (params : Array Param) : CollectorM Unit := - params.forM (collectExpr ·.type) - - /-- - Collect dependencies in the given code. We need this function to be able - to collect dependencies in a local function declaration. - -/ - partial def collectCode (c : Code) : CollectorM Unit := do - match c with - | .let decl k => collectExpr decl.type; collectExpr decl.value; collectCode k - | .fun decl k | .jp decl k => collectFunDecl decl; collectCode k - | .cases c => - collectExpr c.resultType - collectFVar c.discr - c.alts.forM fun alt => do - match alt with - | .default k => collectCode k - | .alt _ ps k => collectParams ps; collectCode k - | .jmp _ args => args.forM collectExpr - | .unreach type => collectExpr type - | .return fvarId => collectFVar fvarId - - /-- Collect dependencies of a local function declaration. -/ - partial def collectFunDecl (decl : FunDecl) : CollectorM Unit := do - collectExpr decl.type - collectParams decl.params - collectCode decl.value - - /-- - Process the given free variable. - If it has not already been visited and is in scope, we collect its dependencies. - -/ - partial def collectFVar (fvarId : FVarId) : CollectorM Unit := do - unless (← get).visited.contains fvarId do - markVisited fvarId - if (← inScope fvarId) then - /- We only collect the variables in the scope of the function application being specialized. -/ - if let some funDecl ← findFunDecl? fvarId then - collectFunDecl funDecl - modify fun s => { s with decls := s.decls.push <| .fun funDecl } - else if let some param ← findParam? fvarId then - collectExpr param.type - modify fun s => { s with params := s.params.push param } - else if let some letDecl ← findLetDecl? fvarId then - collectExpr letDecl.type - if (← isGround letDecl.value) then - -- It is a ground value, thus we keep collecting dependencies - collectExpr letDecl.value - modify fun s => { s with decls := s.decls.push <| .let letDecl } - else - -- It is not a ground value, we convert declaration into a parameter - modify fun s => { s with params := s.params.push <| { letDecl with borrow := false } } - else - unreachable! - - /-- Collect dependencies of the given expression. -/ - partial def collectExpr (e : Expr) : CollectorM Unit := do - e.forEach fun e => do - match e with - | .fvar fvarId => collectFVar fvarId - | _ => pure () -end - /-- Given the specialization mask `paramsInfo` and the arguments `args`, collect their dependencies, and return an array `mask` of size `paramsInfo.size` s.t. @@ -226,16 +122,19 @@ That is, `mask` contains only the arguments that are contributing to the code sp We use this information to compute a "key" to uniquely identify the code specialization, and creating the specialized code. -/ -def collect (paramsInfo : Array SpecParamInfo) (args : Array Expr) : CollectorM (Array (Option Expr)) := do - let mut argMask := #[] - for paramInfo in paramsInfo, arg in args do - match paramInfo with - | .other => - argMask := argMask.push none - | .fixedNeutral | .user | .fixedInst | .fixedHO => - argMask := argMask.push (some arg) - collectExpr arg - return argMask +def collect (paramsInfo : Array SpecParamInfo) (args : Array Expr) : SpecializeM (Array (Option Expr) × Array Param × Array CodeDecl) := do + let ctx ← read + let isGround decl := ctx.ground.contains decl.fvarId + Closure.run (inScope := ctx.scope.contains) (abstractLet := isGround) do + let mut argMask := #[] + for paramInfo in paramsInfo, arg in args do + match paramInfo with + | .other => + argMask := argMask.push none + | .fixedNeutral | .user | .fixedInst | .fixedHO => + argMask := argMask.push (some arg) + Closure.collectExpr arg + return argMask end Collector @@ -356,7 +255,7 @@ mutual unless (← shouldSpecialize paramsInfo args) do return none let some decl ← getDecl? declName | return none trace[Compiler.specialize.candidate] "{e}, {paramsInfo}" - let (argMask, { params, decls, .. }) ← Collector.collect paramsInfo args |>.run {} + let (argMask, params, decls) ← Collector.collect paramsInfo args let keyBody := mkAppN f (argMask.filterMap id) let (key, levelParamsNew) ← mkKey params decls keyBody trace[Compiler.specialize.candidate] "key: {key}"