From 37a61568bce57df4ec87c598d3684bfe737468f6 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 9 Oct 2022 08:01:57 -0700 Subject: [PATCH] feat: improve fixed parameter analyzer --- src/Lean/Compiler/LCNF/FixedArgs.lean | 186 +++++++++++++++++++++----- src/Lean/Compiler/LCNF/SpecInfo.lean | 4 +- 2 files changed, 154 insertions(+), 36 deletions(-) diff --git a/src/Lean/Compiler/LCNF/FixedArgs.lean b/src/Lean/Compiler/LCNF/FixedArgs.lean index d34c103106..c6e8e21e93 100644 --- a/src/Lean/Compiler/LCNF/FixedArgs.lean +++ b/src/Lean/Compiler/LCNF/FixedArgs.lean @@ -7,38 +7,152 @@ import Lean.Compiler.LCNF.Basic import Lean.Compiler.LCNF.Types namespace Lean.Compiler.LCNF +namespace FixedParams -private abbrev Visitor := NameMap (Array Bool) → NameMap (Array Bool) +/-! # Fixed Parameter Static Analyzer -/ -private partial def updateMap (decls : Array Decl) (code : Code) : Visitor := - go code -where - goLetDecl (letDecl : LetDecl) : Visitor := fun s => Id.run do - let .const declName _ := letDecl.value.getAppFn | return s - let some mask := s.find? declName | return s - for decl in decls do - if decl.name == declName then - -- Recall that mask.size == decl.params.size - let mut mask := mask - let args := letDecl.value.getAppArgs - let sz := Nat.min args.size decl.params.size - for i in [:sz] do - let arg := args[i]! - let param := decl.params[i]! - unless arg.isFVarOf param.fvarId || (arg.isErased && param.type.isErased) do - mask := mask.set! i false - -- If the declaration is partially applied, we assume the missing arguments are not fixed - for i in [args.size:decl.params.size] do - mask := mask.set! i false - return s.insert decl.name mask - return s +/- +If function is partially applied, we assume missing parameters are not fixed. +Note that, since LCNF is in A-normal form, if function is used as an argument, none of its parameters +will be considered fixed. - go (code : Code) : Visitor := - match code with - | .let decl k => go k ∘ goLetDecl decl - | .fun decl k | .jp decl k => go k ∘ go decl.value - | .cases c => fun s => c.alts.foldl (init := s) fun s alt => go alt.getCode s - | .unreach .. | .jmp .. | .return .. => id +We assume functions that are not in mutual block do not invoke the function being +analyzed. + +We track fixed arguments using "abstract values". They are just a "token" or ⊤. +We can view it as a form of constant propagation. + +When analyzing a function call to another function in the same mutual block, +we visit its body after binding its parameters to "abstract values". We keep a cache +of the already visited pairs ("declName", "abstract values"). + +Whenever, we find a recursive call to the function being analyzed, we check whether +the arguments match the initial "abstract values". + +We interrupt the analysis if all parameters of the function being analyzed have been +marked as not fixed. +-/ + +/-- Abstract value for the "fixed parameter" analysis. -/ +inductive Value where + | top + | erased + | val (i : Nat) + deriving Inhabited, BEq, Hashable + +structure Context where + /-- Declaration in the same mutual block. -/ + decls : Array Decl + /-- + Function being analyzed. We check every recursive call to this function. + Remark: `main` is in `decls`. + -/ + main : Decl + /-- + The assignment maps free variable ids in the current code being analyzed to abstract values. + We only track the abstract value assigned to parameters. + -/ + assignment : FVarIdMap Value + +structure State where + /-- + Set of calls that have been already analyzed. + Recall that we assume that only functions in `decls` may have recursive calls to the function being analyzed (i.e., `main`). + Whenever there is function application `f a₁ ... aₙ`, where `f` is in `decls`, `f` is not `main`, and + we visit with the abstract values assigned to `aᵢ`, but first we record the visit here. + -/ + visited : HashSet (Name × Array Value) := {} + /-- + Bitmask containing the result, i.e., which parameters of `main` are fixed. + We initialize it with `true` everywhere. + -/ + fixed : Array Bool + +/-- Monad for the fixed parameter static analyzer. We use the unit-exception to interrupt the analysis. -/ +abbrev FixParamM := ReaderT Context <| EStateM Unit State + +/-- Stop the analysis and mark all parameters as non-fixed. -/ +abbrev abort : FixParamM α := do + modify fun s => { s with fixed := s.fixed.map fun _ => false } + throw () + +def evalArg (arg : Expr) : FixParamM Value := do + if arg.isErased then + return .erased + let .fvar fvarId := arg | return .top + let some val := (← read).assignment.find? fvarId | return .top + return val + +def inMutualBlock (declName : Name) : FixParamM Bool := + return (← read).decls.any (·.name == declName) + +def mkAssignment (decl : Decl) (values : Array Value) : FVarIdMap Value := Id.run do + let mut assignment := {} + for param in decl.params, value in values do + assignment := assignment.insert param.fvarId value + return assignment + +mutual + +partial def evalExpr (e : Expr) : FixParamM Unit := do + match e with + | .const declName _ => evalApp declName #[] + | .app .. => + let .const declName _ := e.getAppFn | return () + if (← inMutualBlock declName) then + evalApp declName e.getAppArgs + | _ => return () + +partial def evalCode (code : Code) : FixParamM Unit := do + match code with + | .let decl k => evalExpr decl.value; evalCode k + | .fun decl k | .jp decl k => evalCode decl.value; evalCode k + | .cases c => c.alts.forM fun alt => evalCode alt.getCode + | .unreach .. | .jmp .. | .return .. => return () + +partial def evalApp (declName : Name) (args : Array Expr) : FixParamM Unit := do + let main := (← read).main + if declName == main.name then + -- Recursive call to the function being analyzed + for h : i in [:main.params.size] do + if _h : i < args.size then + have : i < main.params.size := h.upper + let param := main.params[i] + let val ← evalArg args[i] + unless val == .val i || (val == .erased && param.type.isErased) do + -- Found non fixed argument + -- Remark: if the argument is erased and the type of the parameter is erased we assume it is a fixed "propositonal" parameter. + modify fun s => { s with fixed := s.fixed.set! i false } + else + -- Partial application mark argument as not fixed + modify fun s => { s with fixed := s.fixed.set! i false } + unless (← get).fixed.contains true do + throw () -- stop analysis, none of the arguments are fixed. + for decl in (← read).decls do + if declName == decl.name then + -- Call to another function in the same mutual block. + let mut values := #[] + for i in [:decl.params.size] do + if h : i < args.size then + values := values.push (← evalArg args[i]) + else + values := values.push .top + let key := (declName, values) + unless (← get).visited.contains key do + modify fun s => { s with visited := s.visited.insert key } + let assignment := mkAssignment decl values + withReader (fun ctx => { ctx with assignment }) <| evalCode decl.value + +end + +def mkInitialValues (numParams : Nat) : Array Value := Id.run do + let mut values := #[] + for i in [:numParams] do + values := values.push <| .val i + return values + +end FixedParams +open FixedParams /-- Given the (potentially mutually) recursive declarations `decls`, @@ -48,12 +162,14 @@ applications. The function assumes that if a function `f` was declared in a mutual block, then `decls` contains all (computationally relevant) functions in the mutual block. -/ -def mkFixedArgMap (decls : Array Decl) : NameMap (Array Bool) := Id.run do - let mut m := {} +def mkFixedParamsMap (decls : Array Decl) : NameMap (Array Bool) := Id.run do + let mut result := {} for decl in decls do - m := m.insert decl.name (mkArray decl.params.size true) - for decl in decls do - m := updateMap decls decl.value m - return m + let values := mkInitialValues decl.params.size + let assignment := mkAssignment decl values + let fixed := Array.mkArray decl.params.size true + match evalCode decl.value |>.run { main := decl, decls, assignment } |>.run { fixed } with + | .ok _ s | .error _ s => result := result.insert decl.name s.fixed + return result end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/SpecInfo.lean b/src/Lean/Compiler/LCNF/SpecInfo.lean index 669904fb2d..18da8e9795 100644 --- a/src/Lean/Compiler/LCNF/SpecInfo.lean +++ b/src/Lean/Compiler/LCNF/SpecInfo.lean @@ -133,13 +133,15 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do pure .other paramsInfo := paramsInfo.push info pure () + trace[Compiler.specialize.info] ">> {decl.name} {paramsInfo}" declsInfo := declsInfo.push paramsInfo if declsInfo.any fun paramsInfo => paramsInfo.any (· matches .user | .fixedInst | .fixedHO) then - let m := mkFixedArgMap decls + let m := mkFixedParamsMap decls for i in [:decls.size] do let decl := decls[i]! let paramsInfo := declsInfo[i]! let some mask := m.find? decl.name | unreachable! + trace[Compiler.specialize.info] "{decl.name} {mask}" let paramsInfo := paramsInfo.zipWith mask fun info mask => if mask || info matches .user then info else .other if paramsInfo.any fun info => info matches .fixedInst | .fixedHO | .user then trace[Compiler.specialize.info] "{decl.name} {paramsInfo}"