refactor: add MonadFVarSubst class
This commit is contained in:
parent
6a7ccb5797
commit
7b161d33d1
3 changed files with 60 additions and 32 deletions
|
|
@ -16,8 +16,13 @@ structure State where
|
|||
map : Std.PHashMap Expr FVarId := {}
|
||||
subst : FVarSubst := {}
|
||||
|
||||
|
||||
abbrev M := StateRefT State CompilerM
|
||||
|
||||
instance : MonadFVarSubst M where
|
||||
getSubst := return (← get).subst
|
||||
modifySubst f := modify fun s => { s with subst := f s.subst }
|
||||
|
||||
@[inline] def getSubst : M FVarSubst :=
|
||||
return (← get).subst
|
||||
|
||||
|
|
@ -42,15 +47,15 @@ partial def Code.cse (code : Code) : CompilerM Code :=
|
|||
go code |>.run' {}
|
||||
where
|
||||
goFunDecl (decl : FunDecl) : M FunDecl := do
|
||||
let type := (← getSubst).applyToExpr decl.type
|
||||
let params ← (← getSubst).applyToParams decl.params
|
||||
let type ← normExpr decl.type
|
||||
let params ← normParams decl.params
|
||||
let value ← withNewScope do go decl.value
|
||||
decl.update type params value
|
||||
|
||||
go (code : Code) : M Code := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let decl ← (← getSubst).applyToLetDecl decl
|
||||
let decl ← normLetDecl decl
|
||||
if decl.pure then
|
||||
-- We only apply CSE to pure code
|
||||
match (← get).map.find? decl.value with
|
||||
|
|
@ -80,17 +85,16 @@ where
|
|||
-/
|
||||
return code.updateFun! decl (← go k)
|
||||
| .cases c =>
|
||||
let discr := (← getSubst).applyToFVar c.discr
|
||||
let resultType := (← getSubst).applyToExpr c.resultType
|
||||
let discr ← normFVar c.discr
|
||||
let resultType ← normExpr c.resultType
|
||||
let alts ← c.alts.mapMonoM fun alt => do
|
||||
match alt with
|
||||
| .alt _ ps k => withNewScope do
|
||||
let ps ← (← getSubst).applyToParams ps
|
||||
return alt.updateAlt! ps (← go k)
|
||||
return alt.updateAlt! (← normParams ps) (← go k)
|
||||
| .default k => withNewScope do return alt.updateCode (← go k)
|
||||
return code.updateCases! resultType discr alts
|
||||
| .return fvarId => return code.updateReturn! ((← getSubst).applyToFVar fvarId)
|
||||
| .jmp fvarId args => return code.updateJmp! ((← getSubst).applyToFVar fvarId) (← args.mapMonoM fun arg => return (← getSubst).applyToExpr arg)
|
||||
| .return fvarId => return code.updateReturn! (← normFVar fvarId)
|
||||
| .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normExprs args)
|
||||
| .unreach .. => return code
|
||||
|
||||
/--
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ So, when inlining we often want to replace a free variable with a type or type f
|
|||
-/
|
||||
abbrev FVarSubst := Std.HashMap FVarId Expr
|
||||
|
||||
partial def FVarSubst.applyToExpr (s : FVarSubst) (e : Expr) : Expr :=
|
||||
private partial def normExprImp (s : FVarSubst) (e : Expr) : Expr :=
|
||||
go e
|
||||
where
|
||||
go (e : Expr) : Expr :=
|
||||
|
|
@ -75,21 +75,40 @@ where
|
|||
else
|
||||
e
|
||||
|
||||
def FVarSubst.applyToFVar (s : FVarSubst) (fvarId : FVarId) : FVarId :=
|
||||
private def normFVarImp (s : FVarSubst) (fvarId : FVarId) : FVarId :=
|
||||
match s.find? fvarId with
|
||||
| some (.fvar fvarId') => fvarId'
|
||||
| some _ => panic! "invalid LCNF substitution of free variable with expression"
|
||||
| none => fvarId
|
||||
|
||||
class MonadFVarSubst (m : Type → Type) where
|
||||
getSubst : m FVarSubst
|
||||
modifySubst : (FVarSubst → FVarSubst) → m Unit
|
||||
|
||||
export MonadFVarSubst (getSubst modifySubst)
|
||||
|
||||
@[inline] def addSubst [MonadFVarSubst m] (fvarId : FVarId) (e : Expr) : m Unit :=
|
||||
modifySubst fun s => s.insert fvarId e
|
||||
|
||||
@[inline] def addFVarSubst [MonadFVarSubst m] (fvarId fvarId' : FVarId) : m Unit :=
|
||||
addSubst fvarId (.fvar fvarId')
|
||||
|
||||
@[inline] def normFVar [MonadFVarSubst m] [Monad m] (fvarId : FVarId) : m FVarId :=
|
||||
return normFVarImp (← getSubst) fvarId
|
||||
|
||||
@[inline] def normExpr [MonadFVarSubst m] [Monad m] (e : Expr) : m Expr :=
|
||||
return normExprImp (← getSubst) e
|
||||
|
||||
def normExprs [MonadFVarSubst m] [Monad m] (es : Array Expr) : m (Array Expr) :=
|
||||
es.mapMonoM normExpr
|
||||
|
||||
namespace Internalize
|
||||
|
||||
abbrev M := StateRefT FVarSubst CompilerM
|
||||
|
||||
@[inline] private abbrev translateFVarId (fvarId : FVarId) : M FVarId := do
|
||||
return (← get).applyToFVar fvarId
|
||||
|
||||
@[inline] private abbrev translateExpr (e : Expr) : M Expr :=
|
||||
return (← get).applyToExpr e
|
||||
instance : MonadFVarSubst M where
|
||||
getSubst := get
|
||||
modifySubst := modify
|
||||
|
||||
private def mkNewFVarId (fvarId : FVarId) : M FVarId := do
|
||||
let fvarId' ← Lean.mkFreshFVarId
|
||||
|
|
@ -97,7 +116,7 @@ private def mkNewFVarId (fvarId : FVarId) : M FVarId := do
|
|||
return fvarId'
|
||||
|
||||
private def addParam (p : Param) : M Param := do
|
||||
let type ← translateExpr p.type
|
||||
let type ← normExpr p.type
|
||||
let fvarId ← mkNewFVarId p.fvarId
|
||||
modifyLCtx fun lctx => lctx.addLocalDecl fvarId p.binderName type
|
||||
return { p with fvarId, type }
|
||||
|
|
@ -105,7 +124,7 @@ private def addParam (p : Param) : M Param := do
|
|||
mutual
|
||||
|
||||
partial def internalizeFunDecl (decl : FunDecl) : M FunDecl := do
|
||||
let type ← translateExpr decl.type
|
||||
let type ← normExpr decl.type
|
||||
let params ← decl.params.mapM addParam
|
||||
let value ← internalizeCode decl.value
|
||||
let fvarId ← mkNewFVarId decl.fvarId
|
||||
|
|
@ -116,8 +135,8 @@ partial def internalizeFunDecl (decl : FunDecl) : M FunDecl := do
|
|||
partial def internalizeCode (code : Code) : M Code := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let type ← translateExpr decl.type
|
||||
let value ← translateExpr decl.value
|
||||
let type ← normExpr decl.type
|
||||
let value ← normExpr decl.value
|
||||
let fvarId ← mkNewFVarId decl.fvarId
|
||||
modifyLCtx fun lctx => lctx.addLetDecl fvarId decl.binderName type value
|
||||
let k ← internalizeCode k
|
||||
|
|
@ -126,11 +145,11 @@ partial def internalizeCode (code : Code) : M Code := do
|
|||
return .fun (← internalizeFunDecl decl) (← internalizeCode k)
|
||||
| .jp decl k =>
|
||||
return .jp (← internalizeFunDecl decl) (← internalizeCode k)
|
||||
| .return fvarId => return .return (← translateFVarId fvarId)
|
||||
| .jmp fvarId args => return .jmp (← translateFVarId fvarId) (← args.mapM translateExpr)
|
||||
| .unreach type => return .unreach (← translateExpr type)
|
||||
| .return fvarId => return .return (← normFVar fvarId)
|
||||
| .jmp fvarId args => return .jmp (← normFVar fvarId) (← args.mapM normExpr)
|
||||
| .unreach type => return .unreach (← normExpr type)
|
||||
| .cases c =>
|
||||
let discr ← translateFVarId c.discr
|
||||
let discr ← normFVar c.discr
|
||||
let alts ← c.alts.mapM fun
|
||||
| .alt ctorName params k => return .alt ctorName (← params.mapM addParam) (← internalizeCode k)
|
||||
| .default k => return .default (← internalizeCode k)
|
||||
|
|
@ -151,7 +170,7 @@ def Decl.internalize (decl : Decl) (s : FVarSubst := {}): CompilerM Decl :=
|
|||
go decl |>.run' s
|
||||
where
|
||||
go (decl : Decl) : M Decl := do
|
||||
let type ← translateExpr decl.type
|
||||
let type ← normExpr decl.type
|
||||
let params ← decl.params.mapM addParam
|
||||
let value ← internalizeCode decl.value
|
||||
return { decl with type, params, value }
|
||||
|
|
@ -212,14 +231,14 @@ abbrev FunDeclCore.update' (decl : FunDecl) (type : Expr) (value : Code) : Compi
|
|||
abbrev FunDeclCore.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl :=
|
||||
decl.update decl.type decl.params value
|
||||
|
||||
def FVarSubst.applyToParam (s : FVarSubst) (p : Param) : CompilerM Param :=
|
||||
p.update (s.applyToExpr p.type)
|
||||
@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (p : Param) : m Param := do
|
||||
p.update (← normExpr p.type)
|
||||
|
||||
def FVarSubst.applyToParams (s : FVarSubst) (ps : Array Param) : CompilerM (Array Param) :=
|
||||
ps.mapMonoM s.applyToParam
|
||||
def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (ps : Array Param) : m (Array Param) :=
|
||||
ps.mapMonoM normParam
|
||||
|
||||
def FVarSubst.applyToLetDecl (s : FVarSubst) (decl : LetDecl) : CompilerM LetDecl :=
|
||||
decl.update (s.applyToExpr decl.type) (s.applyToExpr decl.value)
|
||||
def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (decl : LetDecl) : m LetDecl := do
|
||||
decl.update (← normExpr decl.type) (← normExpr decl.value)
|
||||
|
||||
def mkFreshBinderName (binderName := `_x): CompilerM Name := do
|
||||
let declName := .num binderName (← get).nextIdx
|
||||
|
|
|
|||
|
|
@ -132,6 +132,10 @@ structure State where
|
|||
|
||||
abbrev SimpM := ReaderT Context $ StateRefT State CompilerM
|
||||
|
||||
instance : MonadFVarSubst SimpM where
|
||||
getSubst := return (← get).subst
|
||||
modifySubst f := modify fun s => { s with subst := f s.subst }
|
||||
|
||||
partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit :=
|
||||
go code
|
||||
where
|
||||
|
|
@ -202,8 +206,9 @@ partial def simp (code : Code) : SimpM Code := do
|
|||
incVisited
|
||||
match code with
|
||||
| .return fvarId =>
|
||||
let fvarId ← normFVar fvarId
|
||||
markUsedFVar fvarId
|
||||
return code
|
||||
return code.updateReturn! fvarId
|
||||
| .unreach .. => return code
|
||||
| _ => return code
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue