perf(library/init/lean/compiler/ir/boxing): create auxiliary constants for caching the value of boxed/unboxed literals and constants

For example, in the new test `qsort64.lean`, the new optimization
prevents the repeated execution of `box UInt64.inhabited`.
On my machine
```
./run.sh qsort64.lean 2000000
```
Goes from 1.22s to 0.355s
This commit is contained in:
Leonardo de Moura 2019-09-11 10:36:32 -07:00
parent 9e2200d0de
commit 61a3ea61c4
4 changed files with 104 additions and 13 deletions

View file

@ -15,6 +15,9 @@ inductive AssocList (α : Type u) (β : Type v)
namespace AssocList
variables {α : Type u} {β : Type v} {δ : Type w} {m : Type w → Type w} [Monad m]
def empty : AssocList α β :=
nil
@[specialize] def mfoldl (f : δ → α → β → m δ) : δ → AssocList α β → m δ
| d, nil => pure d
| d, cons a b es => do d ← f d a b; mfoldl d es

View file

@ -450,6 +450,11 @@ match ctx.find x.idx with
| some (LocalContextEntry.localVar t _) => some t
| other => none
def LocalContext.getValue (ctx : LocalContext) (x : VarId) : Option Expr :=
match ctx.find x.idx with
| some (LocalContextEntry.localVar _ v) => some v
| other => none
abbrev IndexRenaming := RBMap Index Index Index.lt
class HasAlphaEqv (α : Type) :=

View file

