refactor: add MonadScope class
We are going to use it to implement the lambda lifting pass too.
This commit is contained in:
parent
1b8e310ada
commit
e7a36f32f1
2 changed files with 45 additions and 6 deletions
40
src/Lean/Compiler/LCNF/MonadScope.lean
Normal file
40
src/Lean/Compiler/LCNF/MonadScope.lean
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Compiler.LCNF.Basic
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
abbrev Scope := FVarIdSet
|
||||
|
||||
class MonadScope (m : Type → Type) where
|
||||
getScope : m Scope
|
||||
withScope : (Scope → Scope) → m α → m α
|
||||
|
||||
export MonadScope (getScope withScope)
|
||||
|
||||
abbrev ScopeT (m : Type → Type) := ReaderT Scope m
|
||||
|
||||
instance [Monad m] : MonadScope (ScopeT m) where
|
||||
getScope := read
|
||||
withScope := withReader
|
||||
|
||||
instance (m n) [MonadLift m n] [MonadFunctor m n] [MonadScope m] : MonadScope n where
|
||||
getScope := liftM (getScope : m _)
|
||||
withScope f := monadMap (m := m) (withScope f)
|
||||
|
||||
def inScope [MonadScope m] [Monad m] (fvarId : FVarId) : m Bool :=
|
||||
return (← getScope).contains fvarId
|
||||
|
||||
@[inline] def withParams [MonadScope m] [Monad m] (ps : Array Param) (x : m α) : m α :=
|
||||
withScope (fun s => ps.foldl (init := s) fun s p => s.insert p.fvarId) x
|
||||
|
||||
@[inline] def withFVar [MonadScope m] [Monad m] (fvarId : FVarId) (x : m α) : m α :=
|
||||
withScope (fun s => s.insert fvarId) x
|
||||
|
||||
@[inline] def withNewScope [MonadScope m] [Monad m] (x : m α) : m α := do
|
||||
withScope (fun _ => {}) x
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
@ -11,6 +11,7 @@ import Lean.Compiler.LCNF.PrettyPrinter
|
|||
import Lean.Compiler.LCNF.ToExpr
|
||||
import Lean.Compiler.LCNF.Level
|
||||
import Lean.Compiler.LCNF.PhaseExt
|
||||
import Lean.Compiler.LCNF.MonadScope
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace Specialize
|
||||
|
|
@ -46,11 +47,9 @@ structure State where
|
|||
|
||||
abbrev SpecializeM := ReaderT Context $ StateRefT State CompilerM
|
||||
|
||||
@[inline] def withParams (ps : Array Param) (x : SpecializeM α) : SpecializeM α :=
|
||||
withReader (fun ctx => { ctx with scope := ps.foldl (init := ctx.scope) fun s p => s.insert p.fvarId }) x
|
||||
|
||||
@[inline] def withFVar (fvarId : FVarId) (x : SpecializeM α) : SpecializeM α :=
|
||||
withReader (fun ctx => { ctx with scope := ctx.scope.insert fvarId }) x
|
||||
instance : MonadScope SpecializeM where
|
||||
getScope := return (← read).scope
|
||||
withScope f := withReader (fun ctx => { ctx with scope := f ctx.scope })
|
||||
|
||||
/--
|
||||
Return `true` if `e` is a ground term. That is,
|
||||
|
|
@ -190,7 +189,7 @@ mutual
|
|||
partial def collectFVar (fvarId : FVarId) : CollectorM Unit := do
|
||||
unless (← get).visited.contains fvarId do
|
||||
markVisited fvarId
|
||||
if (← read).scope.contains fvarId then
|
||||
if (← inScope fvarId) then
|
||||
/- We only collect the variables in the scope of the function application being specialized. -/
|
||||
if let some funDecl ← findFunDecl? fvarId then
|
||||
collectFunDecl funDecl
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue