From 7b161d33d17bc18c5f95df0210c27dca17929b49 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 29 Aug 2022 08:24:17 -0700 Subject: [PATCH] refactor: add `MonadFVarSubst` class --- src/Lean/Compiler/LCNF/CSE.lean | 22 ++++++---- src/Lean/Compiler/LCNF/CompilerM.lean | 63 +++++++++++++++++---------- src/Lean/Compiler/LCNF/Simp.lean | 7 ++- 3 files changed, 60 insertions(+), 32 deletions(-) diff --git a/src/Lean/Compiler/LCNF/CSE.lean b/src/Lean/Compiler/LCNF/CSE.lean index b7a4d1193c..a31400f514 100644 --- a/src/Lean/Compiler/LCNF/CSE.lean +++ b/src/Lean/Compiler/LCNF/CSE.lean @@ -16,8 +16,13 @@ structure State where map : Std.PHashMap Expr FVarId := {} subst : FVarSubst := {} + abbrev M := StateRefT State CompilerM +instance : MonadFVarSubst M where + getSubst := return (← get).subst + modifySubst f := modify fun s => { s with subst := f s.subst } + @[inline] def getSubst : M FVarSubst := return (← get).subst @@ -42,15 +47,15 @@ partial def Code.cse (code : Code) : CompilerM Code := go code |>.run' {} where goFunDecl (decl : FunDecl) : M FunDecl := do - let type := (← getSubst).applyToExpr decl.type - let params ← (← getSubst).applyToParams decl.params + let type ← normExpr decl.type + let params ← normParams decl.params let value ← withNewScope do go decl.value decl.update type params value go (code : Code) : M Code := do match code with | .let decl k => - let decl ← (← getSubst).applyToLetDecl decl + let decl ← normLetDecl decl if decl.pure then -- We only apply CSE to pure code match (← get).map.find? decl.value with @@ -80,17 +85,16 @@ where -/ return code.updateFun! decl (← go k) | .cases c => - let discr := (← getSubst).applyToFVar c.discr - let resultType := (← getSubst).applyToExpr c.resultType + let discr ← normFVar c.discr + let resultType ← normExpr c.resultType let alts ← c.alts.mapMonoM fun alt => do match alt with | .alt _ ps k => withNewScope do - let ps ← (← getSubst).applyToParams ps - return alt.updateAlt! ps (← go k) + return alt.updateAlt! (← normParams ps) (← go k) | .default k => withNewScope do return alt.updateCode (← go k) return code.updateCases! resultType discr alts - | .return fvarId => return code.updateReturn! ((← getSubst).applyToFVar fvarId) - | .jmp fvarId args => return code.updateJmp! ((← getSubst).applyToFVar fvarId) (← args.mapMonoM fun arg => return (← getSubst).applyToExpr arg) + | .return fvarId => return code.updateReturn! (← normFVar fvarId) + | .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normExprs args) | .unreach .. => return code /-- diff --git a/src/Lean/Compiler/LCNF/CompilerM.lean b/src/Lean/Compiler/LCNF/CompilerM.lean index ee32be99f9..7e7f73dc24 100644 --- a/src/Lean/Compiler/LCNF/CompilerM.lean +++ b/src/Lean/Compiler/LCNF/CompilerM.lean @@ -58,7 +58,7 @@ So, when inlining we often want to replace a free variable with a type or type f -/ abbrev FVarSubst := Std.HashMap FVarId Expr -partial def FVarSubst.applyToExpr (s : FVarSubst) (e : Expr) : Expr := +private partial def normExprImp (s : FVarSubst) (e : Expr) : Expr := go e where go (e : Expr) : Expr := @@ -75,21 +75,40 @@ where else e -def FVarSubst.applyToFVar (s : FVarSubst) (fvarId : FVarId) : FVarId := +private def normFVarImp (s : FVarSubst) (fvarId : FVarId) : FVarId := match s.find? fvarId with | some (.fvar fvarId') => fvarId' | some _ => panic! "invalid LCNF substitution of free variable with expression" | none => fvarId +class MonadFVarSubst (m : Type → Type) where + getSubst : m FVarSubst + modifySubst : (FVarSubst → FVarSubst) → m Unit + +export MonadFVarSubst (getSubst modifySubst) + +@[inline] def addSubst [MonadFVarSubst m] (fvarId : FVarId) (e : Expr) : m Unit := + modifySubst fun s => s.insert fvarId e + +@[inline] def addFVarSubst [MonadFVarSubst m] (fvarId fvarId' : FVarId) : m Unit := + addSubst fvarId (.fvar fvarId') + +@[inline] def normFVar [MonadFVarSubst m] [Monad m] (fvarId : FVarId) : m FVarId := + return normFVarImp (← getSubst) fvarId + +@[inline] def normExpr [MonadFVarSubst m] [Monad m] (e : Expr) : m Expr := + return normExprImp (← getSubst) e + +def normExprs [MonadFVarSubst m] [Monad m] (es : Array Expr) : m (Array Expr) := + es.mapMonoM normExpr + namespace Internalize abbrev M := StateRefT FVarSubst CompilerM -@[inline] private abbrev translateFVarId (fvarId : FVarId) : M FVarId := do - return (← get).applyToFVar fvarId - -@[inline] private abbrev translateExpr (e : Expr) : M Expr := - return (← get).applyToExpr e +instance : MonadFVarSubst M where + getSubst := get + modifySubst := modify private def mkNewFVarId (fvarId : FVarId) : M FVarId := do let fvarId' ← Lean.mkFreshFVarId @@ -97,7 +116,7 @@ private def mkNewFVarId (fvarId : FVarId) : M FVarId := do return fvarId' private def addParam (p : Param) : M Param := do - let type ← translateExpr p.type + let type ← normExpr p.type let fvarId ← mkNewFVarId p.fvarId modifyLCtx fun lctx => lctx.addLocalDecl fvarId p.binderName type return { p with fvarId, type } @@ -105,7 +124,7 @@ private def addParam (p : Param) : M Param := do mutual partial def internalizeFunDecl (decl : FunDecl) : M FunDecl := do - let type ← translateExpr decl.type + let type ← normExpr decl.type let params ← decl.params.mapM addParam let value ← internalizeCode decl.value let fvarId ← mkNewFVarId decl.fvarId @@ -116,8 +135,8 @@ partial def internalizeFunDecl (decl : FunDecl) : M FunDecl := do partial def internalizeCode (code : Code) : M Code := do match code with | .let decl k => - let type ← translateExpr decl.type - let value ← translateExpr decl.value + let type ← normExpr decl.type + let value ← normExpr decl.value let fvarId ← mkNewFVarId decl.fvarId modifyLCtx fun lctx => lctx.addLetDecl fvarId decl.binderName type value let k ← internalizeCode k @@ -126,11 +145,11 @@ partial def internalizeCode (code : Code) : M Code := do return .fun (← internalizeFunDecl decl) (← internalizeCode k) | .jp decl k => return .jp (← internalizeFunDecl decl) (← internalizeCode k) - | .return fvarId => return .return (← translateFVarId fvarId) - | .jmp fvarId args => return .jmp (← translateFVarId fvarId) (← args.mapM translateExpr) - | .unreach type => return .unreach (← translateExpr type) + | .return fvarId => return .return (← normFVar fvarId) + | .jmp fvarId args => return .jmp (← normFVar fvarId) (← args.mapM normExpr) + | .unreach type => return .unreach (← normExpr type) | .cases c => - let discr ← translateFVarId c.discr + let discr ← normFVar c.discr let alts ← c.alts.mapM fun | .alt ctorName params k => return .alt ctorName (← params.mapM addParam) (← internalizeCode k) | .default k => return .default (← internalizeCode k) @@ -151,7 +170,7 @@ def Decl.internalize (decl : Decl) (s : FVarSubst := {}): CompilerM Decl := go decl |>.run' s where go (decl : Decl) : M Decl := do - let type ← translateExpr decl.type + let type ← normExpr decl.type let params ← decl.params.mapM addParam let value ← internalizeCode decl.value return { decl with type, params, value } @@ -212,14 +231,14 @@ abbrev FunDeclCore.update' (decl : FunDecl) (type : Expr) (value : Code) : Compi abbrev FunDeclCore.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl := decl.update decl.type decl.params value -def FVarSubst.applyToParam (s : FVarSubst) (p : Param) : CompilerM Param := - p.update (s.applyToExpr p.type) +@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (p : Param) : m Param := do + p.update (← normExpr p.type) -def FVarSubst.applyToParams (s : FVarSubst) (ps : Array Param) : CompilerM (Array Param) := - ps.mapMonoM s.applyToParam +def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (ps : Array Param) : m (Array Param) := + ps.mapMonoM normParam -def FVarSubst.applyToLetDecl (s : FVarSubst) (decl : LetDecl) : CompilerM LetDecl := - decl.update (s.applyToExpr decl.type) (s.applyToExpr decl.value) +def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m] (decl : LetDecl) : m LetDecl := do + decl.update (← normExpr decl.type) (← normExpr decl.value) def mkFreshBinderName (binderName := `_x): CompilerM Name := do let declName := .num binderName (← get).nextIdx diff --git a/src/Lean/Compiler/LCNF/Simp.lean b/src/Lean/Compiler/LCNF/Simp.lean index 5dd00a9173..c460aa257d 100644 --- a/src/Lean/Compiler/LCNF/Simp.lean +++ b/src/Lean/Compiler/LCNF/Simp.lean @@ -132,6 +132,10 @@ structure State where abbrev SimpM := ReaderT Context $ StateRefT State CompilerM +instance : MonadFVarSubst SimpM where + getSubst := return (← get).subst + modifySubst f := modify fun s => { s with subst := f s.subst } + partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit := go code where @@ -202,8 +206,9 @@ partial def simp (code : Code) : SimpM Code := do incVisited match code with | .return fvarId => + let fvarId ← normFVar fvarId markUsedFVar fvarId - return code + return code.updateReturn! fvarId | .unreach .. => return code | _ => return code