feat: improve fixed parameter analyzer

This commit is contained in:
Leonardo de Moura 2022-10-09 08:01:57 -07:00
parent d4219c9d70
commit 37a61568bc
2 changed files with 154 additions and 36 deletions

View file

@ -7,38 +7,152 @@ import Lean.Compiler.LCNF.Basic
import Lean.Compiler.LCNF.Types
namespace Lean.Compiler.LCNF
namespace FixedParams
private abbrev Visitor := NameMap (Array Bool) → NameMap (Array Bool)
/-! # Fixed Parameter Static Analyzer -/
private partial def updateMap (decls : Array Decl) (code : Code) : Visitor :=
go code
where
goLetDecl (letDecl : LetDecl) : Visitor := fun s => Id.run do
let .const declName _ := letDecl.value.getAppFn | return s
let some mask := s.find? declName | return s
for decl in decls do
if decl.name == declName then
-- Recall that mask.size == decl.params.size
let mut mask := mask
let args := letDecl.value.getAppArgs
let sz := Nat.min args.size decl.params.size
for i in [:sz] do
let arg := args[i]!
let param := decl.params[i]!
unless arg.isFVarOf param.fvarId || (arg.isErased && param.type.isErased) do
mask := mask.set! i false
-- If the declaration is partially applied, we assume the missing arguments are not fixed
for i in [args.size:decl.params.size] do
mask := mask.set! i false
return s.insert decl.name mask
return s
/-
If function is partially applied, we assume missing parameters are not fixed.
Note that, since LCNF is in A-normal form, if function is used as an argument, none of its parameters
will be considered fixed.
go (code : Code) : Visitor :=
match code with
| .let decl k => go k ∘ goLetDecl decl
| .fun decl k | .jp decl k => go k ∘ go decl.value
| .cases c => fun s => c.alts.foldl (init := s) fun s alt => go alt.getCode s
| .unreach .. | .jmp .. | .return .. => id
We assume functions that are not in mutual block do not invoke the function being
analyzed.
We track fixed arguments using "abstract values". They are just a "token" or .
We can view it as a form of constant propagation.
When analyzing a function call to another function in the same mutual block,
we visit its body after binding its parameters to "abstract values". We keep a cache
of the already visited pairs ("declName", "abstract values").
Whenever, we find a recursive call to the function being analyzed, we check whether
the arguments match the initial "abstract values".
We interrupt the analysis if all parameters of the function being analyzed have been
marked as not fixed.
-/
/-- Abstract value for the "fixed parameter" analysis. -/
inductive Value where
| top
| erased
| val (i : Nat)
deriving Inhabited, BEq, Hashable
structure Context where
/-- Declaration in the same mutual block. -/
decls : Array Decl
/--
Function being analyzed. We check every recursive call to this function.
Remark: `main` is in `decls`.
-/
main : Decl
/--
The assignment maps free variable ids in the current code being analyzed to abstract values.
We only track the abstract value assigned to parameters.
-/
assignment : FVarIdMap Value
structure State where
/--
Set of calls that have been already analyzed.
Recall that we assume that only functions in `decls` may have recursive calls to the function being analyzed (i.e., `main`).
Whenever there is function application `f a₁ ... aₙ`, where `f` is in `decls`, `f` is not `main`, and
we visit with the abstract values assigned to `aᵢ`, but first we record the visit here.
-/
visited : HashSet (Name × Array Value) := {}
/--
Bitmask containing the result, i.e., which parameters of `main` are fixed.
We initialize it with `true` everywhere.
-/
fixed : Array Bool
/-- Monad for the fixed parameter static analyzer. We use the unit-exception to interrupt the analysis. -/
abbrev FixParamM := ReaderT Context <| EStateM Unit State
/-- Stop the analysis and mark all parameters as non-fixed. -/
abbrev abort : FixParamM α := do
modify fun s => { s with fixed := s.fixed.map fun _ => false }
throw ()
def evalArg (arg : Expr) : FixParamM Value := do
if arg.isErased then
return .erased
let .fvar fvarId := arg | return .top
let some val := (← read).assignment.find? fvarId | return .top
return val
def inMutualBlock (declName : Name) : FixParamM Bool :=
return (← read).decls.any (·.name == declName)
def mkAssignment (decl : Decl) (values : Array Value) : FVarIdMap Value := Id.run do
let mut assignment := {}
for param in decl.params, value in values do
assignment := assignment.insert param.fvarId value
return assignment
mutual
partial def evalExpr (e : Expr) : FixParamM Unit := do
match e with
| .const declName _ => evalApp declName #[]
| .app .. =>
let .const declName _ := e.getAppFn | return ()
if (← inMutualBlock declName) then
evalApp declName e.getAppArgs
| _ => return ()
partial def evalCode (code : Code) : FixParamM Unit := do
match code with
| .let decl k => evalExpr decl.value; evalCode k
| .fun decl k | .jp decl k => evalCode decl.value; evalCode k
| .cases c => c.alts.forM fun alt => evalCode alt.getCode
| .unreach .. | .jmp .. | .return .. => return ()
partial def evalApp (declName : Name) (args : Array Expr) : FixParamM Unit := do
let main := (← read).main
if declName == main.name then
-- Recursive call to the function being analyzed
for h : i in [:main.params.size] do
if _h : i < args.size then
have : i < main.params.size := h.upper
let param := main.params[i]
let val ← evalArg args[i]
unless val == .val i || (val == .erased && param.type.isErased) do
-- Found non fixed argument
-- Remark: if the argument is erased and the type of the parameter is erased we assume it is a fixed "propositonal" parameter.
modify fun s => { s with fixed := s.fixed.set! i false }
else
-- Partial application mark argument as not fixed
modify fun s => { s with fixed := s.fixed.set! i false }
unless (← get).fixed.contains true do
throw () -- stop analysis, none of the arguments are fixed.
for decl in (← read).decls do
if declName == decl.name then
-- Call to another function in the same mutual block.
let mut values := #[]
for i in [:decl.params.size] do
if h : i < args.size then
values := values.push (← evalArg args[i])
else
values := values.push .top
let key := (declName, values)
unless (← get).visited.contains key do
modify fun s => { s with visited := s.visited.insert key }
let assignment := mkAssignment decl values
withReader (fun ctx => { ctx with assignment }) <| evalCode decl.value
end
def mkInitialValues (numParams : Nat) : Array Value := Id.run do
let mut values := #[]
for i in [:numParams] do
values := values.push <| .val i
return values
end FixedParams
open FixedParams
/--
Given the (potentially mutually) recursive declarations `decls`,
@ -48,12 +162,14 @@ applications.
The function assumes that if a function `f` was declared in a mutual block, then `decls`
contains all (computationally relevant) functions in the mutual block.
-/
def mkFixedArgMap (decls : Array Decl) : NameMap (Array Bool) := Id.run do
let mut m := {}
def mkFixedParamsMap (decls : Array Decl) : NameMap (Array Bool) := Id.run do
let mut result := {}
for decl in decls do
m := m.insert decl.name (mkArray decl.params.size true)
for decl in decls do
m := updateMap decls decl.value m
return m
let values := mkInitialValues decl.params.size
let assignment := mkAssignment decl values
let fixed := Array.mkArray decl.params.size true
match evalCode decl.value |>.run { main := decl, decls, assignment } |>.run { fixed } with
| .ok _ s | .error _ s => result := result.insert decl.name s.fixed
return result
end Lean.Compiler.LCNF

View file

@ -133,13 +133,15 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do
pure .other
paramsInfo := paramsInfo.push info
pure ()
trace[Compiler.specialize.info] ">> {decl.name} {paramsInfo}"
declsInfo := declsInfo.push paramsInfo
if declsInfo.any fun paramsInfo => paramsInfo.any (· matches .user | .fixedInst | .fixedHO) then
let m := mkFixedArgMap decls
let m := mkFixedParamsMap decls
for i in [:decls.size] do
let decl := decls[i]!
let paramsInfo := declsInfo[i]!
let some mask := m.find? decl.name | unreachable!
trace[Compiler.specialize.info] "{decl.name} {mask}"
let paramsInfo := paramsInfo.zipWith mask fun info mask => if mask || info matches .user then info else .other
if paramsInfo.any fun info => info matches .fixedInst | .fixedHO | .user then
trace[Compiler.specialize.info] "{decl.name} {paramsInfo}"