@ -4,9 +4,11 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import init.data.assoclist
import init.control.estate
import init.control.reader
import init.lean.runtime
import init.lean.compiler.closedtermcache
import init.lean.compiler.externattr
import init.lean.compiler.ir.basic
import init.lean.compiler.ir.compilerm
@ -105,13 +107,30 @@ def eqvTypes (t₁ t₂ : IRType) : Bool :=
(t₁.isScalar == t₂.isScalar) && (!t₁.isScalar || t₁ == t₂)
structure BoxingContext :=
(localCtx : LocalContext := {}) (resultType : IRType := IRType.irrelevant) (decls : Array Decl) (env : Environment)
(f : FunId := default _) (localCtx : LocalContext := {}) (resultType : IRType := IRType.irrelevant) (decls : Array Decl) (env : Environment)
abbrev M := ReaderT BoxingContext (StateT Index Id)
structure BoxingState :=
(nextIdx : Index)
/- We create auxiliary declarations when boxing constant and literals.
The idea is to avoid code such as
```
let x1 := Uint64.inhabited;
let x2 := box x1;
...
```
We currently do not cache these declarations in an environment extension, but
we use auxDeclCache to avoid creating equivalent auxiliary declarations more than once when
processing the same IR declaration.
-/
(auxDecls : Array Decl := Array.empty)
(auxDeclCache : AssocList FnBody Expr := AssocList.empty)
(nextAuxId : Nat := 1)
abbrev M := ReaderT BoxingContext (StateT BoxingState Id)
def mkFresh : M VarId :=
do idx ← getModify (fun n => n + 1);
pure { idx := idx }
do oldS ← getModify (fun s => { nextIdx := s.nextIdx + 1, .. s });
pure { idx := oldS.nextIdx }
def getEnv : M Environment := BoxingContext.env <$> read
def getLocalContext : M LocalContext := BoxingContext.localCtx <$> read
@ -142,18 +161,72 @@ adaptReader (fun (ctx : BoxingContext) => { localCtx := ctx.localCtx.addLocal x
@[inline] def withJDecl {α : Type} (j : JoinPointId) (xs : Array Param) (v : FnBody) (k : M α) : M α :=
adaptReader (fun (ctx : BoxingContext) => { localCtx := ctx.localCtx.addJP j xs v, .. ctx }) k
/- If `x` declaration is of the form `x := Expr.lit _` or `x := Expr.fap c Array.empty`, then
return its value. -/
private def isConstantValue (x : VarId) : M (Option Expr) :=
do localCtx ← getLocalContext;
match localCtx.getValue x with
| some val =>
match val with
| Expr.lit _ => pure $ some val
| Expr.fap _ args => pure $ if args.size == 0 then some val else none
| _ => pure none
| _ => pure none
/- Auxiliary function used by castVarIfNeeded.
It is used when the expected type does not match `xType`.
If `xType` is scalar, then we need to "box" it. Otherwise, we need to "unbox" it. -/
def mkCast (x : VarId) (xType : IRType) : Expr :=
if xType.isScalar then Expr.box xType x else Expr.unbox x
def mkCast (x : VarId) (xType : IRType) (expectedType : IRType) : M Expr :=
do
optVal ← isConstantValue x;
match optVal with
| some v => do
ctx ← read;
/- Create auxiliary FnBody
```
let x_1 : xType := v;
let x_2 : expectedType := Expr.box xType x_1;
ret x_2
```
if `xType.isScalar`, and
```
let x_1 : xType := v;
let x_2 : expectedType := Expr.unbox x_1;
ret x_2
```
otherwise
-/
let body : FnBody :=
if xType.isScalar then
FnBody.vdecl { idx := 1 } xType v $
FnBody.vdecl { idx := 2 } expectedType (Expr.box xType { idx := 1 }) $
FnBody.ret (mkVarArg { idx := 2 })
else
FnBody.vdecl { idx := 1 } xType v $
FnBody.vdecl { idx := 2 } expectedType (Expr.unbox { idx := 1 }) $
FnBody.ret (mkVarArg { idx := 2 });
s ← get;
match s.auxDeclCache.find body with
| some v => pure v
| none => do
let auxName := ctx.f ++ ((`_boxed_const).appendIndexAfter s.nextAuxId);
let auxConst := Expr.fap auxName Array.empty;
let auxDecl := Decl.fdecl auxName Array.empty expectedType body;
modify $ fun s => {
auxDecls := s.auxDecls.push auxDecl,
auxDeclCache := s.auxDeclCache.cons body auxConst,
nextAuxId := s.nextAuxId + 1,
.. s
};
pure auxConst
| none => pure $ if xType.isScalar then Expr.box xType x else Expr.unbox x
@[inline] def castVarIfNeeded (x : VarId) (expected : IRType) (k : VarId → M FnBody) : M FnBody :=
do xType ← getVarType x;
if eqvTypes xType expected then k x
else do
y ← mkFresh;
let v := mkCast x xType;
v ← mkCast x xType expected;
FnBody.vdecl y expected v <$> k y
@[inline] def castArgIfNeeded (x : Arg) (expected : IRType) (k : Arg → M FnBody) : M FnBody :=
@ -172,7 +245,7 @@ xs.miterate (Array.empty, Array.empty) $ fun i (x : Arg) (r : Array Arg × Array
if eqvTypes xType expected then pure (xs.push (Arg.var x), bs)
else do
y ← mkFresh;
let v := mkCast x xType;
v ← mkCast x xType expected;
let b := FnBody.vdecl y expected v FnBody.nil;
pure (xs.push (Arg.var y), bs.push b)
@ -197,7 +270,8 @@ def castResultIfNeeded (x : VarId) (ty : IRType) (e : Expr) (eType : IRType) (b
if eqvTypes ty eType then pure $ FnBody.vdecl x ty e b
else do
y ← mkFresh;
pure $ FnBody.vdecl y eType e (FnBody.vdecl x ty (mkCast y eType) b)
v ← mkCast y eType ty;
pure $ FnBody.vdecl y eType e (FnBody.vdecl x ty v b)
def visitVDeclExpr (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody :=
match e with
@ -257,12 +331,15 @@ partial def visitFnBody : FnBody → M FnBody
def run (env : Environment) (decls : Array Decl) : Array Decl :=
let ctx : BoxingContext := { decls := decls, env := env };
let decls := decls.map (fun decl => match decl with
let decls := decls.foldl (fun (newDecls : Array Decl) (decl : Decl) =>
match decl with
| Decl.fdecl f xs t b =>
let nextIdx := decl.maxIndex + 1;
let b := (withParams xs (visitFnBody b) { resultType := t, .. ctx }).run' nextIdx;
Decl.fdecl f xs t b
| d => d);
let (b, s) := (withParams xs (visitFnBody b) { f := f, resultType := t, .. ctx }).run { nextIdx := nextIdx };
let newDecls := newDecls ++ s.auxDecls;
newDecls.push (Decl.fdecl f xs t b)
| d => newDecls.push d)
Array.empty;
addBoxedVersions env decls
end ExplicitBoxing

View file

@ -0,0 +1,6 @@
-- set_option trace.compiler.ir.boxing true
def main (xs : List String) : IO Unit :=
do
let a := xs.head.toNat.fold (fun i (a : Array UInt64) => a.push (UInt64.ofNat i)) Array.empty;
IO.println $ (a.qsort (fun x y => x > y)).get 0