perf(library/init/lean/compiler/ir/boxing): create auxiliary constants for caching the value of boxed/unboxed literals and constants
For example, in the new test `qsort64.lean`, the new optimization prevents the repeated execution of `box UInt64.inhabited`. On my machine ``` ./run.sh qsort64.lean 2000000 ``` Goes from 1.22s to 0.355s
This commit is contained in:
parent
9e2200d0de
commit
61a3ea61c4
4 changed files with 104 additions and 13 deletions
|
|
@ -15,6 +15,9 @@ inductive AssocList (α : Type u) (β : Type v)
|
|||
namespace AssocList
|
||||
variables {α : Type u} {β : Type v} {δ : Type w} {m : Type w → Type w} [Monad m]
|
||||
|
||||
def empty : AssocList α β :=
|
||||
nil
|
||||
|
||||
@[specialize] def mfoldl (f : δ → α → β → m δ) : δ → AssocList α β → m δ
|
||||
| d, nil => pure d
|
||||
| d, cons a b es => do d ← f d a b; mfoldl d es
|
||||
|
|
|
|||
|
|
@ -450,6 +450,11 @@ match ctx.find x.idx with
|
|||
| some (LocalContextEntry.localVar t _) => some t
|
||||
| other => none
|
||||
|
||||
def LocalContext.getValue (ctx : LocalContext) (x : VarId) : Option Expr :=
|
||||
match ctx.find x.idx with
|
||||
| some (LocalContextEntry.localVar _ v) => some v
|
||||
| other => none
|
||||
|
||||
abbrev IndexRenaming := RBMap Index Index Index.lt
|
||||
|
||||
class HasAlphaEqv (α : Type) :=
|
||||
|
|
|
|||
|
|
@ -4,9 +4,11 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import init.data.assoclist
|
||||
import init.control.estate
|
||||
import init.control.reader
|
||||
import init.lean.runtime
|
||||
import init.lean.compiler.closedtermcache
|
||||
import init.lean.compiler.externattr
|
||||
import init.lean.compiler.ir.basic
|
||||
import init.lean.compiler.ir.compilerm
|
||||
|
|
@ -105,13 +107,30 @@ def eqvTypes (t₁ t₂ : IRType) : Bool :=
|
|||
(t₁.isScalar == t₂.isScalar) && (!t₁.isScalar || t₁ == t₂)
|
||||
|
||||
structure BoxingContext :=
|
||||
(localCtx : LocalContext := {}) (resultType : IRType := IRType.irrelevant) (decls : Array Decl) (env : Environment)
|
||||
(f : FunId := default _) (localCtx : LocalContext := {}) (resultType : IRType := IRType.irrelevant) (decls : Array Decl) (env : Environment)
|
||||
|
||||
abbrev M := ReaderT BoxingContext (StateT Index Id)
|
||||
structure BoxingState :=
|
||||
(nextIdx : Index)
|
||||
/- We create auxiliary declarations when boxing constant and literals.
|
||||
The idea is to avoid code such as
|
||||
```
|
||||
let x1 := Uint64.inhabited;
|
||||
let x2 := box x1;
|
||||
...
|
||||
```
|
||||
We currently do not cache these declarations in an environment extension, but
|
||||
we use auxDeclCache to avoid creating equivalent auxiliary declarations more than once when
|
||||
processing the same IR declaration.
|
||||
-/
|
||||
(auxDecls : Array Decl := Array.empty)
|
||||
(auxDeclCache : AssocList FnBody Expr := AssocList.empty)
|
||||
(nextAuxId : Nat := 1)
|
||||
|
||||
abbrev M := ReaderT BoxingContext (StateT BoxingState Id)
|
||||
|
||||
def mkFresh : M VarId :=
|
||||
do idx ← getModify (fun n => n + 1);
|
||||
pure { idx := idx }
|
||||
do oldS ← getModify (fun s => { nextIdx := s.nextIdx + 1, .. s });
|
||||
pure { idx := oldS.nextIdx }
|
||||
|
||||
def getEnv : M Environment := BoxingContext.env <$> read
|
||||
def getLocalContext : M LocalContext := BoxingContext.localCtx <$> read
|
||||
|
|
@ -142,18 +161,72 @@ adaptReader (fun (ctx : BoxingContext) => { localCtx := ctx.localCtx.addLocal x
|
|||
@[inline] def withJDecl {α : Type} (j : JoinPointId) (xs : Array Param) (v : FnBody) (k : M α) : M α :=
|
||||
adaptReader (fun (ctx : BoxingContext) => { localCtx := ctx.localCtx.addJP j xs v, .. ctx }) k
|
||||
|
||||
/- If `x` declaration is of the form `x := Expr.lit _` or `x := Expr.fap c Array.empty`, then
|
||||
return its value. -/
|
||||
private def isConstantValue (x : VarId) : M (Option Expr) :=
|
||||
do localCtx ← getLocalContext;
|
||||
match localCtx.getValue x with
|
||||
| some val =>
|
||||
match val with
|
||||
| Expr.lit _ => pure $ some val
|
||||
| Expr.fap _ args => pure $ if args.size == 0 then some val else none
|
||||
| _ => pure none
|
||||
| _ => pure none
|
||||
|
||||
/- Auxiliary function used by castVarIfNeeded.
|
||||
It is used when the expected type does not match `xType`.
|
||||
If `xType` is scalar, then we need to "box" it. Otherwise, we need to "unbox" it. -/
|
||||
def mkCast (x : VarId) (xType : IRType) : Expr :=
|
||||
if xType.isScalar then Expr.box xType x else Expr.unbox x
|
||||
def mkCast (x : VarId) (xType : IRType) (expectedType : IRType) : M Expr :=
|
||||
do
|
||||
optVal ← isConstantValue x;
|
||||
match optVal with
|
||||
| some v => do
|
||||
ctx ← read;
|
||||
/- Create auxiliary FnBody
|
||||
```
|
||||
let x_1 : xType := v;
|
||||
let x_2 : expectedType := Expr.box xType x_1;
|
||||
ret x_2
|
||||
```
|
||||
if `xType.isScalar`, and
|
||||
```
|
||||
let x_1 : xType := v;
|
||||
let x_2 : expectedType := Expr.unbox x_1;
|
||||
ret x_2
|
||||
```
|
||||
otherwise
|
||||
-/
|
||||
let body : FnBody :=
|
||||
if xType.isScalar then
|
||||
FnBody.vdecl { idx := 1 } xType v $
|
||||
FnBody.vdecl { idx := 2 } expectedType (Expr.box xType { idx := 1 }) $
|
||||
FnBody.ret (mkVarArg { idx := 2 })
|
||||
else
|
||||
FnBody.vdecl { idx := 1 } xType v $
|
||||
FnBody.vdecl { idx := 2 } expectedType (Expr.unbox { idx := 1 }) $
|
||||
FnBody.ret (mkVarArg { idx := 2 });
|
||||
s ← get;
|
||||
match s.auxDeclCache.find body with
|
||||
| some v => pure v
|
||||
| none => do
|
||||
let auxName := ctx.f ++ ((`_boxed_const).appendIndexAfter s.nextAuxId);
|
||||
let auxConst := Expr.fap auxName Array.empty;
|
||||
let auxDecl := Decl.fdecl auxName Array.empty expectedType body;
|
||||
modify $ fun s => {
|
||||
auxDecls := s.auxDecls.push auxDecl,
|
||||
auxDeclCache := s.auxDeclCache.cons body auxConst,
|
||||
nextAuxId := s.nextAuxId + 1,
|
||||
.. s
|
||||
};
|
||||
pure auxConst
|
||||
| none => pure $ if xType.isScalar then Expr.box xType x else Expr.unbox x
|
||||
|
||||
@[inline] def castVarIfNeeded (x : VarId) (expected : IRType) (k : VarId → M FnBody) : M FnBody :=
|
||||
do xType ← getVarType x;
|
||||
if eqvTypes xType expected then k x
|
||||
else do
|
||||
y ← mkFresh;
|
||||
let v := mkCast x xType;
|
||||
v ← mkCast x xType expected;
|
||||
FnBody.vdecl y expected v <$> k y
|
||||
|
||||
@[inline] def castArgIfNeeded (x : Arg) (expected : IRType) (k : Arg → M FnBody) : M FnBody :=
|
||||
|
|
@ -172,7 +245,7 @@ xs.miterate (Array.empty, Array.empty) $ fun i (x : Arg) (r : Array Arg × Array
|
|||
if eqvTypes xType expected then pure (xs.push (Arg.var x), bs)
|
||||
else do
|
||||
y ← mkFresh;
|
||||
let v := mkCast x xType;
|
||||
v ← mkCast x xType expected;
|
||||
let b := FnBody.vdecl y expected v FnBody.nil;
|
||||
pure (xs.push (Arg.var y), bs.push b)
|
||||
|
||||
|
|
@ -197,7 +270,8 @@ def castResultIfNeeded (x : VarId) (ty : IRType) (e : Expr) (eType : IRType) (b
|
|||
if eqvTypes ty eType then pure $ FnBody.vdecl x ty e b
|
||||
else do
|
||||
y ← mkFresh;
|
||||
pure $ FnBody.vdecl y eType e (FnBody.vdecl x ty (mkCast y eType) b)
|
||||
v ← mkCast y eType ty;
|
||||
pure $ FnBody.vdecl y eType e (FnBody.vdecl x ty v b)
|
||||
|
||||
def visitVDeclExpr (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody :=
|
||||
match e with
|
||||
|
|
@ -257,12 +331,15 @@ partial def visitFnBody : FnBody → M FnBody
|
|||
|
||||
def run (env : Environment) (decls : Array Decl) : Array Decl :=
|
||||
let ctx : BoxingContext := { decls := decls, env := env };
|
||||
let decls := decls.map (fun decl => match decl with
|
||||
let decls := decls.foldl (fun (newDecls : Array Decl) (decl : Decl) =>
|
||||
match decl with
|
||||
| Decl.fdecl f xs t b =>
|
||||
let nextIdx := decl.maxIndex + 1;
|
||||
let b := (withParams xs (visitFnBody b) { resultType := t, .. ctx }).run' nextIdx;
|
||||
Decl.fdecl f xs t b
|
||||
| d => d);
|
||||
let (b, s) := (withParams xs (visitFnBody b) { f := f, resultType := t, .. ctx }).run { nextIdx := nextIdx };
|
||||
let newDecls := newDecls ++ s.auxDecls;
|
||||
newDecls.push (Decl.fdecl f xs t b)
|
||||
| d => newDecls.push d)
|
||||
Array.empty;
|
||||
addBoxedVersions env decls
|
||||
|
||||
end ExplicitBoxing
|
||||
|
|
|
|||
6
tests/playground/qsort64.lean
Normal file
6
tests/playground/qsort64.lean
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
-- set_option trace.compiler.ir.boxing true
|
||||
|
||||
def main (xs : List String) : IO Unit :=
|
||||
do
|
||||
let a := xs.head.toNat.fold (fun i (a : Array UInt64) => a.push (UInt64.ofNat i)) Array.empty;
|
||||
IO.println $ (a.qsort (fun x y => x > y)).get 0
|
||||
Loading…
Add table
Reference in a new issue