From dac6127810609cb7d500dfc6116d2f451c858f79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Thu, 20 Oct 2022 00:51:32 +0200 Subject: [PATCH] feat: Compiler pass for reducing common jp args --- src/Lean/Compiler/LCNF/FVarUtil.lean | 15 +++ src/Lean/Compiler/LCNF/JoinPoints.lean | 168 +++++++++++++++++++++++++ src/Lean/Compiler/LCNF/Passes.lean | 1 + 3 files changed, 184 insertions(+) diff --git a/src/Lean/Compiler/LCNF/FVarUtil.lean b/src/Lean/Compiler/LCNF/FVarUtil.lean index 530ce03b90..c46d207b86 100644 --- a/src/Lean/Compiler/LCNF/FVarUtil.lean +++ b/src/Lean/Compiler/LCNF/FVarUtil.lean @@ -167,5 +167,20 @@ instance : TraverseFVar Alt where Code.forFVarM f c | .default c => Code.forFVarM f c +def anyFVarM [Monad m] [TraverseFVar α] (f : FVarId → m Bool) (x : α) : m Bool := do + let (_, res) ← TraverseFVar.forFVarM go x |>.run false + return res +where + -- TODO: StateRefT, early return? + go (fvar : FVarId) : StateT Bool m Unit := do + if (← f fvar) then set true + +def allFVarM [Monad m] [TraverseFVar α] (f : FVarId → m Bool) (x : α) : m Bool := do + let (_, res) ← TraverseFVar.forFVarM go x |>.run true + return res +where + -- TODO: StateRefT, early return? + go (fvar : FVarId) : StateT Bool m Unit := do + if !(← f fvar) then set false end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/JoinPoints.lean b/src/Lean/Compiler/LCNF/JoinPoints.lean index 64ab4b92c5..821c3c8e26 100644 --- a/src/Lean/Compiler/LCNF/JoinPoints.lean +++ b/src/Lean/Compiler/LCNF/JoinPoints.lean @@ -8,6 +8,7 @@ import Lean.Compiler.LCNF.PassManager import Lean.Compiler.LCNF.PullFunDecls import Lean.Compiler.LCNF.FVarUtil import Lean.Compiler.LCNF.ScopeM +import Lean.Compiler.LCNF.InferType namespace Lean.Compiler.LCNF @@ -439,6 +440,164 @@ where end JoinPointContextExtender +namespace JoinPointCommonArgs + +/-- +Context for `ReduceAnalysisM`. +-/ +structure AnalysisCtx where + /-- + The variables that are in scope at the time of the definition of + the join point. + -/ + jpScopes : FVarIdMap FVarIdSet := {} + +/-- +State for `ReduceAnalysisM`. +-/ +structure AnalysisState where + /-- + Lists of names of arguments of jmps to join points to find duplicates. + -/ + jpJmpArgs : FVarIdMap FVarSubst := {} + +abbrev ReduceAnalysisM := ReaderT AnalysisCtx StateRefT AnalysisState ScopeM +abbrev ReduceActionM := ReaderT AnalysisState CompilerM + +def isInJpScope (jp : FVarId) (var : FVarId) : ReduceAnalysisM Bool := do + return (← read).jpScopes.find! jp |>.contains var + +open ScopeM + +/-- +Take a look at each join point and each of their call sites. If all +call sites of a join point have one or more arguments in common, for example: +``` +jp _jp.1 a b c => ... +... +cases foo +| n1 => jmp _jp.1 d e f +| n2 => jmp _jp.1 g e h +``` +We can get rid of the common argument in favour of inlining it directly +into the join point (in this case the `e`). This reduces the amount of +arguments we have to pass around drastically for example in `ReaderT` based +monad stacks. + +Note 1: This transformation can in certain niche cases obtain better results. +For example: +``` +jp foo a b => .. +let x := ... +cases discr +| n1 => jmp foo x y +| n2 => jmp foo x z +``` +Here we will not collapse the `x` since it is defined after the join point `foo` +and thus not accessible for substitution yet. We could however reorder the code in +such a way that this is possible, this is currently not done since we observe +than in praxis most of the applications of this transformation can occur naturally +without reordering. + +Note 2: This transformation is kind of the opposite of `JoinPointContextExtender`. +However we still benefit from the extender because in the `simp` run after it +we might be able to pull join point declarations further up in the hierarchy +of nested functions/join points which in turn might enable additional optimizations. +After we have performed all of these optimizations we can take away the +(remaining) common arguments and end up with nicely floated and optimized +code that has as little arguments as possible in the join points. +-/ +partial def reduce (decl : Decl) : CompilerM Decl := do + let (_, analysis) ← goAnalyze decl.value |>.run {} |>.run {} |>.run' {} + let newValue ← goReduce decl.value |>.run analysis + return { decl with value := newValue } +where + goAnalyzeFunDecl (fn : FunDecl) : ReduceAnalysisM Unit := do + withNewScope do + fn.params.forM (addToScope ·.fvarId) + goAnalyze fn.value + + goAnalyze (code : Code) : ReduceAnalysisM Unit := do + match code with + | .let decl k => + addToScope decl.fvarId + goAnalyze k + | .jp decl k => + goAnalyzeFunDecl decl + let scope ← getScope + withReader (fun ctx => { ctx with jpScopes := ctx.jpScopes.insert decl.fvarId scope }) do + addToScope decl.fvarId + goAnalyze k + | .fun decl k => + goAnalyzeFunDecl decl + addToScope decl.fvarId + goAnalyze k + | .cases cs => + let visitor alt := do + withNewScope do + alt.getParams.forM (addToScope ·.fvarId) + goAnalyze alt.getCode + cs.alts.forM visitor + | .jmp fn args => + let decl ← getFunDecl fn + if let some knownArgs := (← get).jpJmpArgs.find? fn then + let mut newArgs := knownArgs + for (param, arg) in decl.params.zip args do + if let some knownVal := newArgs.find? param.fvarId then + if arg != knownVal then + newArgs := newArgs.erase param.fvarId + modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn newArgs } + else + let folder := fun acc (param, arg) => do + if (← allFVarM (isInJpScope fn) arg) then + return acc.insert param.fvarId arg + else + return acc + let interestingArgs ← decl.params.zip args |>.foldlM (init := {}) folder + modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn interestingArgs } + | .return .. | .unreach .. => return () + + goReduce (code : Code) : ReduceActionM Code := do + match code with + | .jp decl k => + if let some reducibleArgs := (← read).jpJmpArgs.find? decl.fvarId then + let filter param := do + let erasable := reducibleArgs.contains param.fvarId + if erasable then + eraseParam param + return !erasable + let newParams ← decl.params.filterM filter + let mut newValue ← goReduce decl.value + newValue ← replaceFVars newValue reducibleArgs false + let newType ← + if newParams.size != decl.params.size then + mkForallParams newParams (← newValue.inferType) + else + pure decl.type + let k ← goReduce k + let decl ← decl.update newType newParams newValue + return Code.updateFun! code decl k + else + return Code.updateFun! code decl (← goReduce k) + | .jmp fn args => + let reducibleArgs := (← read).jpJmpArgs.find! fn + let decl ← getFunDecl fn + let newParams := decl.params.zip args + |>.filter (!reducibleArgs.contains ·.fst.fvarId) + |>.map Prod.snd + return Code.updateJmp! code fn newParams + | .let decl k => + return Code.updateLet! code decl (← goReduce k) + | .fun decl k => + let decl ← decl.updateValue (← goReduce decl.value) + return Code.updateFun! code decl (← goReduce k) + | .cases cs => + let alts ← cs.alts.mapM (·.mapCodeM goReduce) + return Code.updateCases! code cs.resultType cs.discr alts + | .return .. | .unreach .. => return code + +end JoinPointCommonArgs + /-- Find all `fun` declarations in `decl` that qualify as join points then replace their definitions and call sites with `jp`/`jmp`. @@ -463,4 +622,13 @@ def extendJoinPointContext : Pass := builtin_initialize registerTraceClass `Compiler.extendJoinPointContext (inherited := true) +def Decl.commonJoinPointArgs (decl : Decl) : CompilerM Decl := do + JoinPointCommonArgs.reduce decl + +def commonJoinPointArgs : Pass := + .mkPerDeclaration `commonJoinPointArgs Decl.commonJoinPointArgs .mono + +builtin_initialize + registerTraceClass `Compiler.commonJoinPointArgs (inherited := true) + end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/Passes.lean b/src/Lean/Compiler/LCNF/Passes.lean index 2bcde2b867..1bf81ae7ca 100644 --- a/src/Lean/Compiler/LCNF/Passes.lean +++ b/src/Lean/Compiler/LCNF/Passes.lean @@ -62,6 +62,7 @@ def builtinPassManager : PassManager := { extendJoinPointContext, floatLetIn (phase := .mono) (occurrence := 1), reduceArity, + commonJoinPointArgs, simp (occurrence := 4) (phase := .mono), floatLetIn (phase := .mono) (occurrence := 2), lambdaLifting,