refactor: add MonadFVarSubst class

This commit is contained in:
Leonardo de Moura 2022-08-29 08:24:17 -07:00
parent 6a7ccb5797
commit 7b161d33d1
3 changed files with 60 additions and 32 deletions

View file

@ -16,8 +16,13 @@ structure State where
map : Std.PHashMap Expr FVarId := {}
subst : FVarSubst := {}
abbrev M := StateRefT State CompilerM
instance : MonadFVarSubst M where
getSubst := return (← get).subst
modifySubst f := modify fun s => { s with subst := f s.subst }
@[inline] def getSubst : M FVarSubst :=
return (← get).subst
@ -42,15 +47,15 @@ partial def Code.cse (code : Code) : CompilerM Code :=
go code |>.run' {}
where
goFunDecl (decl : FunDecl) : M FunDecl := do
let type := (← getSubst).applyToExpr decl.type
let params ← (← getSubst).applyToParams decl.params
let type ← normExpr decl.type
let params ← normParams decl.params
let value ← withNewScope do go decl.value
decl.update type params value
go (code : Code) : M Code := do
match code with
| .let decl k =>
let decl ← (← getSubst).applyToLetDecl decl
let decl ← normLetDecl decl
if decl.pure then
-- We only apply CSE to pure code
match (← get).map.find? decl.value with
@ -80,17 +85,16 @@ where
-/
return code.updateFun! decl (← go k)
| .cases c =>
let discr := (← getSubst).applyToFVar c.discr
let resultType := (← getSubst).applyToExpr c.resultType
let discr ← normFVar c.discr
let resultType ← normExpr c.resultType
let alts ← c.alts.mapMonoM fun alt => do
match alt with
| .alt _ ps k => withNewScope do
let ps ← (← getSubst).applyToParams ps
return alt.updateAlt! ps (← go k)
return alt.updateAlt! (← normParams ps) (← go k)
| .default k => withNewScope do return alt.updateCode (← go k)
return code.updateCases! resultType discr alts
| .return fvarId => return code.updateReturn! ((← getSubst).applyToFVar fvarId)
| .jmp fvarId args => return code.updateJmp! ((← getSubst).applyToFVar fvarId) (← args.mapMonoM fun arg => return (← getSubst).applyToExpr arg)
| .return fvarId => return code.updateReturn! (← normFVar fvarId)
| .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normExprs args)
| .unreach .. => return code
/--

View file

@ -58,7 +58,7 @@ So, when inlining we often want to replace a free variable with a type or type f
-/
abbrev FVarSubst := Std.HashMap FVarId Expr
partial def FVarSubst.applyToExpr (s : FVarSubst) (e : Expr) : Expr :=
private partial def normExprImp (s : FVarSubst) (e : Expr) : Expr :=
go e
where
go (e : Expr) : Expr :=
@ -75,21 +75,40 @@ where
else
e
def FVarSubst.applyToFVar (s : FVarSubst) (fvarId : FVarId) : FVarId :=
private def normFVarImp (s : FVarSubst) (fvarId : FVarId) : FVarId :=
match s.find? fvarId with
| some (.fvar fvarId') => fvarId'
| some _ => panic! "invalid LCNF substitution of free variable with expression"
| none => fvarId
class MonadFVarSubst (m : Type → Type) where
getSubst : m FVarSubst
modifySubst : (FVarSubst → FVarSubst) → m Unit
export MonadFVarSubst (getSubst modifySubst)
@[inline] def addSubst [MonadFVarSubst m] (fvarId : FVarId) (e : Expr) : m Unit :=
modifySubst fun s => s.insert fvarId e
@[inline] def addFVarSubst [MonadFVarSubst m] (fvarId fvarId' : FVarId) : m Unit :=
addSubst fvarId (.fvar fvarId')
@[inline] def normFVar [MonadFVarSubst m] [Monad m] (fvarId : FVarId) : m FVarId :=
return normFVarImp (← getSubst) fvarId
@[inline] def normExpr [MonadFVarSubst m] [Monad m] (e : Expr) : m Expr :=
return normExprImp (← getSubst) e
def normExprs [MonadFVarSubst m] [Monad m] (es : Array Expr) : m (Array Expr) :=
es.mapMonoM normExpr
namespace Internalize
abbrev M := StateRefT FVarSubst CompilerM
@[inline] private abbrev translateFVarId (fvarId : FVarId) : M FVarId := do
return (← get).applyToFVar fvarId
@[inline] private abbrev translateExpr (e : Expr) : M Expr :=
return (← get).applyToExpr e
instance : MonadFVarSubst M where
getSubst := get
modifySubst := modify
private def mkNewFVarId (fvarId : FVarId) : M FVarId := do
let fvarId' ← Lean.mkFreshFVarId
@ -97,7 +116,7 @@ private def mkNewFVarId (fvarId : FVarId) : M FVarId := do
return fvarId'
private def addParam (p : Param) : M Param := do
let type ← translateExpr p.type
let type ← normExpr p.type
let fvarId ← mkNewFVarId p.fvarId
modifyLCtx fun lctx => lctx.addLocalDecl fvarId p.binderName type
return { p with fvarId, type }
@ -105,7 +124,7 @@ private def addParam (p : Param) : M Param := do
mutual
partial def internalizeFunDecl (decl : FunDecl) : M FunDecl := do
let type ← translateExpr decl.type
let type ← normExpr decl.type
let params ← decl.params.mapM addParam
let value ← internalizeCode decl.value
let fvarId ← mkNewFVarId decl.fvarId
@ -116,8 +135,8 @@ partial def internalizeFunDecl (decl : FunDecl) : M FunDecl := do
partial def internalizeCode (code : Code) : M Code := do
match code with
| .let decl k =>
let type ← translateExpr decl.type
let value ← translateExpr decl.value
let type ← normExpr decl.type
let value ← normExpr decl.value
let fvarId ← mkNewFVarId decl.fvarId
modifyLCtx fun lctx => lctx.addLetDecl fvarId decl.binderName type value
let k ← internalizeCode k
@ -126,11 +145,11 @@ partial def internalizeCode (code : Code) : M Code := do
return .fun (← internalizeFunDecl decl) (← internalizeCode k)
| .jp decl k =>
return .jp (← internalizeFunDecl decl) (← internalizeCode k)
| .return fvarId => return .return (← translateFVarId fvarId)
| .jmp fvarId args => return .jmp (← translateFVarId fvarId) (← args.mapM translateExpr)
| .unreach type => return .unreach (← translateExpr type)
| .return fvarId => return .return (← normFVar fvarId)
| .jmp fvarId args => return .jmp (← normFVar fvarId) (← args.mapM normExpr)
| .unreach type => return .unreach (← normExpr type)
| .cases c =>
let discr ← translateFVarId c.discr
let discr ← normFVar c.discr
let alts ← c.alts.mapM fun
| .alt ctorName params k => return .alt ctorName (← params.mapM addParam) (← internalizeCode k)
| .default k => return .default (← internalizeCode k)
@ -151,7 +170,7 @@ def Decl.internalize (decl : Decl) (s : FVarSubst := {}): CompilerM Decl :=
go decl |>.run' s
where
go (decl : Decl) : M Decl := do
let type ← translateExpr decl.type
let type ← normExpr decl.type
let params ← decl.params.mapM addParam
let value ← internalizeCode decl.value
return { decl with type, params, value }
@ -212,14 +231,14 @@ abbrev FunDeclCore.update' (decl : FunDecl) (type : Expr) (value : Code) : Compi
abbrev FunDeclCore.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl :=
decl.update decl.type decl.params value
def FVarSubst.applyToParam (s : FVarSubst) (p : Param) : CompilerM Param :=
p.update (s.applyToExpr p.type)
@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (p : Param) : m Param := do
p.update (← normExpr p.type)
def FVarSubst.applyToParams (s : FVarSubst) (ps : Array Param) : CompilerM (Array Param) :=
ps.mapMonoM s.applyToParam
def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (ps : Array Param) : m (Array Param) :=
ps.mapMonoM normParam
def FVarSubst.applyToLetDecl (s : FVarSubst) (decl : LetDecl) : CompilerM LetDecl :=
decl.update (s.applyToExpr decl.type) (s.applyToExpr decl.value)
def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (decl : LetDecl) : m LetDecl := do
decl.update (← normExpr decl.type) (← normExpr decl.value)
def mkFreshBinderName (binderName := `_x): CompilerM Name := do
let declName := .num binderName (← get).nextIdx

View file

@ -132,6 +132,10 @@ structure State where
abbrev SimpM := ReaderT Context $ StateRefT State CompilerM
instance : MonadFVarSubst SimpM where
getSubst := return (← get).subst
modifySubst f := modify fun s => { s with subst := f s.subst }
partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit :=
go code
where
@ -202,8 +206,9 @@ partial def simp (code : Code) : SimpM Code := do
incVisited
match code with
| .return fvarId =>
let fvarId ← normFVar fvarId
markUsedFVar fvarId
return code
return code.updateReturn! fvarId
| .unreach .. => return code
| _ => return code