feat: add IR.DeclInfo

This commit is contained in:
Leonardo de Moura 2021-01-26 12:40:31 -08:00
parent 3d01327129
commit 72a8fb84b5
16 changed files with 86 additions and 67 deletions

View file

@ -392,34 +392,47 @@ def reshape (bs : Array FnBody) (term : FnBody) : FnBody :=
@[export lean_ir_mk_alt] def mkAlt (n : Name) (cidx : Nat) (size : Nat) (usize : Nat) (ssize : Nat) (b : FnBody) : Alt :=
Alt.ctor ⟨n, cidx, size, usize, ssize⟩ b
/-- Extra information associated with a declaration. -/
structure DeclInfo where
/-- If `some <blame>`, then declaration depends on `<blame>` which uses a `sorry` axiom. -/
sorryDep? : Option Name := none
inductive Decl where
| fdecl (f : FunId) (xs : Array Param) (ty : IRType) (b : FnBody)
| extern (f : FunId) (xs : Array Param) (ty : IRType) (ext : ExternAttrData)
| fdecl (f : FunId) (xs : Array Param) (type : IRType) (body : FnBody) (info : DeclInfo)
| extern (f : FunId) (xs : Array Param) (type : IRType) (ext : ExternAttrData)
deriving Inhabited
namespace Decl
instance : Inhabited Decl :=
⟨fdecl arbitrary arbitrary IRType.irrelevant arbitrary⟩
def name : Decl → FunId
| Decl.fdecl f _ _ _ => f
| Decl.extern f _ _ _ => f
| Decl.fdecl f .. => f
| Decl.extern f .. => f
def params : Decl → Array Param
| Decl.fdecl _ xs _ _ => xs
| Decl.extern _ xs _ _ => xs
| Decl.fdecl (xs := xs) .. => xs
| Decl.extern (xs := xs) .. => xs
def resultType : Decl → IRType
| Decl.fdecl _ _ t _ => t
| Decl.extern _ _ t _ => t
| Decl.fdecl (type := t) .. => t
| Decl.extern (type := t) .. => t
def isExtern : Decl → Bool
| Decl.extern _ _ _ _ => true
| Decl.extern .. => true
| _ => false
def getInfo : Decl → DeclInfo
| Decl.fdecl (info := info) .. => info
| _ => {}
def updateBody! (d : Decl) (bNew : FnBody) : Decl :=
match d with
| Decl.fdecl f xs t b info => Decl.fdecl f xs t bNew info
| _ => panic! "expected definition"
end Decl
@[export lean_ir_mk_decl] def mkDecl (f : FunId) (xs : Array Param) (ty : IRType) (b : FnBody) : Decl := Decl.fdecl f xs ty b
@[export lean_ir_mk_decl] def mkDecl (f : FunId) (xs : Array Param) (ty : IRType) (b : FnBody) : Decl :=
Decl.fdecl f xs ty b {}
@[export lean_ir_mk_extern_decl] def mkExternDecl (f : FunId) (xs : Array Param) (ty : IRType) (e : ExternAttrData) : Decl :=
Decl.extern f xs ty e

View file

@ -88,9 +88,9 @@ partial def visitFnBody (fnid : FunId) : FnBody → StateM ParamMap Unit
let (instr, b) := e.split
visitFnBody fnid b
def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit :=
def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit :=
decls.forM fun decl => match decl with
| Decl.fdecl f xs _ b => do
| Decl.fdecl (f := f) (xs := xs) (body := b) .. => do
let exported := isExport env f
modify fun m => m.insert (ParamMap.Key.decl f) (initBorrowIfNotExported exported xs)
visitFnBody f b
@ -122,10 +122,10 @@ partial def visitFnBody (fn : FunId) (paramMap : ParamMap) : FnBody → FnBody
def visitDecls (decls : Array Decl) (paramMap : ParamMap) : Array Decl :=
decls.map fun decl => match decl with
| Decl.fdecl f xs ty b =>
| Decl.fdecl f xs ty b info =>
let b := visitFnBody f paramMap b
match paramMap.find? (ParamMap.Key.decl f) with
| some xs => Decl.fdecl f xs ty b
| some xs => Decl.fdecl f xs ty b info
| none => unreachable!
| other => other
@ -284,7 +284,7 @@ partial def collectFnBody : FnBody → M Unit
| e => do unless e.isTerminal do collectFnBody e.body
partial def collectDecl : Decl → M Unit
| Decl.fdecl f ys _ b =>
| Decl.fdecl (f := f) (xs := ys) (body := b) .. =>
withReader (fun ctx => let ctx := updateParamSet ctx ys; { ctx with currFn := f }) do
collectFnBody b
updateParamMap (ParamMap.Key.decl f)

View file

@ -61,12 +61,12 @@ def mkBoxedVersionAux (decl : Decl) : N Decl := do
let newVDecls := newVDecls.push (FnBody.vdecl r decl.resultType (Expr.fap decl.name xs) arbitrary)
let body ←
if !decl.resultType.isScalar then
pure $ reshape newVDecls (FnBody.ret (Arg.var r))
pure <| reshape newVDecls (FnBody.ret (Arg.var r))
else
let newR ← N.mkFresh
let newVDecls := newVDecls.push (FnBody.vdecl newR IRType.object (Expr.box decl.resultType r) arbitrary)
pure $ reshape newVDecls (FnBody.ret (Arg.var newR))
pure $ Decl.fdecl (mkBoxedName decl.name) qs IRType.object body
pure <| reshape newVDecls (FnBody.ret (Arg.var newR))
return Decl.fdecl (mkBoxedName decl.name) qs IRType.object body decl.getInfo
def mkBoxedVersion (decl : Decl) : Decl :=
(mkBoxedVersionAux decl).run' 1
@ -199,11 +199,10 @@ def mkCast (x : VarId) (xType : IRType) (expectedType : IRType) : M Expr := do
| none => do
let auxName := ctx.f ++ ((`_boxed_const).appendIndexAfter s.nextAuxId)
let auxConst := Expr.fap auxName #[]
let auxDecl := Decl.fdecl auxName #[] expectedType body
modify fun s => {
s with
auxDecls := s.auxDecls.push auxDecl,
auxDeclCache := s.auxDeclCache.cons body auxConst,
let auxDecl := Decl.fdecl auxName #[] expectedType body {}
modify fun s => { s with
auxDecls := s.auxDecls.push auxDecl
auxDeclCache := s.auxDeclCache.cons body auxConst
nextAuxId := s.nextAuxId + 1
}
pure auxConst
@ -330,11 +329,11 @@ def run (env : Environment) (decls : Array Decl) : Array Decl :=
let ctx : BoxingContext := { decls := decls, env := env }
let decls := decls.foldl (init := #[]) fun newDecls decl =>
match decl with
| Decl.fdecl f xs t b =>
| Decl.fdecl (f := f) (xs := xs) (type := t) (body := b) .. =>
let nextIdx := decl.maxIndex + 1
let (b, s) := (withParams xs (visitFnBody b) { ctx with f := f, resultType := t }).run { nextIdx := nextIdx }
let newDecls := newDecls ++ s.auxDecls
let newDecl := Decl.fdecl f xs t b
let newDecl := decl.updateBody! b
let newDecl := newDecl.elimDead
newDecls.push newDecl
| d => newDecls.push d

View file

@ -149,8 +149,8 @@ partial def checkFnBody : FnBody → M Unit
| FnBody.unreachable => pure ()
def checkDecl : Decl → M Unit
| Decl.fdecl f xs t b => withParams xs (checkFnBody b)
| Decl.extern f xs t _ => withParams xs (pure ())
| Decl.fdecl (xs := xs) (body := b) .. => withParams xs (checkFnBody b)
| Decl.extern (xs := xs) .. => withParams xs (pure ())
end Checker

View file

@ -239,7 +239,7 @@ def inferStep : M Bool := do
modify fun s => { s with assignments := ctx.decls.map fun _ => {} }
ctx.decls.size.foldM (init := false) fun idx modified => do
match ctx.decls[idx] with
| Decl.fdecl fid ys _ b => do
| Decl.fdecl (xs := ys) (body := b) .. => do
let s ← get
let currVals := s.funVals[idx]
withReader (fun ctx => { ctx with currFnIdx := idx }) do
@ -271,8 +271,9 @@ partial def elimDeadAux (assignment : Assignment) : FnBody → FnBody
let b := elimDeadAux assignment b
instr.setBody b
partial def elimDead (assignment : Assignment) : Decl → Decl
| Decl.fdecl fid ys t b => Decl.fdecl fid ys t $ elimDeadAux assignment b
partial def elimDead (assignment : Assignment) (d : Decl) : Decl :=
match d with
| Decl.fdecl (body := b) .. => d.updateBody! <| elimDeadAux assignment b
| other => other
end UnreachableBranches

View file

@ -39,8 +39,9 @@ partial def FnBody.elimDead (b : FnBody) : FnBody :=
reshapeWithoutDead bs term
/-- Eliminate dead let-declarations and join points -/
def Decl.elimDead : Decl → Decl
| Decl.fdecl f xs t b => Decl.fdecl f xs t b.elimDead
| other => other
def Decl.elimDead (d : Decl) : Decl :=
match d with
| Decl.fdecl (body := b) .. => d.updateBody! b.elimDead
| other => other
end Lean.IR

View file

@ -133,7 +133,7 @@ def emitFnDecls : M Unit := do
def emitMainFn : M Unit := do
let d ← getDecl `main
match d with
| Decl.fdecl f xs t b => do
| Decl.fdecl (f := f) (xs := xs) (type := t) (body := b) .. => do
unless xs.size == 2 || xs.size == 1 do throw "invalid main function, incorrect arity when generating code"
let env ← getEnv
let usesLeanAPI := usesModuleFrom env `Lean
@ -629,7 +629,7 @@ def emitDeclAux (d : Decl) : M Unit := do
withReader (fun ctx => { ctx with jpMap := jpMap }) do
unless hasInitAttr env d.name do
match d with
| Decl.fdecl f xs t b =>
| Decl.fdecl (f := f) (xs := xs) (type := t) (body := b) .. =>
let baseName ← toCName f;
if xs.size == 0 then
emit "static "

View file

@ -42,8 +42,8 @@ def collectInitDecl (fn : Name) : M Unit := do
| _ => pure ()
def collectDecl : Decl → M NameSet
| Decl.fdecl fn _ _ b => collectInitDecl fn *> CollectUsedDecls.collectFnBody b *> get
| Decl.extern fn _ _ _ => collectInitDecl fn *> get
| Decl.fdecl (f := f) (body := b) .. => collectInitDecl f *> CollectUsedDecls.collectFnBody b *> get
| Decl.extern (f := f) .. => collectInitDecl f *> get
end CollectUsedDecls
@ -70,7 +70,7 @@ partial def collectFnBody : FnBody → Collector
| e => if e.isTerminal then id else collectFnBody e.body
def collectDecl : Decl → Collector
| Decl.fdecl _ xs _ b => collectParams xs ∘ collectFnBody b
| Decl.fdecl (xs := xs) (body := b) .. => collectParams xs ∘ collectFnBody b
| _ => id
end CollectMaps

View file

@ -30,7 +30,7 @@ end CollectProjMap
This function assumes variable ids have been normalized -/
def mkProjMap (d : Decl) : ProjMap :=
match d with
| Decl.fdecl _ _ _ b => CollectProjMap.collectFnBody b {}
| Decl.fdecl (body := b) .. => CollectProjMap.collectFnBody b {}
| _ => {}
structure Context where
@ -265,11 +265,11 @@ partial def searchAndExpand : FnBody → Array FnBody → M FnBody
def main (d : Decl) : Decl :=
match d with
| (Decl.fdecl f xs t b) =>
| Decl.fdecl (body := b) .. =>
let m := mkProjMap d
let nextIdx := d.maxIndex + 1
let b := (searchAndExpand b #[] { projMap := m }).run' nextIdx
Decl.fdecl f xs t b
let bNew := (searchAndExpand b #[] { projMap := m }).run' nextIdx
d.updateBody! bNew
| d => d
end ExpandResetReuse

View file

@ -121,8 +121,8 @@ instance : ToString FnBody := ⟨fun b => (format b).pretty⟩
def formatDecl (decl : Decl) (indent : Nat := 2) : Format :=
match decl with
| Decl.fdecl f xs ty b => "def " ++ format f ++ formatParams xs ++ format " : " ++ format ty ++ " :=" ++ Format.nest indent (Format.line ++ formatFnBody b indent)
| Decl.extern f xs ty _ => "extern " ++ format f ++ formatParams xs ++ format " : " ++ format ty
| Decl.fdecl f xs ty b _ => "def " ++ format f ++ formatParams xs ++ format " : " ++ format ty ++ " :=" ++ Format.nest indent (Format.line ++ formatFnBody b indent)
| Decl.extern f xs ty _ => "extern " ++ format f ++ formatParams xs ++ format " : " ++ format ty
instance : ToFormat Decl := ⟨formatDecl⟩

View file

@ -72,8 +72,8 @@ partial def collectFnBody : FnBody → Collector
| FnBody.unreachable => skip
partial def collectDecl : Decl → Collector
| Decl.fdecl _ xs _ b => collectParams xs >> collectFnBody b
| Decl.extern _ xs _ _ => collectParams xs
| Decl.fdecl (xs := xs) (body := b) .. => collectParams xs >> collectFnBody b
| Decl.extern (xs := xs) .. => collectParams xs
end MaxIndex

View file

@ -24,8 +24,8 @@ partial def checkFnBody : FnBody → M Bool
| b => if b.isTerminal then pure true else checkFnBody b.body
partial def checkDecl : Decl → M Bool
| Decl.fdecl _ xs _ b => checkParams xs <&&> checkFnBody b
| Decl.extern _ xs _ _ => checkParams xs
| Decl.fdecl (xs := xs) (body := b) .. => checkParams xs <&&> checkFnBody b
| Decl.extern (xs := xs) .. => checkParams xs
end UniqueIds
@ -112,9 +112,10 @@ partial def normFnBody : FnBody → N FnBody
| FnBody.ret x => return FnBody.ret (← normArg x)
| FnBody.unreachable => pure FnBody.unreachable
def normDecl : Decl → N Decl
| Decl.fdecl f xs t b => withParams xs fun xs => Decl.fdecl f xs t <$> normFnBody b
| other => pure other
def normDecl (d : Decl) : N Decl :=
match d with
| Decl.fdecl (xs := xs) (body := b) .. => withParams xs fun xs => return d.updateBody! (← normFnBody b)
| other => pure other
end NormalizeIds

View file

@ -49,8 +49,9 @@ partial def FnBody.pushProj (b : FnBody) : FnBody :=
| other => reshape bs term
/-- Push projections inside `case` branches. -/
def Decl.pushProj : Decl → Decl
| Decl.fdecl f xs t b => (Decl.fdecl f xs t b.pushProj).normalizeIds
| other => other
def Decl.pushProj (d : Decl) : Decl :=
match d with
| Decl.fdecl (body := b) .. => d.updateBody! b.pushProj |>.normalizeIds
| other => other
end Lean.IR

View file

@ -271,13 +271,14 @@ partial def visitFnBody : FnBody → Context → (FnBody × LiveVarSet)
| FnBody.unreachable, _ => (FnBody.unreachable, {})
| other, ctx => (other, {}) -- unreachable if well-formed
partial def visitDecl (env : Environment) (decls : Array Decl) : Decl → Decl
| Decl.fdecl f xs t b =>
partial def visitDecl (env : Environment) (decls : Array Decl) (d : Decl) : Decl :=
match d with
| Decl.fdecl (xs := xs) (body := b) .. =>
let ctx : Context := { env := env, decls := decls }
let ctx := updateVarInfoWithParams ctx xs
let (b, bLiveVars) := visitFnBody b ctx
let b := addDecForDeadParams ctx xs b bLiveVars
Decl.fdecl f xs t b
d.updateBody! b
| other => other
end ExplicitRC

View file

@ -148,11 +148,12 @@ end ResetReuse
open ResetReuse
def Decl.insertResetReuse : Decl → Decl
| d@(Decl.fdecl f xs t b) =>
def Decl.insertResetReuse (d : Decl) : Decl :=
match d with
| Decl.fdecl (body := b) ..=>
let nextIndex := d.maxIndex + 1
let b := (R b {}).run' nextIndex
Decl.fdecl f xs t b
let bNew := (R b {}).run' nextIndex
d.updateBody! bNew
| other => other
end Lean.IR

View file

@ -66,8 +66,9 @@ partial def FnBody.simpCase (b : FnBody) : FnBody :=
- Remove unreachable branches.
- Remove `case` if there is only one branch.
- Merge most common branches using `Alt.default`. -/
def Decl.simpCase : Decl → Decl
| Decl.fdecl f xs t b => Decl.fdecl f xs t b.simpCase
| other => other
def Decl.simpCase (d : Decl) : Decl :=
match d with
| Decl.fdecl (body := b) .. => d.updateBody! b.simpCase
| other => other
end Lean.IR