diff --git a/src/Lean/Compiler/LCNF/MonadScope.lean b/src/Lean/Compiler/LCNF/MonadScope.lean new file mode 100644 index 0000000000..b157cdeb9f --- /dev/null +++ b/src/Lean/Compiler/LCNF/MonadScope.lean @@ -0,0 +1,40 @@ +/- +Copyright (c) 2022 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +import Lean.Compiler.LCNF.Basic + +namespace Lean.Compiler.LCNF + +abbrev Scope := FVarIdSet + +class MonadScope (m : Type → Type) where + getScope : m Scope + withScope : (Scope → Scope) → m α → m α + +export MonadScope (getScope withScope) + +abbrev ScopeT (m : Type → Type) := ReaderT Scope m + +instance [Monad m] : MonadScope (ScopeT m) where + getScope := read + withScope := withReader + +instance (m n) [MonadLift m n] [MonadFunctor m n] [MonadScope m] : MonadScope n where + getScope := liftM (getScope : m _) + withScope f := monadMap (m := m) (withScope f) + +def inScope [MonadScope m] [Monad m] (fvarId : FVarId) : m Bool := + return (← getScope).contains fvarId + +@[inline] def withParams [MonadScope m] [Monad m] (ps : Array Param) (x : m α) : m α := + withScope (fun s => ps.foldl (init := s) fun s p => s.insert p.fvarId) x + +@[inline] def withFVar [MonadScope m] [Monad m] (fvarId : FVarId) (x : m α) : m α := + withScope (fun s => s.insert fvarId) x + +@[inline] def withNewScope [MonadScope m] [Monad m] (x : m α) : m α := do + withScope (fun _ => {}) x + +end Lean.Compiler.LCNF \ No newline at end of file diff --git a/src/Lean/Compiler/LCNF/Specialize.lean b/src/Lean/Compiler/LCNF/Specialize.lean index 0973ff9099..0d30d6221c 100644 --- a/src/Lean/Compiler/LCNF/Specialize.lean +++ b/src/Lean/Compiler/LCNF/Specialize.lean @@ -11,6 +11,7 @@ import Lean.Compiler.LCNF.PrettyPrinter import Lean.Compiler.LCNF.ToExpr import Lean.Compiler.LCNF.Level import Lean.Compiler.LCNF.PhaseExt +import Lean.Compiler.LCNF.MonadScope namespace Lean.Compiler.LCNF namespace Specialize @@ -46,11 +47,9 @@ structure State where abbrev SpecializeM := ReaderT Context $ StateRefT State CompilerM -@[inline] def withParams (ps : Array Param) (x : SpecializeM α) : SpecializeM α := - withReader (fun ctx => { ctx with scope := ps.foldl (init := ctx.scope) fun s p => s.insert p.fvarId }) x - -@[inline] def withFVar (fvarId : FVarId) (x : SpecializeM α) : SpecializeM α := - withReader (fun ctx => { ctx with scope := ctx.scope.insert fvarId }) x +instance : MonadScope SpecializeM where + getScope := return (← read).scope + withScope f := withReader (fun ctx => { ctx with scope := f ctx.scope }) /-- Return `true` if `e` is a ground term. That is, @@ -190,7 +189,7 @@ mutual partial def collectFVar (fvarId : FVarId) : CollectorM Unit := do unless (← get).visited.contains fvarId do markVisited fvarId - if (← read).scope.contains fvarId then + if (← inScope fvarId) then /- We only collect the variables in the scope of the function application being specialized. -/ if let some funDecl ← findFunDecl? fvarId then collectFunDecl funDecl