diff --git a/library/init/data/assoclist.lean b/library/init/data/assoclist.lean index 03f645386e..8ec939f354 100644 --- a/library/init/data/assoclist.lean +++ b/library/init/data/assoclist.lean @@ -15,6 +15,9 @@ inductive AssocList (α : Type u) (β : Type v) namespace AssocList variables {α : Type u} {β : Type v} {δ : Type w} {m : Type w → Type w} [Monad m] +def empty : AssocList α β := +nil + @[specialize] def mfoldl (f : δ → α → β → m δ) : δ → AssocList α β → m δ | d, nil => pure d | d, cons a b es => do d ← f d a b; mfoldl d es diff --git a/library/init/lean/compiler/ir/basic.lean b/library/init/lean/compiler/ir/basic.lean index 2718789166..dfaa1cfbe9 100644 --- a/library/init/lean/compiler/ir/basic.lean +++ b/library/init/lean/compiler/ir/basic.lean @@ -450,6 +450,11 @@ match ctx.find x.idx with | some (LocalContextEntry.localVar t _) => some t | other => none +def LocalContext.getValue (ctx : LocalContext) (x : VarId) : Option Expr := +match ctx.find x.idx with +| some (LocalContextEntry.localVar _ v) => some v +| other => none + abbrev IndexRenaming := RBMap Index Index Index.lt class HasAlphaEqv (α : Type) := diff --git a/library/init/lean/compiler/ir/boxing.lean b/library/init/lean/compiler/ir/boxing.lean index 0bf7daf199..d4c7a42c31 100644 --- a/library/init/lean/compiler/ir/boxing.lean +++ b/library/init/lean/compiler/ir/boxing.lean @@ -4,9 +4,11 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ prelude +import init.data.assoclist import init.control.estate import init.control.reader import init.lean.runtime +import init.lean.compiler.closedtermcache import init.lean.compiler.externattr import init.lean.compiler.ir.basic import init.lean.compiler.ir.compilerm @@ -105,13 +107,30 @@ def eqvTypes (t₁ t₂ : IRType) : Bool := (t₁.isScalar == t₂.isScalar) && (!t₁.isScalar || t₁ == t₂) structure BoxingContext := -(localCtx : LocalContext := {}) (resultType : IRType := IRType.irrelevant) (decls : Array Decl) (env : Environment) +(f : FunId := default _) (localCtx : LocalContext := {}) (resultType : IRType := IRType.irrelevant) (decls : Array Decl) (env : Environment) -abbrev M := ReaderT BoxingContext (StateT Index Id) +structure BoxingState := +(nextIdx : Index) +/- We create auxiliary declarations when boxing constant and literals. + The idea is to avoid code such as + ``` + let x1 := Uint64.inhabited; + let x2 := box x1; + ... + ``` + We currently do not cache these declarations in an environment extension, but + we use auxDeclCache to avoid creating equivalent auxiliary declarations more than once when + processing the same IR declaration. +-/ +(auxDecls : Array Decl := Array.empty) +(auxDeclCache : AssocList FnBody Expr := AssocList.empty) +(nextAuxId : Nat := 1) + +abbrev M := ReaderT BoxingContext (StateT BoxingState Id) def mkFresh : M VarId := -do idx ← getModify (fun n => n + 1); - pure { idx := idx } +do oldS ← getModify (fun s => { nextIdx := s.nextIdx + 1, .. s }); + pure { idx := oldS.nextIdx } def getEnv : M Environment := BoxingContext.env <$> read def getLocalContext : M LocalContext := BoxingContext.localCtx <$> read @@ -142,18 +161,72 @@ adaptReader (fun (ctx : BoxingContext) => { localCtx := ctx.localCtx.addLocal x @[inline] def withJDecl {α : Type} (j : JoinPointId) (xs : Array Param) (v : FnBody) (k : M α) : M α := adaptReader (fun (ctx : BoxingContext) => { localCtx := ctx.localCtx.addJP j xs v, .. ctx }) k +/- If `x` declaration is of the form `x := Expr.lit _` or `x := Expr.fap c Array.empty`, then + return its value. -/ +private def isConstantValue (x : VarId) : M (Option Expr) := +do localCtx ← getLocalContext; + match localCtx.getValue x with + | some val => + match val with + | Expr.lit _ => pure $ some val + | Expr.fap _ args => pure $ if args.size == 0 then some val else none + | _ => pure none + | _ => pure none + /- Auxiliary function used by castVarIfNeeded. It is used when the expected type does not match `xType`. If `xType` is scalar, then we need to "box" it. Otherwise, we need to "unbox" it. -/ -def mkCast (x : VarId) (xType : IRType) : Expr := -if xType.isScalar then Expr.box xType x else Expr.unbox x +def mkCast (x : VarId) (xType : IRType) (expectedType : IRType) : M Expr := +do +optVal ← isConstantValue x; +match optVal with +| some v => do + ctx ← read; + /- Create auxiliary FnBody + ``` + let x_1 : xType := v; + let x_2 : expectedType := Expr.box xType x_1; + ret x_2 + ``` + if `xType.isScalar`, and + ``` + let x_1 : xType := v; + let x_2 : expectedType := Expr.unbox x_1; + ret x_2 + ``` + otherwise + -/ + let body : FnBody := + if xType.isScalar then + FnBody.vdecl { idx := 1 } xType v $ + FnBody.vdecl { idx := 2 } expectedType (Expr.box xType { idx := 1 }) $ + FnBody.ret (mkVarArg { idx := 2 }) + else + FnBody.vdecl { idx := 1 } xType v $ + FnBody.vdecl { idx := 2 } expectedType (Expr.unbox { idx := 1 }) $ + FnBody.ret (mkVarArg { idx := 2 }); + s ← get; + match s.auxDeclCache.find body with + | some v => pure v + | none => do + let auxName := ctx.f ++ ((`_boxed_const).appendIndexAfter s.nextAuxId); + let auxConst := Expr.fap auxName Array.empty; + let auxDecl := Decl.fdecl auxName Array.empty expectedType body; + modify $ fun s => { + auxDecls := s.auxDecls.push auxDecl, + auxDeclCache := s.auxDeclCache.cons body auxConst, + nextAuxId := s.nextAuxId + 1, + .. s + }; + pure auxConst +| none => pure $ if xType.isScalar then Expr.box xType x else Expr.unbox x @[inline] def castVarIfNeeded (x : VarId) (expected : IRType) (k : VarId → M FnBody) : M FnBody := do xType ← getVarType x; if eqvTypes xType expected then k x else do y ← mkFresh; - let v := mkCast x xType; + v ← mkCast x xType expected; FnBody.vdecl y expected v <$> k y @[inline] def castArgIfNeeded (x : Arg) (expected : IRType) (k : Arg → M FnBody) : M FnBody := @@ -172,7 +245,7 @@ xs.miterate (Array.empty, Array.empty) $ fun i (x : Arg) (r : Array Arg × Array if eqvTypes xType expected then pure (xs.push (Arg.var x), bs) else do y ← mkFresh; - let v := mkCast x xType; + v ← mkCast x xType expected; let b := FnBody.vdecl y expected v FnBody.nil; pure (xs.push (Arg.var y), bs.push b) @@ -197,7 +270,8 @@ def castResultIfNeeded (x : VarId) (ty : IRType) (e : Expr) (eType : IRType) (b if eqvTypes ty eType then pure $ FnBody.vdecl x ty e b else do y ← mkFresh; - pure $ FnBody.vdecl y eType e (FnBody.vdecl x ty (mkCast y eType) b) + v ← mkCast y eType ty; + pure $ FnBody.vdecl y eType e (FnBody.vdecl x ty v b) def visitVDeclExpr (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody := match e with @@ -257,12 +331,15 @@ partial def visitFnBody : FnBody → M FnBody def run (env : Environment) (decls : Array Decl) : Array Decl := let ctx : BoxingContext := { decls := decls, env := env }; -let decls := decls.map (fun decl => match decl with +let decls := decls.foldl (fun (newDecls : Array Decl) (decl : Decl) => + match decl with | Decl.fdecl f xs t b => let nextIdx := decl.maxIndex + 1; - let b := (withParams xs (visitFnBody b) { resultType := t, .. ctx }).run' nextIdx; - Decl.fdecl f xs t b - | d => d); + let (b, s) := (withParams xs (visitFnBody b) { f := f, resultType := t, .. ctx }).run { nextIdx := nextIdx }; + let newDecls := newDecls ++ s.auxDecls; + newDecls.push (Decl.fdecl f xs t b) + | d => newDecls.push d) + Array.empty; addBoxedVersions env decls end ExplicitBoxing diff --git a/tests/playground/qsort64.lean b/tests/playground/qsort64.lean new file mode 100644 index 0000000000..6f731d8136 --- /dev/null +++ b/tests/playground/qsort64.lean @@ -0,0 +1,6 @@ +-- set_option trace.compiler.ir.boxing true + +def main (xs : List String) : IO Unit := +do +let a := xs.head.toNat.fold (fun i (a : Array UInt64) => a.push (UInt64.ofNat i)) Array.empty; +IO.println $ (a.qsort (fun x y => x > y)).get 0