feat: add IR.DeclInfo
This commit is contained in:
parent
3d01327129
commit
72a8fb84b5
16 changed files with 86 additions and 67 deletions
|
|
@ -392,34 +392,47 @@ def reshape (bs : Array FnBody) (term : FnBody) : FnBody :=
|
|||
@[export lean_ir_mk_alt] def mkAlt (n : Name) (cidx : Nat) (size : Nat) (usize : Nat) (ssize : Nat) (b : FnBody) : Alt :=
|
||||
Alt.ctor ⟨n, cidx, size, usize, ssize⟩ b
|
||||
|
||||
/-- Extra information associated with a declaration. -/
|
||||
structure DeclInfo where
|
||||
/-- If `some <blame>`, then declaration depends on `<blame>` which uses a `sorry` axiom. -/
|
||||
sorryDep? : Option Name := none
|
||||
|
||||
inductive Decl where
|
||||
| fdecl (f : FunId) (xs : Array Param) (ty : IRType) (b : FnBody)
|
||||
| extern (f : FunId) (xs : Array Param) (ty : IRType) (ext : ExternAttrData)
|
||||
| fdecl (f : FunId) (xs : Array Param) (type : IRType) (body : FnBody) (info : DeclInfo)
|
||||
| extern (f : FunId) (xs : Array Param) (type : IRType) (ext : ExternAttrData)
|
||||
deriving Inhabited
|
||||
|
||||
namespace Decl
|
||||
|
||||
instance : Inhabited Decl :=
|
||||
⟨fdecl arbitrary arbitrary IRType.irrelevant arbitrary⟩
|
||||
|
||||
def name : Decl → FunId
|
||||
| Decl.fdecl f _ _ _ => f
|
||||
| Decl.extern f _ _ _ => f
|
||||
| Decl.fdecl f .. => f
|
||||
| Decl.extern f .. => f
|
||||
|
||||
def params : Decl → Array Param
|
||||
| Decl.fdecl _ xs _ _ => xs
|
||||
| Decl.extern _ xs _ _ => xs
|
||||
| Decl.fdecl (xs := xs) .. => xs
|
||||
| Decl.extern (xs := xs) .. => xs
|
||||
|
||||
def resultType : Decl → IRType
|
||||
| Decl.fdecl _ _ t _ => t
|
||||
| Decl.extern _ _ t _ => t
|
||||
| Decl.fdecl (type := t) .. => t
|
||||
| Decl.extern (type := t) .. => t
|
||||
|
||||
def isExtern : Decl → Bool
|
||||
| Decl.extern _ _ _ _ => true
|
||||
| Decl.extern .. => true
|
||||
| _ => false
|
||||
|
||||
def getInfo : Decl → DeclInfo
|
||||
| Decl.fdecl (info := info) .. => info
|
||||
| _ => {}
|
||||
|
||||
def updateBody! (d : Decl) (bNew : FnBody) : Decl :=
|
||||
match d with
|
||||
| Decl.fdecl f xs t b info => Decl.fdecl f xs t bNew info
|
||||
| _ => panic! "expected definition"
|
||||
|
||||
end Decl
|
||||
|
||||
@[export lean_ir_mk_decl] def mkDecl (f : FunId) (xs : Array Param) (ty : IRType) (b : FnBody) : Decl := Decl.fdecl f xs ty b
|
||||
@[export lean_ir_mk_decl] def mkDecl (f : FunId) (xs : Array Param) (ty : IRType) (b : FnBody) : Decl :=
|
||||
Decl.fdecl f xs ty b {}
|
||||
|
||||
@[export lean_ir_mk_extern_decl] def mkExternDecl (f : FunId) (xs : Array Param) (ty : IRType) (e : ExternAttrData) : Decl :=
|
||||
Decl.extern f xs ty e
|
||||
|
|
|
|||
|
|
@ -88,9 +88,9 @@ partial def visitFnBody (fnid : FunId) : FnBody → StateM ParamMap Unit
|
|||
let (instr, b) := e.split
|
||||
visitFnBody fnid b
|
||||
|
||||
def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit :=
|
||||
def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit :=
|
||||
decls.forM fun decl => match decl with
|
||||
| Decl.fdecl f xs _ b => do
|
||||
| Decl.fdecl (f := f) (xs := xs) (body := b) .. => do
|
||||
let exported := isExport env f
|
||||
modify fun m => m.insert (ParamMap.Key.decl f) (initBorrowIfNotExported exported xs)
|
||||
visitFnBody f b
|
||||
|
|
@ -122,10 +122,10 @@ partial def visitFnBody (fn : FunId) (paramMap : ParamMap) : FnBody → FnBody
|
|||
|
||||
def visitDecls (decls : Array Decl) (paramMap : ParamMap) : Array Decl :=
|
||||
decls.map fun decl => match decl with
|
||||
| Decl.fdecl f xs ty b =>
|
||||
| Decl.fdecl f xs ty b info =>
|
||||
let b := visitFnBody f paramMap b
|
||||
match paramMap.find? (ParamMap.Key.decl f) with
|
||||
| some xs => Decl.fdecl f xs ty b
|
||||
| some xs => Decl.fdecl f xs ty b info
|
||||
| none => unreachable!
|
||||
| other => other
|
||||
|
||||
|
|
@ -284,7 +284,7 @@ partial def collectFnBody : FnBody → M Unit
|
|||
| e => do unless e.isTerminal do collectFnBody e.body
|
||||
|
||||
partial def collectDecl : Decl → M Unit
|
||||
| Decl.fdecl f ys _ b =>
|
||||
| Decl.fdecl (f := f) (xs := ys) (body := b) .. =>
|
||||
withReader (fun ctx => let ctx := updateParamSet ctx ys; { ctx with currFn := f }) do
|
||||
collectFnBody b
|
||||
updateParamMap (ParamMap.Key.decl f)
|
||||
|
|
|
|||
|
|
@ -61,12 +61,12 @@ def mkBoxedVersionAux (decl : Decl) : N Decl := do
|
|||
let newVDecls := newVDecls.push (FnBody.vdecl r decl.resultType (Expr.fap decl.name xs) arbitrary)
|
||||
let body ←
|
||||
if !decl.resultType.isScalar then
|
||||
pure $ reshape newVDecls (FnBody.ret (Arg.var r))
|
||||
pure <| reshape newVDecls (FnBody.ret (Arg.var r))
|
||||
else
|
||||
let newR ← N.mkFresh
|
||||
let newVDecls := newVDecls.push (FnBody.vdecl newR IRType.object (Expr.box decl.resultType r) arbitrary)
|
||||
pure $ reshape newVDecls (FnBody.ret (Arg.var newR))
|
||||
pure $ Decl.fdecl (mkBoxedName decl.name) qs IRType.object body
|
||||
pure <| reshape newVDecls (FnBody.ret (Arg.var newR))
|
||||
return Decl.fdecl (mkBoxedName decl.name) qs IRType.object body decl.getInfo
|
||||
|
||||
def mkBoxedVersion (decl : Decl) : Decl :=
|
||||
(mkBoxedVersionAux decl).run' 1
|
||||
|
|
@ -199,11 +199,10 @@ def mkCast (x : VarId) (xType : IRType) (expectedType : IRType) : M Expr := do
|
|||
| none => do
|
||||
let auxName := ctx.f ++ ((`_boxed_const).appendIndexAfter s.nextAuxId)
|
||||
let auxConst := Expr.fap auxName #[]
|
||||
let auxDecl := Decl.fdecl auxName #[] expectedType body
|
||||
modify fun s => {
|
||||
s with
|
||||
auxDecls := s.auxDecls.push auxDecl,
|
||||
auxDeclCache := s.auxDeclCache.cons body auxConst,
|
||||
let auxDecl := Decl.fdecl auxName #[] expectedType body {}
|
||||
modify fun s => { s with
|
||||
auxDecls := s.auxDecls.push auxDecl
|
||||
auxDeclCache := s.auxDeclCache.cons body auxConst
|
||||
nextAuxId := s.nextAuxId + 1
|
||||
}
|
||||
pure auxConst
|
||||
|
|
@ -330,11 +329,11 @@ def run (env : Environment) (decls : Array Decl) : Array Decl :=
|
|||
let ctx : BoxingContext := { decls := decls, env := env }
|
||||
let decls := decls.foldl (init := #[]) fun newDecls decl =>
|
||||
match decl with
|
||||
| Decl.fdecl f xs t b =>
|
||||
| Decl.fdecl (f := f) (xs := xs) (type := t) (body := b) .. =>
|
||||
let nextIdx := decl.maxIndex + 1
|
||||
let (b, s) := (withParams xs (visitFnBody b) { ctx with f := f, resultType := t }).run { nextIdx := nextIdx }
|
||||
let newDecls := newDecls ++ s.auxDecls
|
||||
let newDecl := Decl.fdecl f xs t b
|
||||
let newDecl := decl.updateBody! b
|
||||
let newDecl := newDecl.elimDead
|
||||
newDecls.push newDecl
|
||||
| d => newDecls.push d
|
||||
|
|
|
|||
|
|
@ -149,8 +149,8 @@ partial def checkFnBody : FnBody → M Unit
|
|||
| FnBody.unreachable => pure ()
|
||||
|
||||
def checkDecl : Decl → M Unit
|
||||
| Decl.fdecl f xs t b => withParams xs (checkFnBody b)
|
||||
| Decl.extern f xs t _ => withParams xs (pure ())
|
||||
| Decl.fdecl (xs := xs) (body := b) .. => withParams xs (checkFnBody b)
|
||||
| Decl.extern (xs := xs) .. => withParams xs (pure ())
|
||||
|
||||
end Checker
|
||||
|
||||
|
|
|
|||
|
|
@ -239,7 +239,7 @@ def inferStep : M Bool := do
|
|||
modify fun s => { s with assignments := ctx.decls.map fun _ => {} }
|
||||
ctx.decls.size.foldM (init := false) fun idx modified => do
|
||||
match ctx.decls[idx] with
|
||||
| Decl.fdecl fid ys _ b => do
|
||||
| Decl.fdecl (xs := ys) (body := b) .. => do
|
||||
let s ← get
|
||||
let currVals := s.funVals[idx]
|
||||
withReader (fun ctx => { ctx with currFnIdx := idx }) do
|
||||
|
|
@ -271,8 +271,9 @@ partial def elimDeadAux (assignment : Assignment) : FnBody → FnBody
|
|||
let b := elimDeadAux assignment b
|
||||
instr.setBody b
|
||||
|
||||
partial def elimDead (assignment : Assignment) : Decl → Decl
|
||||
| Decl.fdecl fid ys t b => Decl.fdecl fid ys t $ elimDeadAux assignment b
|
||||
partial def elimDead (assignment : Assignment) (d : Decl) : Decl :=
|
||||
match d with
|
||||
| Decl.fdecl (body := b) .. => d.updateBody! <| elimDeadAux assignment b
|
||||
| other => other
|
||||
|
||||
end UnreachableBranches
|
||||
|
|
|
|||
|
|
@ -39,8 +39,9 @@ partial def FnBody.elimDead (b : FnBody) : FnBody :=
|
|||
reshapeWithoutDead bs term
|
||||
|
||||
/-- Eliminate dead let-declarations and join points -/
|
||||
def Decl.elimDead : Decl → Decl
|
||||
| Decl.fdecl f xs t b => Decl.fdecl f xs t b.elimDead
|
||||
| other => other
|
||||
def Decl.elimDead (d : Decl) : Decl :=
|
||||
match d with
|
||||
| Decl.fdecl (body := b) .. => d.updateBody! b.elimDead
|
||||
| other => other
|
||||
|
||||
end Lean.IR
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ def emitFnDecls : M Unit := do
|
|||
def emitMainFn : M Unit := do
|
||||
let d ← getDecl `main
|
||||
match d with
|
||||
| Decl.fdecl f xs t b => do
|
||||
| Decl.fdecl (f := f) (xs := xs) (type := t) (body := b) .. => do
|
||||
unless xs.size == 2 || xs.size == 1 do throw "invalid main function, incorrect arity when generating code"
|
||||
let env ← getEnv
|
||||
let usesLeanAPI := usesModuleFrom env `Lean
|
||||
|
|
@ -629,7 +629,7 @@ def emitDeclAux (d : Decl) : M Unit := do
|
|||
withReader (fun ctx => { ctx with jpMap := jpMap }) do
|
||||
unless hasInitAttr env d.name do
|
||||
match d with
|
||||
| Decl.fdecl f xs t b =>
|
||||
| Decl.fdecl (f := f) (xs := xs) (type := t) (body := b) .. =>
|
||||
let baseName ← toCName f;
|
||||
if xs.size == 0 then
|
||||
emit "static "
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ def collectInitDecl (fn : Name) : M Unit := do
|
|||
| _ => pure ()
|
||||
|
||||
def collectDecl : Decl → M NameSet
|
||||
| Decl.fdecl fn _ _ b => collectInitDecl fn *> CollectUsedDecls.collectFnBody b *> get
|
||||
| Decl.extern fn _ _ _ => collectInitDecl fn *> get
|
||||
| Decl.fdecl (f := f) (body := b) .. => collectInitDecl f *> CollectUsedDecls.collectFnBody b *> get
|
||||
| Decl.extern (f := f) .. => collectInitDecl f *> get
|
||||
|
||||
end CollectUsedDecls
|
||||
|
||||
|
|
@ -70,7 +70,7 @@ partial def collectFnBody : FnBody → Collector
|
|||
| e => if e.isTerminal then id else collectFnBody e.body
|
||||
|
||||
def collectDecl : Decl → Collector
|
||||
| Decl.fdecl _ xs _ b => collectParams xs ∘ collectFnBody b
|
||||
| Decl.fdecl (xs := xs) (body := b) .. => collectParams xs ∘ collectFnBody b
|
||||
| _ => id
|
||||
|
||||
end CollectMaps
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ end CollectProjMap
|
|||
This function assumes variable ids have been normalized -/
|
||||
def mkProjMap (d : Decl) : ProjMap :=
|
||||
match d with
|
||||
| Decl.fdecl _ _ _ b => CollectProjMap.collectFnBody b {}
|
||||
| Decl.fdecl (body := b) .. => CollectProjMap.collectFnBody b {}
|
||||
| _ => {}
|
||||
|
||||
structure Context where
|
||||
|
|
@ -265,11 +265,11 @@ partial def searchAndExpand : FnBody → Array FnBody → M FnBody
|
|||
|
||||
def main (d : Decl) : Decl :=
|
||||
match d with
|
||||
| (Decl.fdecl f xs t b) =>
|
||||
| Decl.fdecl (body := b) .. =>
|
||||
let m := mkProjMap d
|
||||
let nextIdx := d.maxIndex + 1
|
||||
let b := (searchAndExpand b #[] { projMap := m }).run' nextIdx
|
||||
Decl.fdecl f xs t b
|
||||
let bNew := (searchAndExpand b #[] { projMap := m }).run' nextIdx
|
||||
d.updateBody! bNew
|
||||
| d => d
|
||||
|
||||
end ExpandResetReuse
|
||||
|
|
|
|||
|
|
@ -121,8 +121,8 @@ instance : ToString FnBody := ⟨fun b => (format b).pretty⟩
|
|||
|
||||
def formatDecl (decl : Decl) (indent : Nat := 2) : Format :=
|
||||
match decl with
|
||||
| Decl.fdecl f xs ty b => "def " ++ format f ++ formatParams xs ++ format " : " ++ format ty ++ " :=" ++ Format.nest indent (Format.line ++ formatFnBody b indent)
|
||||
| Decl.extern f xs ty _ => "extern " ++ format f ++ formatParams xs ++ format " : " ++ format ty
|
||||
| Decl.fdecl f xs ty b _ => "def " ++ format f ++ formatParams xs ++ format " : " ++ format ty ++ " :=" ++ Format.nest indent (Format.line ++ formatFnBody b indent)
|
||||
| Decl.extern f xs ty _ => "extern " ++ format f ++ formatParams xs ++ format " : " ++ format ty
|
||||
|
||||
instance : ToFormat Decl := ⟨formatDecl⟩
|
||||
|
||||
|
|
|
|||
|
|
@ -72,8 +72,8 @@ partial def collectFnBody : FnBody → Collector
|
|||
| FnBody.unreachable => skip
|
||||
|
||||
partial def collectDecl : Decl → Collector
|
||||
| Decl.fdecl _ xs _ b => collectParams xs >> collectFnBody b
|
||||
| Decl.extern _ xs _ _ => collectParams xs
|
||||
| Decl.fdecl (xs := xs) (body := b) .. => collectParams xs >> collectFnBody b
|
||||
| Decl.extern (xs := xs) .. => collectParams xs
|
||||
|
||||
end MaxIndex
|
||||
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ partial def checkFnBody : FnBody → M Bool
|
|||
| b => if b.isTerminal then pure true else checkFnBody b.body
|
||||
|
||||
partial def checkDecl : Decl → M Bool
|
||||
| Decl.fdecl _ xs _ b => checkParams xs <&&> checkFnBody b
|
||||
| Decl.extern _ xs _ _ => checkParams xs
|
||||
| Decl.fdecl (xs := xs) (body := b) .. => checkParams xs <&&> checkFnBody b
|
||||
| Decl.extern (xs := xs) .. => checkParams xs
|
||||
|
||||
end UniqueIds
|
||||
|
||||
|
|
@ -112,9 +112,10 @@ partial def normFnBody : FnBody → N FnBody
|
|||
| FnBody.ret x => return FnBody.ret (← normArg x)
|
||||
| FnBody.unreachable => pure FnBody.unreachable
|
||||
|
||||
def normDecl : Decl → N Decl
|
||||
| Decl.fdecl f xs t b => withParams xs fun xs => Decl.fdecl f xs t <$> normFnBody b
|
||||
| other => pure other
|
||||
def normDecl (d : Decl) : N Decl :=
|
||||
match d with
|
||||
| Decl.fdecl (xs := xs) (body := b) .. => withParams xs fun xs => return d.updateBody! (← normFnBody b)
|
||||
| other => pure other
|
||||
|
||||
end NormalizeIds
|
||||
|
||||
|
|
|
|||
|
|
@ -49,8 +49,9 @@ partial def FnBody.pushProj (b : FnBody) : FnBody :=
|
|||
| other => reshape bs term
|
||||
|
||||
/-- Push projections inside `case` branches. -/
|
||||
def Decl.pushProj : Decl → Decl
|
||||
| Decl.fdecl f xs t b => (Decl.fdecl f xs t b.pushProj).normalizeIds
|
||||
| other => other
|
||||
def Decl.pushProj (d : Decl) : Decl :=
|
||||
match d with
|
||||
| Decl.fdecl (body := b) .. => d.updateBody! b.pushProj |>.normalizeIds
|
||||
| other => other
|
||||
|
||||
end Lean.IR
|
||||
|
|
|
|||
|
|
@ -271,13 +271,14 @@ partial def visitFnBody : FnBody → Context → (FnBody × LiveVarSet)
|
|||
| FnBody.unreachable, _ => (FnBody.unreachable, {})
|
||||
| other, ctx => (other, {}) -- unreachable if well-formed
|
||||
|
||||
partial def visitDecl (env : Environment) (decls : Array Decl) : Decl → Decl
|
||||
| Decl.fdecl f xs t b =>
|
||||
partial def visitDecl (env : Environment) (decls : Array Decl) (d : Decl) : Decl :=
|
||||
match d with
|
||||
| Decl.fdecl (xs := xs) (body := b) .. =>
|
||||
let ctx : Context := { env := env, decls := decls }
|
||||
let ctx := updateVarInfoWithParams ctx xs
|
||||
let (b, bLiveVars) := visitFnBody b ctx
|
||||
let b := addDecForDeadParams ctx xs b bLiveVars
|
||||
Decl.fdecl f xs t b
|
||||
d.updateBody! b
|
||||
| other => other
|
||||
|
||||
end ExplicitRC
|
||||
|
|
|
|||
|
|
@ -148,11 +148,12 @@ end ResetReuse
|
|||
|
||||
open ResetReuse
|
||||
|
||||
def Decl.insertResetReuse : Decl → Decl
|
||||
| d@(Decl.fdecl f xs t b) =>
|
||||
def Decl.insertResetReuse (d : Decl) : Decl :=
|
||||
match d with
|
||||
| Decl.fdecl (body := b) ..=>
|
||||
let nextIndex := d.maxIndex + 1
|
||||
let b := (R b {}).run' nextIndex
|
||||
Decl.fdecl f xs t b
|
||||
let bNew := (R b {}).run' nextIndex
|
||||
d.updateBody! bNew
|
||||
| other => other
|
||||
|
||||
end Lean.IR
|
||||
|
|
|
|||
|
|
@ -66,8 +66,9 @@ partial def FnBody.simpCase (b : FnBody) : FnBody :=
|
|||
- Remove unreachable branches.
|
||||
- Remove `case` if there is only one branch.
|
||||
- Merge most common branches using `Alt.default`. -/
|
||||
def Decl.simpCase : Decl → Decl
|
||||
| Decl.fdecl f xs t b => Decl.fdecl f xs t b.simpCase
|
||||
| other => other
|
||||
def Decl.simpCase (d : Decl) : Decl :=
|
||||
match d with
|
||||
| Decl.fdecl (body := b) .. => d.updateBody! b.simpCase
|
||||
| other => other
|
||||
|
||||
end Lean.IR
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue