From 680ee2116163cd43aca89774bd28e426b55df28a Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 2 Oct 2019 19:57:43 -0700 Subject: [PATCH] feat(library/init/lean/compiler/ir): add `unreachbranches.lean` New optimization pass for eliminating unreachable branches that occur very often when using `ExceptT` and `EState`. The current commit implements an abstract interpreter for computing an approximation of the kinds of values returned by a function. TODO: - Implement `FnBody.jmp`. - Implement `interpDecl` - Remove unreachable branches in `FnBody.case` --- library/init/data/array/basic.lean | 3 + library/init/lean/compiler/ir/format.lean | 2 +- .../lean/compiler/ir/unreachbranches.lean | 173 ++++++++++++++++++ 3 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 library/init/lean/compiler/ir/unreachbranches.lean diff --git a/library/init/data/array/basic.lean b/library/init/data/array/basic.lean index efc7d9120d..c610d89048 100644 --- a/library/init/data/array/basic.lean +++ b/library/init/data/array/basic.lean @@ -86,6 +86,9 @@ a.get! (a.size - 1) def get? (a : Array α) (i : Nat) : Option α := if h : i < a.size then some (a.get ⟨i, h⟩) else none +def getD (a : Array α) (i : Nat) (v₀ : α) : α := +if h : i < a.size then a.get ⟨i, h⟩ else v₀ + @[extern c inline "lean_array_fset(#2, #3, #4)"] def set (a : Array α) (i : @& Fin a.size) (v : α) : Array α := { sz := a.sz, diff --git a/library/init/lean/compiler/ir/format.lean b/library/init/lean/compiler/ir/format.lean index 5b269e6fa2..e38123ef41 100644 --- a/library/init/lean/compiler/ir/format.lean +++ b/library/init/lean/compiler/ir/format.lean @@ -15,7 +15,7 @@ private def formatArg : Arg → Format instance argHasFormat : HasFormat Arg := ⟨formatArg⟩ -private def formatArray {α : Type} [HasFormat α] (args : Array α) : Format := +def formatArray {α : Type} [HasFormat α] (args : Array α) : Format := args.foldl (fun r a => r ++ " " ++ format a) Format.nil private def formatLitVal : LitVal → Format diff --git a/library/init/lean/compiler/ir/unreachbranches.lean b/library/init/lean/compiler/ir/unreachbranches.lean new file mode 100644 index 0000000000..9a725b5308 --- /dev/null +++ b/library/init/lean/compiler/ir/unreachbranches.lean @@ -0,0 +1,173 @@ +/- +Copyright (c) 2019 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import init.control.reader +import init.lean.compiler.ir.format +import init.lean.compiler.ir.basic + +namespace Lean +namespace IR +namespace UnreachableBranches + +/-- Value used in the abstract interpreter -/ +inductive Value +| bot -- undefined +| top -- any value +| ctor (i : CtorInfo) (vs : Array Value) +| choice (vs : List Value) + +namespace Value + +instance : Inhabited Value := ⟨top⟩ + +protected partial def beq : Value → Value → Bool +| bot, bot => true +| top, top => true +| ctor i₁ vs₁, ctor i₂ vs₂ => i₁ == i₂ && Array.isEqv vs₁ vs₂ beq +| choice vs₁, choice vs₂ => + vs₁.all (fun v₁ => vs₂.any $ fun v₂ => beq v₁ v₂) + && + vs₂.all (fun v₂ => vs₁.any $ fun v₁ => beq v₁ v₂) +| _, _ => false + +instance : HasBeq Value := ⟨Value.beq⟩ + +partial def addChoice (merge : Value → Value → Value) : List Value → Value → List Value +| [], v => [v] +| v₁@(ctor i₁ vs₁) :: cs, v₂@(ctor i₂ vs₂) => + if i₁ == i₂ then merge v₁ v₂ :: cs + else v₁ :: addChoice cs v₂ +| _, _ => panic! "invalid addChoice" + +partial def merge : Value → Value → Value +| bot, v => v +| v, bot => v +| top, _ => top +| _, top => top +| v₁@(ctor i₁ vs₁), v₂@(ctor i₂ vs₂) => + if i₁ == i₂ then ctor i₁ $ vs₁.size.fold (fun i r => r.push (merge (vs₁.get! i) (vs₂.get! i))) Array.empty + else choice [v₁, v₂] +| choice vs₁, choice vs₂ => choice $ vs₁.foldl (addChoice merge) vs₂ +| choice vs, v => choice $ addChoice merge vs v +| v, choice vs => choice $ addChoice merge vs v + +protected partial def format : Value → Format +| top => "top" +| bot => "bot" +| choice vs => fmt "@" ++ @List.format _ ⟨format⟩ vs +| ctor i vs => fmt i.name ++ @formatArray _ ⟨format⟩ vs + +instance : HasFormat Value := ⟨Value.format⟩ +instance : HasToString Value := ⟨Format.pretty ∘ Value.format⟩ + +end Value + +abbrev FunctionSummaries := SMap FunId Value + +def mkFunctionSummariesExtension : IO (SimplePersistentEnvExtension (FunId × Value) FunctionSummaries) := +registerSimplePersistentEnvExtension { + name := `unreachBranchesFunSummary, + addImportedFn := fun as => + let cache : FunctionSummaries := mkStateFromImportedEntries (fun s (p : FunId × Value) => s.insert p.1 p.2) {} as; + cache.switch, + addEntryFn := fun s ⟨e, n⟩ => s.insert e n +} + +@[init mkFunctionSummariesExtension] +constant functionSummariesExt : SimplePersistentEnvExtension (FunId × Value) FunctionSummaries := default _ + +def addFunctionSummary (env : Environment) (fid : FunId) (v : Value) : Environment := +functionSummariesExt.addEntry env (fid, v) + +def getFunctionSummary (env : Environment) (fid : FunId) : Option Value := +(functionSummariesExt.getState env).find fid + +def Assignment := PHashMap VarId Value + +structure InterpContext := +(currFn : Name) +(env : Environment) +(lctx : LocalContext) + +abbrev FunMap := HashMap FunId Value + +structure InterpState := +(assignment : Assignment) +(funMap : FunMap) + +abbrev M := ReaderT InterpContext (State InterpState) + +open Value + +def findVarValue (x : VarId) : M Value := +do s ← get; + match s.assignment.find x with + | some v => pure v + | none => pure top + +def findArgValue (arg : Arg) : M Value := +match arg with +| Arg.var x => findVarValue x +| _ => pure top + +def updateVarAssignment (x : VarId) (v : Value) : M Unit := +do v' ← findVarValue x; + modify $ fun s => { assignment := s.assignment.insert x (merge v v'), .. s } + +partial def projValue : Value → Nat → Value +| ctor _ vs, i => vs.getD i bot +| choice vs, i => vs.foldl (fun r v => merge r (projValue v i)) bot +| v, _ => v + +def interpExpr : Expr → M Value +| Expr.ctor i ys => ctor i <$> ys.mmap (fun y => findArgValue y) +| Expr.proj i x => do v ← findVarValue x; pure $ projValue v i +| Expr.fap fid ys => do + ctx ← read; + match getFunctionSummary ctx.env fid with + | some v => pure v + | none => pure top +| _ => pure top + +partial def containsCtor : Value → CtorInfo → Bool +| top, _ => true +| ctor i _, j => i == j +| choice vs, j => vs.any $ fun v => containsCtor v j +| _, _ => false + +def updateCurrFnSummary (v : Value) : M Unit := +do ctx ← read; + s ← get; + let currFn := ctx.currFn; + let v' := (s.funMap.find currFn).getOrElse bot; + modify $ fun s => { funMap := s.funMap.insert currFn (merge v v'), .. s } + +partial def interpFnBody : FnBody → M Unit +| FnBody.vdecl x _ e b => do + v ← interpExpr e; + updateVarAssignment x v; + interpFnBody b +| FnBody.jdecl j ys v b => + adaptReader (fun (ctx : InterpContext) => { lctx := ctx.lctx.addJP j ys v, .. ctx }) $ + interpFnBody b +| FnBody.case _ x _ alts => do + v ← findVarValue x; + alts.mfor $ fun alt => + match alt with + | Alt.ctor i b => when (containsCtor v i) $ interpFnBody b + | Alt.default b => interpFnBody b +| FnBody.ret x => do + v ← findArgValue x; + updateCurrFnSummary v +| FnBody.jmp j xs => do + ctx ← read; + -- TODO + pure () +| e => unless (e.isTerminal) $ interpFnBody e.body + +end UnreachableBranches +end IR +end Lean