refactor: State ==> StateM

This commit is contained in:
Leonardo de Moura 2019-11-05 07:56:19 -08:00
parent cb554d6473
commit 5bb5ef6296
14 changed files with 42 additions and 42 deletions

View file

@ -128,7 +128,7 @@ fun s =>
instance {ε σ σ'} : MonadStateAdapter σ σ' (EState ε σ) (EState ε σ') :=
⟨fun σ'' α => EState.adaptState⟩
@[inline] def fromState {ε σ α : Type} (x : State σ α) : EState ε σ α :=
@[inline] def fromStateMσ α : Type} (x : StateM σ α) : EState ε σ α :=
fun s =>
match x.run s with
| (a, s') => EState.Result.ok a s'

View file

@ -21,7 +21,7 @@ x s
@[inline] def StateT.run' {σ : Type u} {m : Type u → Type v} [Functor m] {α : Type u} (x : StateT σ m α) (s : σ) : m α :=
Prod.fst <$> x s
@[reducible] def State (σ α : Type u) : Type u := StateT σ Id α
@[reducible] def StateM (σ α : Type u) : Type u := StateT σ Id α
namespace StateT
section

View file

@ -164,7 +164,7 @@ def hasAssignedMVar (mctx : σ) : Expr → Bool
| Expr.proj _ _ e => e.hasMVar && hasAssignedMVar e
| Expr.mvar mvarId => isExprAssigned mctx mvarId
partial def instantiateLevelMVars : Level → State σ Level
partial def instantiateLevelMVars : Level → StateM σ Level
| lvl@(Level.succ lvl₁) => do lvl₁ ← instantiateLevelMVars lvl₁; pure (Level.updateSucc! lvl lvl₁)
| lvl@(Level.max lvl₁ lvl₂) => do lvl₁ ← instantiateLevelMVars lvl₁; lvl₂ ← instantiateLevelMVars lvl₂; pure (Level.updateMax! lvl lvl₁ lvl₂)
| lvl@(Level.imax lvl₁ lvl₂) => do lvl₁ ← instantiateLevelMVars lvl₁; lvl₂ ← instantiateLevelMVars lvl₂; pure (Level.updateIMax! lvl lvl₁ lvl₂)
@ -181,7 +181,7 @@ partial def instantiateLevelMVars : Level → State σ Level
| lvl => pure lvl
namespace InstantiateExprMVars
abbrev M (σ : Type) := State (WithHashMapCache Expr Expr σ)
abbrev M (σ : Type) := StateM (WithHashMapCache Expr Expr σ)
@[inline] def instantiateLevelMVars (lvl : Level) : M σ Level :=
WithHashMapCache.fromState $ AbstractMetavarContext.instantiateLevelMVars lvl
@ -281,13 +281,13 @@ partial def main : Expr → M σ Expr
end InstantiateExprMVars
def instantiateMVars (e : Expr) : State σ Expr :=
def instantiateMVars (e : Expr) : StateM σ Expr :=
if !e.hasMVar then pure e
else WithHashMapCache.toState $ InstantiateExprMVars.main e
namespace DependsOn
abbrev M := State ExprSet
abbrev M := StateM ExprSet
@[inline] def visit (main : Expr → M Bool) (e : Expr) : M Bool :=
if !e.hasMVar && !e.hasFVar then

View file

@ -69,7 +69,7 @@ ps.map $ fun p => { borrow := p.ty.isObj, .. p }
def initBorrowIfNotExported (exported : Bool) (ps : Array Param) : Array Param :=
if exported then ps else initBorrow ps
partial def visitFnBody (fnid : FunId) : FnBody → State ParamMap Unit
partial def visitFnBody (fnid : FunId) : FnBody → StateM ParamMap Unit
| FnBody.jdecl j xs v b => do
modify $ fun m => m.insert (Key.jp fnid j) (initBorrow xs);
visitFnBody v;
@ -79,7 +79,7 @@ partial def visitFnBody (fnid : FunId) : FnBody → State ParamMap Unit
let (instr, b) := e.split;
visitFnBody b
def visitDecls (env : Environment) (decls : Array Decl) : State ParamMap Unit :=
def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit :=
decls.forM $ fun decl => match decl with
| Decl.fdecl f xs _ b => do
let exported := isExport env f;
@ -137,7 +137,7 @@ structure BorrowInfState :=
(modifiedOwned : Bool := false)
(modifiedParamMap : Bool := false)
abbrev M := ReaderT BorrowInfCtx (State BorrowInfState)
abbrev M := ReaderT BorrowInfCtx (StateM BorrowInfState)
def markModifiedParamMap : M Unit :=
modify $ fun s => { modifiedParamMap := true, .. s }

View file

@ -41,7 +41,7 @@ def isBoxedName : Name → Bool
| Name.mkString _ "_boxed" => true
| _ => false
abbrev N := State Nat
abbrev N := StateM Nat
private def mkFresh : N VarId :=
modifyGet $ fun n => ({ idx := n }, n + 1)

View file

@ -126,7 +126,7 @@ structure InterpState :=
(assignments : Array Assignment)
(funVals : PArray Value) -- we take snapshots during fixpoint computations
abbrev M := ReaderT InterpContext (State InterpState)
abbrev M := ReaderT InterpContext (StateM InterpState)
open Value

View file

@ -20,7 +20,7 @@ match b with
namespace UsesLeanNamespace
abbrev M := ReaderT Environment (State NameSet)
abbrev M := ReaderT Environment (StateM NameSet)
def leanNameSpacePrefix := `Lean
@ -57,7 +57,7 @@ def usesLeanNamespace (env : Environment) : Decl → Bool
namespace CollectUsedDecls
abbrev M := ReaderT Environment (State NameSet)
abbrev M := ReaderT Environment (StateM NameSet)
@[inline] def collect (f : FunId) : M Unit :=
modify $ fun s => s.insert f

View file

@ -136,7 +136,7 @@ mask.foldl
| none => b)
b
abbrev M := ReaderT Context (State Nat)
abbrev M := ReaderT Context (StateM Nat)
def mkFresh : M VarId :=
modifyGet $ fun n => ({ idx := n }, n + 1)

View file

@ -38,7 +38,7 @@ namespace IsLive
Remark: we don't need to track local join points because we assume there is
no variable or join point shadowing in our IR.
-/
abbrev M := State LocalContext
abbrev M := StateM LocalContext
@[inline] def visitVar (w : Index) (x : VarId) : M Bool := pure (HasIndex.visitVar w x)
@[inline] def visitJP (w : Index) (x : JoinPointId) : M Bool := pure (HasIndex.visitJP w x)

View file

@ -76,7 +76,7 @@ def normExpr : Expr → M Expr
| Expr.isTaggedPtr x, m => Expr.isTaggedPtr (normVar x m)
| e@(Expr.lit v), m => e
abbrev N := ReaderT IndexRenaming (State Nat)
abbrev N := ReaderT IndexRenaming (StateM Nat)
@[inline] def withVar {α : Type} (x : VarId) (k : VarId → N α) : N α :=
fun m => do

View file

@ -61,13 +61,13 @@ structure WithHashMapCache (α β σ : Type) [HasBeq α] [Hashable α] :=
namespace WithHashMapCache
@[inline] def getCache {α β σ : Type} [HasBeq α] [Hashable α] : State (WithHashMapCache α β σ) (HashMap α β) :=
@[inline] def getCache {α β σ : Type} [HasBeq α] [Hashable α] : StateM (WithHashMapCache α β σ) (HashMap α β) :=
do s ← get; pure s.cache
@[inline] def modifyCache {α β σ : Type} [HasBeq α] [Hashable α] (f : HashMap α β → HashMap α β) : State (WithHashMapCache α β σ) Unit :=
@[inline] def modifyCache {α β σ : Type} [HasBeq α] [Hashable α] (f : HashMap α β → HashMap α β) : StateM (WithHashMapCache α β σ) Unit :=
modify $ fun s => { cache := f s.cache, .. s }
instance stateAdapter (α β σ : Type) [HasBeq α] [Hashable α] : MonadHashMapCacheAdapter α β (State (WithHashMapCache α β σ)) :=
instance stateAdapter (α β σ : Type) [HasBeq α] [Hashable α] : MonadHashMapCacheAdapter α β (StateM (WithHashMapCache α β σ)) :=
{ getCache := WithHashMapCache.getCache,
modifyCache := WithHashMapCache.modifyCache }
@ -81,13 +81,13 @@ instance estateAdapter (α β ε σ : Type) [HasBeq α] [Hashable α] : MonadHas
{ getCache := WithHashMapCache.getCacheE,
modifyCache := WithHashMapCache.modifyCacheE }
@[inline] def fromState {α β σ δ : Type} [HasBeq α] [Hashable α] (x : State σ δ) : State (WithHashMapCache α β σ) δ :=
@[inline] def fromState {α β σ δ : Type} [HasBeq α] [Hashable α] (x : StateM σ δ) : StateM (WithHashMapCache α β σ) δ :=
adaptState
(fun (s : WithHashMapCache α β σ) => (s.state, s.cache))
(fun (s : σ) (cache : HashMap α β) => { state := s, cache := cache })
x
@[inline] def toState {α β σ δ : Type} [HasBeq α] [Hashable α] (x : State (WithHashMapCache α β σ) δ) : State σ δ :=
@[inline] def toState {α β σ δ : Type} [HasBeq α] [Hashable α] (x : StateM (WithHashMapCache α β σ) δ) : StateM σ δ :=
adaptState'
(fun (s : σ) => ({ state := s } : WithHashMapCache α β σ))
(fun (s : WithHashMapCache α β σ) => s.state)

View file

@ -207,7 +207,7 @@ private def updateInfo : SourceInfo → String.Pos → SourceInfo
/- Remark: the State `String.Pos` is the `SourceInfo.trailing.stopPos` of the previous token,
or the beginning of the String. -/
@[inline]
private def updateLeadingAux {α} : Syntax α → State String.Pos (Option (Syntax α))
private def updateLeadingAux {α} : Syntax α → StateM String.Pos (Option (Syntax α))
| atom (some info) val => do
last ← get;
set info.trailing.stopPos;

View file

@ -44,16 +44,16 @@ def eMetaIdx : Expr → Option Nat
def eIsMeta (e : Expr) : Bool := (eMetaIdx e).toBool
def eNewMeta (type : Expr) : State Context Expr :=
def eNewMeta (type : Expr) : StateM Context Expr :=
do ctx ← get;
let idx := ctx.eTypes.size;
set { eTypes := ctx.eTypes.push type, eVals := ctx.eVals.push none, .. ctx };
pure $ Expr.mvar (mkNumName metaPrefix idx)
def eLookupIdx (idx : Nat) : State Context (Option Expr) :=
def eLookupIdx (idx : Nat) : StateM Context (Option Expr) :=
do ctx ← get; pure $ ctx.eVals.get! idx
partial def eShallowInstantiate : Expr → State Context Expr
partial def eShallowInstantiate : Expr → StateM Context Expr
| e =>
match eMetaIdx e with
| some idx => get >>= λ ctx =>
@ -62,7 +62,7 @@ partial def eShallowInstantiate : Expr → State Context Expr
| some v => eShallowInstantiate v
| none => pure e
def eInferIdx (idx : Nat) : State Context Expr :=
def eInferIdx (idx : Nat) : StateM Context Expr :=
do ctx ← get; pure $ ctx.eTypes.get! idx
def eInfer (ctx : Context) (mvar : Expr) : Expr :=
@ -70,10 +70,10 @@ match eMetaIdx mvar with
| some idx => ctx.eTypes.get! idx
| none => panic! "eInfer called on non-(tmp-)mvar"
def eAssignIdx (idx : Nat) (e : Expr) : State Context Unit :=
def eAssignIdx (idx : Nat) (e : Expr) : StateM Context Unit :=
modify $ λ ctx => { eVals := ctx.eVals.set idx (some e) .. ctx }
def eAssign (mvar : Expr) (e : Expr) : State Context Unit :=
def eAssign (mvar : Expr) (e : Expr) : StateM Context Unit :=
match eMetaIdx mvar with
| some idx => modify $ λ ctx => { eVals := ctx.eVals.set idx (some e) .. ctx }
| _ => panic! "eAssign called on non-(tmp-)mvar"
@ -102,16 +102,16 @@ def uMetaIdx : Level → Option Nat
def uIsMeta (l : Level) : Bool := (uMetaIdx l).toBool
def uNewMeta : State Context Level :=
def uNewMeta : StateM Context Level :=
do ctx ← get;
let idx := ctx.uVals.size;
set { uVals := ctx.uVals.push none, .. ctx };
pure $ Level.mvar (mkNumName metaPrefix idx)
def uLookupIdx (idx : Nat) : State Context (Option Level) :=
def uLookupIdx (idx : Nat) : StateM Context (Option Level) :=
do ctx ← get; pure $ ctx.uVals.get! idx
partial def uShallowInstantiate : Level → State Context Level
partial def uShallowInstantiate : Level → StateM Context Level
| l =>
match uMetaIdx l with
| some idx => get >>= λ ctx =>
@ -120,10 +120,10 @@ partial def uShallowInstantiate : Level → State Context Level
| some v => uShallowInstantiate v
| none => pure l
def uAssignIdx (idx : Nat) (l : Level) : State Context Unit :=
def uAssignIdx (idx : Nat) (l : Level) : StateM Context Unit :=
modify $ λ ctx => { uVals := ctx.uVals.set idx (some l) .. ctx }
def uAssign (umvar : Level) (l : Level) : State Context Unit :=
def uAssign (umvar : Level) (l : Level) : StateM Context Unit :=
match uMetaIdx umvar with
| some idx => modify $ λ ctx => { uVals := ctx.uVals.set idx (some l) .. ctx }
| _ => panic! "uassign called on non-(tmp-)mvar"
@ -147,8 +147,8 @@ uFind uIsMeta l
partial def uUnify : Level → Level → EState String Context Unit
| l₁, l₂ => do
l₁ ← EState.fromState $ uShallowInstantiate l₁;
l₂ ← EState.fromState $ uShallowInstantiate l₂;
l₁ ← EState.fromStateM $ uShallowInstantiate l₁;
l₂ ← EState.fromStateM $ uShallowInstantiate l₂;
if uIsMeta l₂ && !(uIsMeta l₁)
then uUnify l₂ l₁
else
@ -162,7 +162,7 @@ partial def uUnify : Level → Level → EState String Context Unit
match uMetaIdx l₁ with
| none => when (!(l₁ == l₂)) $ throw "Level.mvar clash"
| some idx => do when (uOccursIn l₁ l₂) $ throw "occurs";
EState.fromState $ uAssignIdx idx l₂
EState.fromStateM $ uAssignIdx idx l₂
| _, _ => throw $ "lUnify: " ++ toString l₁ ++ " !=?= " ++ toString l₂
partial def uInstantiate (ctx : Context) : Level → Level
@ -207,8 +207,8 @@ partial def eUnify : Expr → Expr → EState String Context Unit
if !e₁.hasMVar && !e₂.hasMVar
then unless (e₁ == e₂) $ throw $ "eUnify: " ++ toString e₁ ++ " !=?= " ++ toString e₂
else do
e₁ ← slowWhnf <$> (EState.fromState $ eShallowInstantiate e₁);
e₂ ← slowWhnf <$> (EState.fromState $ eShallowInstantiate e₂);
e₁ ← slowWhnf <$> (EState.fromStateM $ eShallowInstantiate e₁);
e₂ ← slowWhnf <$> (EState.fromStateM $ eShallowInstantiate e₂);
if e₁.isMVar && e₂.isMVar && e₁ == e₂ then pure ()
else if eIsMeta e₂ && !(eIsMeta e₁) then eUnify e₂ e₁
else if e₁.isBVar && e₂.isBVar && e₁.bvarIdx! == e₂.bvarIdx! then pure ()
@ -225,7 +225,7 @@ partial def eUnify : Expr → Expr → EState String Context Unit
eUnify e₁.bindingBody! e₂.bindingBody!
else if eIsMeta e₁ && !(eOccursIn e₂ e₁) then
match eMetaIdx e₁ with
| some idx => EState.fromState $ eAssignIdx idx e₂
| some idx => EState.fromStateM $ eAssignIdx idx e₂
| none => panic! "UNREACHABLE"
else
throw $ "eUnify: " ++ toString e₁ ++ " !=?= " ++ toString e₂
@ -254,7 +254,7 @@ structure AlphaNormData : Type :=
(eRenameMap : RBMap Nat Nat (λ n₁ n₂ => n₁ < n₂) := mkRBMap _ _ _)
(uRenameMap : RBMap Nat Nat (λ n₁ n₂ => n₁ < n₂) := mkRBMap _ _ _)
partial def uAlphaNormalizeCore : Level → State AlphaNormData Level
partial def uAlphaNormalizeCore : Level → StateM AlphaNormData Level
| l =>
if !l.hasMVar then pure l else
match l with
@ -280,7 +280,7 @@ partial def uAlphaNormalizeCore : Level → State AlphaNormData Level
pure l
| some alphaIdx => pure $ Level.mvar (mkNumName alphaMetaPrefix alphaIdx)
partial def eAlphaNormalizeCore : Expr → State AlphaNormData Expr
partial def eAlphaNormalizeCore : Expr → StateM AlphaNormData Expr
| e =>
if e.isConst then pure e
else if e.isFVar then pure e

View file

@ -97,7 +97,7 @@ instance tracer : SimpleMonadTracerAdapter (TypeUtilM σ ϕ) :=
getTraceState := getTraceState,
modifyTraceState := fun f => modify $ fun s => { traceState := f s.traceState, .. s } }
@[inline] private def liftStateMCtx {α} (x : State σ α) : TypeUtilM σ ϕ α :=
@[inline] private def liftStateMCtx {α} (x : StateM σ α) : TypeUtilM σ ϕ α :=
fun _ s =>
let (a, mctx) := x.run s.mctx;
EState.Result.ok a { mctx := mctx, .. s }