chore: remove IR elim dead branches (#11576)
This PR removes the old ElimDeadBranches pass and shifts the new one past lambda lifting. The reason for dropping the old one is its general unsoundness and the fact that we want to do refactorings on the IR part. The reason for shifting the current pass past lambda lifting, is that its analysis is imprecise in the presence of local function symbols. I experimented with the exact placement for a while and it seems like it is optimal here. Overall we observe a slight regression in the amount of C code generated, likely because we don't propagate information into lambdas before lifting them anymore. But generally measure a slight performance improvement in general.
This commit is contained in:
parent
e7f4fc9baf
commit
b8c53b1d29
5 changed files with 18 additions and 352 deletions
|
|
@ -21,7 +21,6 @@ public import Lean.Compiler.IR.Boxing
|
|||
public import Lean.Compiler.IR.RC
|
||||
public import Lean.Compiler.IR.ExpandResetReuse
|
||||
public import Lean.Compiler.IR.UnboxResult
|
||||
public import Lean.Compiler.IR.ElimDeadBranches
|
||||
public import Lean.Compiler.IR.EmitC
|
||||
public import Lean.Compiler.IR.Sorry
|
||||
public import Lean.Compiler.IR.ToIR
|
||||
|
|
@ -46,8 +45,7 @@ register_builtin_option compiler.reuse : Bool := {
|
|||
def compile (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
logDecls `init decls
|
||||
checkDecls decls
|
||||
let mut decls ← elimDeadBranches decls
|
||||
logDecls `elim_dead_branches decls
|
||||
let mut decls := decls
|
||||
decls := decls.map Decl.pushProj
|
||||
logDecls `push_proj decls
|
||||
if compiler.reuse.get (← getOptions) then
|
||||
|
|
|
|||
|
|
@ -1,344 +0,0 @@
|
|||
/-
|
||||
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Lean.Compiler.IR.CompilerM
|
||||
|
||||
public section
|
||||
|
||||
namespace Lean.IR.UnreachableBranches
|
||||
|
||||
/-- Value used in the abstract interpreter -/
|
||||
inductive Value where
|
||||
| bot -- undefined
|
||||
| top -- any value
|
||||
| ctor (i : CtorInfo) (vs : Array Value)
|
||||
| choice (vs : List Value)
|
||||
deriving Inhabited, BEq, Repr
|
||||
|
||||
protected partial def Value.toFormat : Value → Format
|
||||
| Value.bot => "⊥"
|
||||
| Value.top => "⊤"
|
||||
| Value.ctor info vs =>
|
||||
if vs.isEmpty then
|
||||
format info.name
|
||||
else
|
||||
Format.paren <| format info.name ++ Std.Format.join (vs.toList.map fun v => " " ++ Value.toFormat v)
|
||||
| Value.choice vs =>
|
||||
Format.paren <| Std.Format.joinSep (vs.map Value.toFormat) " | "
|
||||
|
||||
instance : ToFormat Value where
|
||||
format := Value.toFormat
|
||||
|
||||
instance : ToString Value where
|
||||
toString v := toString (format v)
|
||||
|
||||
namespace Value
|
||||
|
||||
partial def addChoice (merge : Value → Value → Value) : List Value → Value → List Value
|
||||
| [], v => [v]
|
||||
| v₁@(ctor i₁ _) :: cs, v₂@(ctor i₂ _) =>
|
||||
if i₁ == i₂ then merge v₁ v₂ :: cs
|
||||
else v₁ :: addChoice merge cs v₂
|
||||
| _, _ => panic! "invalid addChoice"
|
||||
|
||||
partial def merge (v₁ v₂ : Value) : Value :=
|
||||
match v₁, v₂ with
|
||||
| bot, v => v
|
||||
| v, bot => v
|
||||
| top, _ => top
|
||||
| _, top => top
|
||||
| v₁@(ctor i₁ vs₁), v₂@(ctor i₂ vs₂) =>
|
||||
if i₁ == i₂ then ctor i₁ <| vs₁.size.fold (init := #[]) fun i _ r => r.push (merge vs₁[i] vs₂[i]!)
|
||||
else choice [v₁, v₂]
|
||||
| choice vs₁, choice vs₂ => choice <| vs₁.foldl (addChoice merge) vs₂
|
||||
| choice vs, v => choice <| addChoice merge vs v
|
||||
| v, choice vs => choice <| addChoice merge vs v
|
||||
|
||||
/--
|
||||
In `truncate`, we approximate a value as `top` if depth > `truncateMaxDepth`.
|
||||
TODO: add option to control this parameter.
|
||||
-/
|
||||
def truncateMaxDepth := 8
|
||||
|
||||
/--
|
||||
Make sure constructors of recursive inductive datatypes can only occur once in each path.
|
||||
Values at depth > truncateMaxDepth are also approximated at `top`.
|
||||
We use this function this function to implement a simple widening operation for our abstract
|
||||
interpreter.
|
||||
Recall the widening functions is used to ensure termination in abstract interpreters.
|
||||
-/
|
||||
partial def truncate (env : Environment) (v : Value) (s : NameSet) : Value :=
|
||||
go v s truncateMaxDepth
|
||||
where
|
||||
go (v : Value) (s : NameSet) (depth : Nat) : Value :=
|
||||
match depth with
|
||||
| 0 => top
|
||||
| depth+1 =>
|
||||
match v, s with
|
||||
| ctor i vs, found =>
|
||||
let I := i.name.getPrefix
|
||||
if found.contains I then
|
||||
top
|
||||
else
|
||||
let cont (found' : NameSet) : Value :=
|
||||
ctor i (vs.map fun v => go v found' depth)
|
||||
match env.find? I with
|
||||
| some (ConstantInfo.inductInfo d) =>
|
||||
if d.isRec then cont (found.insert I)
|
||||
else cont found
|
||||
| _ => cont found
|
||||
| choice vs, found =>
|
||||
let newVs := vs.map fun v => go v found depth
|
||||
if newVs.elem top then top
|
||||
else choice newVs
|
||||
| v, _ => v
|
||||
|
||||
/-- Widening operator that guarantees termination in our abstract interpreter. -/
|
||||
def widening (env : Environment) (v₁ v₂ : Value) : Value :=
|
||||
truncate env (merge v₁ v₂) {}
|
||||
|
||||
end Value
|
||||
|
||||
abbrev FunctionSummaries := PHashMap FunId Value
|
||||
|
||||
private abbrev declLt (a b : FunId × Value) :=
|
||||
Name.quickLt a.1 b.1
|
||||
|
||||
private abbrev sortEntries (entries : Array (FunId × Value)) : Array (FunId × Value) :=
|
||||
entries.qsort declLt
|
||||
|
||||
private abbrev findAtSorted? (entries : Array (FunId × Value)) (fid : FunId) : Option Value :=
|
||||
if let some (_, value) := entries.binSearch (fid, default) declLt then
|
||||
some value
|
||||
else
|
||||
none
|
||||
|
||||
builtin_initialize functionSummariesExt : SimplePersistentEnvExtension (FunId × Value) FunctionSummaries ←
|
||||
registerSimplePersistentEnvExtension {
|
||||
addImportedFn := fun _ => {}
|
||||
addEntryFn := fun s ⟨e, n⟩ => s.insert e n
|
||||
exportEntriesFnEx? := some fun _ s _ => fun
|
||||
-- preserved for non-modules, make non-persistent at some point?
|
||||
| .private => sortEntries s.toArray
|
||||
| _ => #[]
|
||||
asyncMode := .sync -- compilation is non-parallel anyway
|
||||
replay? := some <| SimplePersistentEnvExtension.replayOfFilter (!·.contains ·.1) (fun s ⟨e, n⟩ => s.insert e n)
|
||||
}
|
||||
|
||||
def addFunctionSummary (env : Environment) (fid : FunId) (v : Value) : Environment :=
|
||||
functionSummariesExt.addEntry env (fid, v)
|
||||
|
||||
def getFunctionSummary? (env : Environment) (fid : FunId) : Option Value :=
|
||||
match env.getModuleIdxFor? fid with
|
||||
| some modIdx => findAtSorted? (functionSummariesExt.getModuleEntries env modIdx) fid
|
||||
| none => functionSummariesExt.getState env |>.find? fid
|
||||
|
||||
abbrev Assignment := Std.HashMap VarId Value
|
||||
|
||||
structure InterpContext where
|
||||
currFnIdx : Nat := 0
|
||||
decls : Array Decl
|
||||
env : Environment
|
||||
lctx : LocalContext := {}
|
||||
|
||||
structure InterpState where
|
||||
assignments : Array Assignment
|
||||
funVals : PArray Value -- we take snapshots during fixpoint computations
|
||||
visitedJps : Array (Std.HashSet JoinPointId)
|
||||
|
||||
abbrev M := ReaderT InterpContext (StateM InterpState)
|
||||
|
||||
open Value
|
||||
|
||||
def findVarValue (x : VarId) : M Value := do
|
||||
let ctx ← read
|
||||
let s ← get
|
||||
let assignment := s.assignments[ctx.currFnIdx]!
|
||||
return assignment.getD x bot
|
||||
|
||||
def findArgValue (arg : Arg) : M Value :=
|
||||
match arg with
|
||||
| .var x => findVarValue x
|
||||
| .erased => pure top
|
||||
|
||||
def updateVarAssignment (x : VarId) (v : Value) : M Unit := do
|
||||
let v' ← findVarValue x
|
||||
let ctx ← read
|
||||
modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x (merge v v') }
|
||||
|
||||
def resetVarAssignment (x : VarId) : M Unit := do
|
||||
let ctx ← read
|
||||
modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x Value.bot }
|
||||
|
||||
def resetParamAssignment (y : Param) : M Unit :=
|
||||
resetVarAssignment y.x
|
||||
|
||||
partial def projValue : Value → Nat → Value
|
||||
| ctor _ vs, i => vs.getD i bot
|
||||
| choice vs, i => vs.foldl (fun r v => merge r (projValue v i)) bot
|
||||
| v, _ => v
|
||||
|
||||
def interpExpr : Expr → M Value
|
||||
| Expr.ctor i ys => return ctor i (← ys.mapM fun y => findArgValue y)
|
||||
| Expr.proj i x => return projValue (← findVarValue x) i
|
||||
| Expr.fap fid _ => do
|
||||
let ctx ← read
|
||||
match getFunctionSummary? ctx.env fid with
|
||||
| some v => pure v
|
||||
| none => do
|
||||
let s ← get
|
||||
match ctx.decls.findIdx? (fun decl => decl.name == fid) with
|
||||
| some idx => pure s.funVals[idx]!
|
||||
| none => pure top
|
||||
| _ => pure top
|
||||
|
||||
partial def containsCtor : Value → CtorInfo → Bool
|
||||
| top, _ => true
|
||||
| ctor i _, j => i == j
|
||||
| choice vs, j => vs.any fun v => containsCtor v j
|
||||
| _, _ => false
|
||||
|
||||
def updateCurrFnSummary (v : Value) : M Unit := do
|
||||
let ctx ← read
|
||||
let currFnIdx := ctx.currFnIdx
|
||||
modify fun s => { s with funVals := s.funVals.modify currFnIdx (fun v' => widening ctx.env v v') }
|
||||
|
||||
def markJPVisited (j : JoinPointId) : M Bool := do
|
||||
let currFnIdx := (← read).currFnIdx
|
||||
modifyGet fun s =>
|
||||
⟨!(s.visitedJps[currFnIdx]!.contains j),
|
||||
{ s with visitedJps := s.visitedJps.modify currFnIdx fun a => a.insert j }⟩
|
||||
|
||||
/-- Return true if the assignment of at least one parameter has been updated. -/
|
||||
def updateJPParamsAssignment (j : JoinPointId) (ys : Array Param) (xs : Array Arg) : M Bool := do
|
||||
let ctx ← read
|
||||
let currFnIdx := ctx.currFnIdx
|
||||
let isFirstVisit ← markJPVisited j
|
||||
ys.size.foldM (init := isFirstVisit) fun i _ r => do
|
||||
let y := ys[i]
|
||||
let x := xs[i]!
|
||||
let yVal ← findVarValue y.x
|
||||
let xVal ← findArgValue x
|
||||
let newVal := merge yVal xVal
|
||||
if newVal == yVal then
|
||||
pure r
|
||||
else
|
||||
modify fun s => { s with assignments := s.assignments.modify currFnIdx fun a => a.insert y.x newVal }
|
||||
pure true
|
||||
|
||||
private partial def resetNestedJPParams : FnBody → M Unit
|
||||
| FnBody.jdecl _ ys _ k => do
|
||||
ys.forM resetParamAssignment
|
||||
/- Remark we don't need to reset the parameters of join points
|
||||
nested in `b` since they will be reset if this JP is used. -/
|
||||
resetNestedJPParams k
|
||||
| FnBody.case _ _ _ alts =>
|
||||
alts.forM fun alt => match alt with
|
||||
| Alt.ctor _ b => resetNestedJPParams b
|
||||
| Alt.default b => resetNestedJPParams b
|
||||
| e => do unless e.isTerminal do resetNestedJPParams e.body
|
||||
|
||||
partial def interpFnBody : FnBody → M Unit
|
||||
| FnBody.vdecl x _ e b => do
|
||||
let v ← interpExpr e
|
||||
updateVarAssignment x v
|
||||
interpFnBody b
|
||||
| FnBody.jdecl j ys v b =>
|
||||
withReader (fun ctx => { ctx with lctx := ctx.lctx.addJP j ys v }) do
|
||||
interpFnBody b
|
||||
| FnBody.case _ x _ alts => do
|
||||
let v ← findVarValue x
|
||||
alts.forM fun alt => do
|
||||
match alt with
|
||||
| Alt.ctor i b => if containsCtor v i then interpFnBody b
|
||||
| Alt.default b => interpFnBody b
|
||||
| FnBody.ret x => do
|
||||
let v ← findArgValue x
|
||||
updateCurrFnSummary v
|
||||
| FnBody.jmp j xs => do
|
||||
let ctx ← read
|
||||
let ys := (ctx.lctx.getJPParams j).get!
|
||||
let b := (ctx.lctx.getJPBody j).get!
|
||||
let updated ← updateJPParamsAssignment j ys xs
|
||||
if updated then
|
||||
-- We must reset the value of nested join-point parameters since they depend on `ys` values
|
||||
resetNestedJPParams b
|
||||
interpFnBody b
|
||||
| e => do
|
||||
unless e.isTerminal do
|
||||
interpFnBody e.body
|
||||
|
||||
def inferStep : M Bool := do
|
||||
let ctx ← read
|
||||
modify fun s => { s with assignments := ctx.decls.map fun _ => {},
|
||||
visitedJps := ctx.decls.map fun _ => {} }
|
||||
ctx.decls.size.foldM (init := false) fun idx _ modified => do
|
||||
match ctx.decls[idx] with
|
||||
| .fdecl (xs := ys) (body := b) .. => do
|
||||
let s ← get
|
||||
let currVals := s.funVals[idx]!
|
||||
withReader (fun ctx => { ctx with currFnIdx := idx }) do
|
||||
ys.forM fun y => updateVarAssignment y.x top
|
||||
interpFnBody b
|
||||
let s ← get
|
||||
let newVals := s.funVals[idx]!
|
||||
pure (modified || currVals != newVals)
|
||||
| .extern .. => do
|
||||
let currVals := (← get).funVals[idx]!
|
||||
updateCurrFnSummary .top
|
||||
let newVals := (← get).funVals[idx]!
|
||||
pure (modified || currVals != newVals)
|
||||
|
||||
partial def inferMain : M Unit := do
|
||||
let modified ← inferStep
|
||||
if modified then inferMain else pure ()
|
||||
|
||||
partial def elimDeadAux (assignment : Assignment) : FnBody → FnBody
|
||||
| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux assignment b)
|
||||
| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux assignment v) (elimDeadAux assignment b)
|
||||
| FnBody.case tid x xType alts =>
|
||||
let v := assignment.getD x bot
|
||||
let alts := alts.map fun alt =>
|
||||
match alt with
|
||||
| Alt.ctor i b => Alt.ctor i <| if containsCtor v i then elimDeadAux assignment b else FnBody.unreachable
|
||||
| Alt.default b => Alt.default (elimDeadAux assignment b)
|
||||
FnBody.case tid x xType alts
|
||||
| e =>
|
||||
if e.isTerminal then e
|
||||
else
|
||||
let (instr, b) := e.split
|
||||
let b := elimDeadAux assignment b
|
||||
instr.setBody b
|
||||
|
||||
partial def elimDead (assignment : Assignment) (d : Decl) : Decl :=
|
||||
match d with
|
||||
| .fdecl (body := b) .. => d.updateBody! <| elimDeadAux assignment b
|
||||
| other => other
|
||||
|
||||
end UnreachableBranches
|
||||
|
||||
open UnreachableBranches
|
||||
|
||||
def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
let env ← getEnv
|
||||
let assignments : Array Assignment := decls.map fun _ => {}
|
||||
let funVals := mkPArray decls.size Value.bot
|
||||
let visitedJps := decls.map fun _ => {}
|
||||
let ctx : InterpContext := { decls := decls, env := env }
|
||||
let s : InterpState := { assignments, funVals, visitedJps }
|
||||
let (_, s) := (inferMain ctx).run s
|
||||
let funVals := s.funVals
|
||||
let assignments := s.assignments
|
||||
modifyEnv fun env =>
|
||||
decls.size.fold (init := env) fun i _ env =>
|
||||
addFunctionSummary env decls[i].name funVals[i]!
|
||||
return decls.mapIdx fun i decl => elimDead assignments[i]! decl
|
||||
|
||||
builtin_initialize registerTraceClass `compiler.ir.elim_dead_branches (inherited := true)
|
||||
|
||||
end Lean.IR
|
||||
|
|
@ -113,10 +113,10 @@ def builtinPassManager : PassManager := {
|
|||
commonJoinPointArgs,
|
||||
simp (occurrence := 4) (phase := .mono),
|
||||
floatLetIn (phase := .mono) (occurrence := 2),
|
||||
elimDeadBranches,
|
||||
lambdaLifting,
|
||||
extendJoinPointContext (phase := .mono) (occurrence := 1),
|
||||
simp (occurrence := 5) (phase := .mono),
|
||||
elimDeadBranches,
|
||||
cse (occurrence := 2) (phase := .mono),
|
||||
saveMono, -- End of mono phase
|
||||
inferVisibility (phase := .mono),
|
||||
|
|
|
|||
|
|
@ -85,6 +85,18 @@ def getDeclInfo? (declName : Name) : CoreM (Option ConstantInfo) := do
|
|||
let env ← getEnv
|
||||
return env.find? (mkUnsafeRecName declName) <|> env.find? declName
|
||||
|
||||
def declIsNotUnsafe (declName : Name) : CoreM Bool := do
|
||||
let env ← getEnv
|
||||
let some info := env.find? declName | return true
|
||||
if info.isUnsafe then
|
||||
return false
|
||||
else
|
||||
if info matches .opaqueInfo .. then
|
||||
-- check if its a partial def
|
||||
return env.find? (Compiler.mkUnsafeRecName declName) |>.isNone
|
||||
else
|
||||
return true
|
||||
|
||||
/--
|
||||
Convert the given declaration from the Lean environment into `Decl`.
|
||||
The steps for this are roughly:
|
||||
|
|
@ -97,7 +109,7 @@ The steps for this are roughly:
|
|||
def toDecl (declName : Name) : CompilerM Decl := do
|
||||
let declName := if let some name := isUnsafeRecName? declName then name else declName
|
||||
let some info ← getDeclInfo? declName | throwError "declaration `{.ofConstName declName}` not found"
|
||||
let safe := !info.isPartial && !info.isUnsafe
|
||||
let safe ← declIsNotUnsafe declName
|
||||
let env ← getEnv
|
||||
let inlineAttr? := getInlineAttribute? env declName
|
||||
let paramsFromTypeBinders (expr : Expr) : CompilerM (Array Param) := do
|
||||
|
|
|
|||
|
|
@ -6,13 +6,13 @@ class Semiring (α : Type u) where
|
|||
|
||||
/--
|
||||
trace: [Compiler.IR] [result]
|
||||
def instSemiringUInt8 : obj :=
|
||||
let x_1 : obj := pap instSemiringUInt8._lam_0._boxed;
|
||||
ret x_1
|
||||
def instSemiringUInt8._lam_0 (x_1 : @& tobj) (x_2 : u8) : u8 :=
|
||||
let x_3 : u8 := UInt8.ofNat x_1;
|
||||
let x_4 : u8 := UInt8.mul x_3 x_2;
|
||||
ret x_4
|
||||
def instSemiringUInt8 : obj :=
|
||||
let x_1 : obj := pap instSemiringUInt8._lam_0._boxed;
|
||||
ret x_1
|
||||
def instSemiringUInt8._lam_0._boxed (x_1 : tobj) (x_2 : tagged) : tagged :=
|
||||
let x_3 : u8 := unbox x_2;
|
||||
let x_4 : u8 := instSemiringUInt8._lam_0 x_1 x_3;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue