refactor: add MonadScope class

We are going to use it to implement the lambda lifting pass too.
This commit is contained in:
Leonardo de Moura 2022-10-07 14:59:59 -07:00
parent 1b8e310ada
commit e7a36f32f1
2 changed files with 45 additions and 6 deletions

View file

@ -0,0 +1,40 @@
/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Compiler.LCNF.Basic
namespace Lean.Compiler.LCNF
abbrev Scope := FVarIdSet
class MonadScope (m : Type → Type) where
getScope : m Scope
withScope : (Scope → Scope) → m α → m α
export MonadScope (getScope withScope)
abbrev ScopeT (m : Type → Type) := ReaderT Scope m
instance [Monad m] : MonadScope (ScopeT m) where
getScope := read
withScope := withReader
instance (m n) [MonadLift m n] [MonadFunctor m n] [MonadScope m] : MonadScope n where
getScope := liftM (getScope : m _)
withScope f := monadMap (m := m) (withScope f)
def inScope [MonadScope m] [Monad m] (fvarId : FVarId) : m Bool :=
return (← getScope).contains fvarId
@[inline] def withParams [MonadScope m] [Monad m] (ps : Array Param) (x : m α) : m α :=
withScope (fun s => ps.foldl (init := s) fun s p => s.insert p.fvarId) x
@[inline] def withFVar [MonadScope m] [Monad m] (fvarId : FVarId) (x : m α) : m α :=
withScope (fun s => s.insert fvarId) x
@[inline] def withNewScope [MonadScope m] [Monad m] (x : m α) : m α := do
withScope (fun _ => {}) x
end Lean.Compiler.LCNF

View file

@ -11,6 +11,7 @@ import Lean.Compiler.LCNF.PrettyPrinter
import Lean.Compiler.LCNF.ToExpr
import Lean.Compiler.LCNF.Level
import Lean.Compiler.LCNF.PhaseExt
import Lean.Compiler.LCNF.MonadScope
namespace Lean.Compiler.LCNF
namespace Specialize
@ -46,11 +47,9 @@ structure State where
abbrev SpecializeM := ReaderT Context $ StateRefT State CompilerM
@[inline] def withParams (ps : Array Param) (x : SpecializeM α) : SpecializeM α :=
withReader (fun ctx => { ctx with scope := ps.foldl (init := ctx.scope) fun s p => s.insert p.fvarId }) x
@[inline] def withFVar (fvarId : FVarId) (x : SpecializeM α) : SpecializeM α :=
withReader (fun ctx => { ctx with scope := ctx.scope.insert fvarId }) x
instance : MonadScope SpecializeM where
getScope := return (← read).scope
withScope f := withReader (fun ctx => { ctx with scope := f ctx.scope })
/--
Return `true` if `e` is a ground term. That is,
@ -190,7 +189,7 @@ mutual
partial def collectFVar (fvarId : FVarId) : CollectorM Unit := do
unless (← get).visited.contains fvarId do
markVisited fvarId
if (← read).scope.contains fvarId then
if (← inScope fvarId) then
/- We only collect the variables in the scope of the function application being specialized. -/
if let some funDecl ← findFunDecl? fvarId then
collectFunDecl funDecl