refactor: State ==> StateM
This commit is contained in:
parent
cb554d6473
commit
5bb5ef6296
14 changed files with 42 additions and 42 deletions
|
|
@ -128,7 +128,7 @@ fun s =>
|
|||
instance {ε σ σ'} : MonadStateAdapter σ σ' (EState ε σ) (EState ε σ') :=
|
||||
⟨fun σ'' α => EState.adaptState⟩
|
||||
|
||||
@[inline] def fromState {ε σ α : Type} (x : State σ α) : EState ε σ α :=
|
||||
@[inline] def fromStateM {ε σ α : Type} (x : StateM σ α) : EState ε σ α :=
|
||||
fun s =>
|
||||
match x.run s with
|
||||
| (a, s') => EState.Result.ok a s'
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ x s
|
|||
@[inline] def StateT.run' {σ : Type u} {m : Type u → Type v} [Functor m] {α : Type u} (x : StateT σ m α) (s : σ) : m α :=
|
||||
Prod.fst <$> x s
|
||||
|
||||
@[reducible] def State (σ α : Type u) : Type u := StateT σ Id α
|
||||
@[reducible] def StateM (σ α : Type u) : Type u := StateT σ Id α
|
||||
|
||||
namespace StateT
|
||||
section
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ def hasAssignedMVar (mctx : σ) : Expr → Bool
|
|||
| Expr.proj _ _ e => e.hasMVar && hasAssignedMVar e
|
||||
| Expr.mvar mvarId => isExprAssigned mctx mvarId
|
||||
|
||||
partial def instantiateLevelMVars : Level → State σ Level
|
||||
partial def instantiateLevelMVars : Level → StateM σ Level
|
||||
| lvl@(Level.succ lvl₁) => do lvl₁ ← instantiateLevelMVars lvl₁; pure (Level.updateSucc! lvl lvl₁)
|
||||
| lvl@(Level.max lvl₁ lvl₂) => do lvl₁ ← instantiateLevelMVars lvl₁; lvl₂ ← instantiateLevelMVars lvl₂; pure (Level.updateMax! lvl lvl₁ lvl₂)
|
||||
| lvl@(Level.imax lvl₁ lvl₂) => do lvl₁ ← instantiateLevelMVars lvl₁; lvl₂ ← instantiateLevelMVars lvl₂; pure (Level.updateIMax! lvl lvl₁ lvl₂)
|
||||
|
|
@ -181,7 +181,7 @@ partial def instantiateLevelMVars : Level → State σ Level
|
|||
| lvl => pure lvl
|
||||
|
||||
namespace InstantiateExprMVars
|
||||
abbrev M (σ : Type) := State (WithHashMapCache Expr Expr σ)
|
||||
abbrev M (σ : Type) := StateM (WithHashMapCache Expr Expr σ)
|
||||
|
||||
@[inline] def instantiateLevelMVars (lvl : Level) : M σ Level :=
|
||||
WithHashMapCache.fromState $ AbstractMetavarContext.instantiateLevelMVars lvl
|
||||
|
|
@ -281,13 +281,13 @@ partial def main : Expr → M σ Expr
|
|||
|
||||
end InstantiateExprMVars
|
||||
|
||||
def instantiateMVars (e : Expr) : State σ Expr :=
|
||||
def instantiateMVars (e : Expr) : StateM σ Expr :=
|
||||
if !e.hasMVar then pure e
|
||||
else WithHashMapCache.toState $ InstantiateExprMVars.main e
|
||||
|
||||
namespace DependsOn
|
||||
|
||||
abbrev M := State ExprSet
|
||||
abbrev M := StateM ExprSet
|
||||
|
||||
@[inline] def visit (main : Expr → M Bool) (e : Expr) : M Bool :=
|
||||
if !e.hasMVar && !e.hasFVar then
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ ps.map $ fun p => { borrow := p.ty.isObj, .. p }
|
|||
def initBorrowIfNotExported (exported : Bool) (ps : Array Param) : Array Param :=
|
||||
if exported then ps else initBorrow ps
|
||||
|
||||
partial def visitFnBody (fnid : FunId) : FnBody → State ParamMap Unit
|
||||
partial def visitFnBody (fnid : FunId) : FnBody → StateM ParamMap Unit
|
||||
| FnBody.jdecl j xs v b => do
|
||||
modify $ fun m => m.insert (Key.jp fnid j) (initBorrow xs);
|
||||
visitFnBody v;
|
||||
|
|
@ -79,7 +79,7 @@ partial def visitFnBody (fnid : FunId) : FnBody → State ParamMap Unit
|
|||
let (instr, b) := e.split;
|
||||
visitFnBody b
|
||||
|
||||
def visitDecls (env : Environment) (decls : Array Decl) : State ParamMap Unit :=
|
||||
def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit :=
|
||||
decls.forM $ fun decl => match decl with
|
||||
| Decl.fdecl f xs _ b => do
|
||||
let exported := isExport env f;
|
||||
|
|
@ -137,7 +137,7 @@ structure BorrowInfState :=
|
|||
(modifiedOwned : Bool := false)
|
||||
(modifiedParamMap : Bool := false)
|
||||
|
||||
abbrev M := ReaderT BorrowInfCtx (State BorrowInfState)
|
||||
abbrev M := ReaderT BorrowInfCtx (StateM BorrowInfState)
|
||||
|
||||
def markModifiedParamMap : M Unit :=
|
||||
modify $ fun s => { modifiedParamMap := true, .. s }
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def isBoxedName : Name → Bool
|
|||
| Name.mkString _ "_boxed" => true
|
||||
| _ => false
|
||||
|
||||
abbrev N := State Nat
|
||||
abbrev N := StateM Nat
|
||||
|
||||
private def mkFresh : N VarId :=
|
||||
modifyGet $ fun n => ({ idx := n }, n + 1)
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ structure InterpState :=
|
|||
(assignments : Array Assignment)
|
||||
(funVals : PArray Value) -- we take snapshots during fixpoint computations
|
||||
|
||||
abbrev M := ReaderT InterpContext (State InterpState)
|
||||
abbrev M := ReaderT InterpContext (StateM InterpState)
|
||||
|
||||
open Value
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ match b with
|
|||
|
||||
namespace UsesLeanNamespace
|
||||
|
||||
abbrev M := ReaderT Environment (State NameSet)
|
||||
abbrev M := ReaderT Environment (StateM NameSet)
|
||||
|
||||
def leanNameSpacePrefix := `Lean
|
||||
|
||||
|
|
@ -57,7 +57,7 @@ def usesLeanNamespace (env : Environment) : Decl → Bool
|
|||
|
||||
namespace CollectUsedDecls
|
||||
|
||||
abbrev M := ReaderT Environment (State NameSet)
|
||||
abbrev M := ReaderT Environment (StateM NameSet)
|
||||
|
||||
@[inline] def collect (f : FunId) : M Unit :=
|
||||
modify $ fun s => s.insert f
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ mask.foldl
|
|||
| none => b)
|
||||
b
|
||||
|
||||
abbrev M := ReaderT Context (State Nat)
|
||||
abbrev M := ReaderT Context (StateM Nat)
|
||||
def mkFresh : M VarId :=
|
||||
modifyGet $ fun n => ({ idx := n }, n + 1)
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ namespace IsLive
|
|||
Remark: we don't need to track local join points because we assume there is
|
||||
no variable or join point shadowing in our IR.
|
||||
-/
|
||||
abbrev M := State LocalContext
|
||||
abbrev M := StateM LocalContext
|
||||
|
||||
@[inline] def visitVar (w : Index) (x : VarId) : M Bool := pure (HasIndex.visitVar w x)
|
||||
@[inline] def visitJP (w : Index) (x : JoinPointId) : M Bool := pure (HasIndex.visitJP w x)
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ def normExpr : Expr → M Expr
|
|||
| Expr.isTaggedPtr x, m => Expr.isTaggedPtr (normVar x m)
|
||||
| e@(Expr.lit v), m => e
|
||||
|
||||
abbrev N := ReaderT IndexRenaming (State Nat)
|
||||
abbrev N := ReaderT IndexRenaming (StateM Nat)
|
||||
|
||||
@[inline] def withVar {α : Type} (x : VarId) (k : VarId → N α) : N α :=
|
||||
fun m => do
|
||||
|
|
|
|||
|
|
@ -61,13 +61,13 @@ structure WithHashMapCache (α β σ : Type) [HasBeq α] [Hashable α] :=
|
|||
|
||||
namespace WithHashMapCache
|
||||
|
||||
@[inline] def getCache {α β σ : Type} [HasBeq α] [Hashable α] : State (WithHashMapCache α β σ) (HashMap α β) :=
|
||||
@[inline] def getCache {α β σ : Type} [HasBeq α] [Hashable α] : StateM (WithHashMapCache α β σ) (HashMap α β) :=
|
||||
do s ← get; pure s.cache
|
||||
|
||||
@[inline] def modifyCache {α β σ : Type} [HasBeq α] [Hashable α] (f : HashMap α β → HashMap α β) : State (WithHashMapCache α β σ) Unit :=
|
||||
@[inline] def modifyCache {α β σ : Type} [HasBeq α] [Hashable α] (f : HashMap α β → HashMap α β) : StateM (WithHashMapCache α β σ) Unit :=
|
||||
modify $ fun s => { cache := f s.cache, .. s }
|
||||
|
||||
instance stateAdapter (α β σ : Type) [HasBeq α] [Hashable α] : MonadHashMapCacheAdapter α β (State (WithHashMapCache α β σ)) :=
|
||||
instance stateAdapter (α β σ : Type) [HasBeq α] [Hashable α] : MonadHashMapCacheAdapter α β (StateM (WithHashMapCache α β σ)) :=
|
||||
{ getCache := WithHashMapCache.getCache,
|
||||
modifyCache := WithHashMapCache.modifyCache }
|
||||
|
||||
|
|
@ -81,13 +81,13 @@ instance estateAdapter (α β ε σ : Type) [HasBeq α] [Hashable α] : MonadHas
|
|||
{ getCache := WithHashMapCache.getCacheE,
|
||||
modifyCache := WithHashMapCache.modifyCacheE }
|
||||
|
||||
@[inline] def fromState {α β σ δ : Type} [HasBeq α] [Hashable α] (x : State σ δ) : State (WithHashMapCache α β σ) δ :=
|
||||
@[inline] def fromState {α β σ δ : Type} [HasBeq α] [Hashable α] (x : StateM σ δ) : StateM (WithHashMapCache α β σ) δ :=
|
||||
adaptState
|
||||
(fun (s : WithHashMapCache α β σ) => (s.state, s.cache))
|
||||
(fun (s : σ) (cache : HashMap α β) => { state := s, cache := cache })
|
||||
x
|
||||
|
||||
@[inline] def toState {α β σ δ : Type} [HasBeq α] [Hashable α] (x : State (WithHashMapCache α β σ) δ) : State σ δ :=
|
||||
@[inline] def toState {α β σ δ : Type} [HasBeq α] [Hashable α] (x : StateM (WithHashMapCache α β σ) δ) : StateM σ δ :=
|
||||
adaptState'
|
||||
(fun (s : σ) => ({ state := s } : WithHashMapCache α β σ))
|
||||
(fun (s : WithHashMapCache α β σ) => s.state)
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ private def updateInfo : SourceInfo → String.Pos → SourceInfo
|
|||
/- Remark: the State `String.Pos` is the `SourceInfo.trailing.stopPos` of the previous token,
|
||||
or the beginning of the String. -/
|
||||
@[inline]
|
||||
private def updateLeadingAux {α} : Syntax α → State String.Pos (Option (Syntax α))
|
||||
private def updateLeadingAux {α} : Syntax α → StateM String.Pos (Option (Syntax α))
|
||||
| atom (some info) val => do
|
||||
last ← get;
|
||||
set info.trailing.stopPos;
|
||||
|
|
|
|||
|
|
@ -44,16 +44,16 @@ def eMetaIdx : Expr → Option Nat
|
|||
|
||||
def eIsMeta (e : Expr) : Bool := (eMetaIdx e).toBool
|
||||
|
||||
def eNewMeta (type : Expr) : State Context Expr :=
|
||||
def eNewMeta (type : Expr) : StateM Context Expr :=
|
||||
do ctx ← get;
|
||||
let idx := ctx.eTypes.size;
|
||||
set { eTypes := ctx.eTypes.push type, eVals := ctx.eVals.push none, .. ctx };
|
||||
pure $ Expr.mvar (mkNumName metaPrefix idx)
|
||||
|
||||
def eLookupIdx (idx : Nat) : State Context (Option Expr) :=
|
||||
def eLookupIdx (idx : Nat) : StateM Context (Option Expr) :=
|
||||
do ctx ← get; pure $ ctx.eVals.get! idx
|
||||
|
||||
partial def eShallowInstantiate : Expr → State Context Expr
|
||||
partial def eShallowInstantiate : Expr → StateM Context Expr
|
||||
| e =>
|
||||
match eMetaIdx e with
|
||||
| some idx => get >>= λ ctx =>
|
||||
|
|
@ -62,7 +62,7 @@ partial def eShallowInstantiate : Expr → State Context Expr
|
|||
| some v => eShallowInstantiate v
|
||||
| none => pure e
|
||||
|
||||
def eInferIdx (idx : Nat) : State Context Expr :=
|
||||
def eInferIdx (idx : Nat) : StateM Context Expr :=
|
||||
do ctx ← get; pure $ ctx.eTypes.get! idx
|
||||
|
||||
def eInfer (ctx : Context) (mvar : Expr) : Expr :=
|
||||
|
|
@ -70,10 +70,10 @@ match eMetaIdx mvar with
|
|||
| some idx => ctx.eTypes.get! idx
|
||||
| none => panic! "eInfer called on non-(tmp-)mvar"
|
||||
|
||||
def eAssignIdx (idx : Nat) (e : Expr) : State Context Unit :=
|
||||
def eAssignIdx (idx : Nat) (e : Expr) : StateM Context Unit :=
|
||||
modify $ λ ctx => { eVals := ctx.eVals.set idx (some e) .. ctx }
|
||||
|
||||
def eAssign (mvar : Expr) (e : Expr) : State Context Unit :=
|
||||
def eAssign (mvar : Expr) (e : Expr) : StateM Context Unit :=
|
||||
match eMetaIdx mvar with
|
||||
| some idx => modify $ λ ctx => { eVals := ctx.eVals.set idx (some e) .. ctx }
|
||||
| _ => panic! "eAssign called on non-(tmp-)mvar"
|
||||
|
|
@ -102,16 +102,16 @@ def uMetaIdx : Level → Option Nat
|
|||
|
||||
def uIsMeta (l : Level) : Bool := (uMetaIdx l).toBool
|
||||
|
||||
def uNewMeta : State Context Level :=
|
||||
def uNewMeta : StateM Context Level :=
|
||||
do ctx ← get;
|
||||
let idx := ctx.uVals.size;
|
||||
set { uVals := ctx.uVals.push none, .. ctx };
|
||||
pure $ Level.mvar (mkNumName metaPrefix idx)
|
||||
|
||||
def uLookupIdx (idx : Nat) : State Context (Option Level) :=
|
||||
def uLookupIdx (idx : Nat) : StateM Context (Option Level) :=
|
||||
do ctx ← get; pure $ ctx.uVals.get! idx
|
||||
|
||||
partial def uShallowInstantiate : Level → State Context Level
|
||||
partial def uShallowInstantiate : Level → StateM Context Level
|
||||
| l =>
|
||||
match uMetaIdx l with
|
||||
| some idx => get >>= λ ctx =>
|
||||
|
|
@ -120,10 +120,10 @@ partial def uShallowInstantiate : Level → State Context Level
|
|||
| some v => uShallowInstantiate v
|
||||
| none => pure l
|
||||
|
||||
def uAssignIdx (idx : Nat) (l : Level) : State Context Unit :=
|
||||
def uAssignIdx (idx : Nat) (l : Level) : StateM Context Unit :=
|
||||
modify $ λ ctx => { uVals := ctx.uVals.set idx (some l) .. ctx }
|
||||
|
||||
def uAssign (umvar : Level) (l : Level) : State Context Unit :=
|
||||
def uAssign (umvar : Level) (l : Level) : StateM Context Unit :=
|
||||
match uMetaIdx umvar with
|
||||
| some idx => modify $ λ ctx => { uVals := ctx.uVals.set idx (some l) .. ctx }
|
||||
| _ => panic! "uassign called on non-(tmp-)mvar"
|
||||
|
|
@ -147,8 +147,8 @@ uFind uIsMeta l
|
|||
|
||||
partial def uUnify : Level → Level → EState String Context Unit
|
||||
| l₁, l₂ => do
|
||||
l₁ ← EState.fromState $ uShallowInstantiate l₁;
|
||||
l₂ ← EState.fromState $ uShallowInstantiate l₂;
|
||||
l₁ ← EState.fromStateM $ uShallowInstantiate l₁;
|
||||
l₂ ← EState.fromStateM $ uShallowInstantiate l₂;
|
||||
if uIsMeta l₂ && !(uIsMeta l₁)
|
||||
then uUnify l₂ l₁
|
||||
else
|
||||
|
|
@ -162,7 +162,7 @@ partial def uUnify : Level → Level → EState String Context Unit
|
|||
match uMetaIdx l₁ with
|
||||
| none => when (!(l₁ == l₂)) $ throw "Level.mvar clash"
|
||||
| some idx => do when (uOccursIn l₁ l₂) $ throw "occurs";
|
||||
EState.fromState $ uAssignIdx idx l₂
|
||||
EState.fromStateM $ uAssignIdx idx l₂
|
||||
| _, _ => throw $ "lUnify: " ++ toString l₁ ++ " !=?= " ++ toString l₂
|
||||
|
||||
partial def uInstantiate (ctx : Context) : Level → Level
|
||||
|
|
@ -207,8 +207,8 @@ partial def eUnify : Expr → Expr → EState String Context Unit
|
|||
if !e₁.hasMVar && !e₂.hasMVar
|
||||
then unless (e₁ == e₂) $ throw $ "eUnify: " ++ toString e₁ ++ " !=?= " ++ toString e₂
|
||||
else do
|
||||
e₁ ← slowWhnf <$> (EState.fromState $ eShallowInstantiate e₁);
|
||||
e₂ ← slowWhnf <$> (EState.fromState $ eShallowInstantiate e₂);
|
||||
e₁ ← slowWhnf <$> (EState.fromStateM $ eShallowInstantiate e₁);
|
||||
e₂ ← slowWhnf <$> (EState.fromStateM $ eShallowInstantiate e₂);
|
||||
if e₁.isMVar && e₂.isMVar && e₁ == e₂ then pure ()
|
||||
else if eIsMeta e₂ && !(eIsMeta e₁) then eUnify e₂ e₁
|
||||
else if e₁.isBVar && e₂.isBVar && e₁.bvarIdx! == e₂.bvarIdx! then pure ()
|
||||
|
|
@ -225,7 +225,7 @@ partial def eUnify : Expr → Expr → EState String Context Unit
|
|||
eUnify e₁.bindingBody! e₂.bindingBody!
|
||||
else if eIsMeta e₁ && !(eOccursIn e₂ e₁) then
|
||||
match eMetaIdx e₁ with
|
||||
| some idx => EState.fromState $ eAssignIdx idx e₂
|
||||
| some idx => EState.fromStateM $ eAssignIdx idx e₂
|
||||
| none => panic! "UNREACHABLE"
|
||||
else
|
||||
throw $ "eUnify: " ++ toString e₁ ++ " !=?= " ++ toString e₂
|
||||
|
|
@ -254,7 +254,7 @@ structure AlphaNormData : Type :=
|
|||
(eRenameMap : RBMap Nat Nat (λ n₁ n₂ => n₁ < n₂) := mkRBMap _ _ _)
|
||||
(uRenameMap : RBMap Nat Nat (λ n₁ n₂ => n₁ < n₂) := mkRBMap _ _ _)
|
||||
|
||||
partial def uAlphaNormalizeCore : Level → State AlphaNormData Level
|
||||
partial def uAlphaNormalizeCore : Level → StateM AlphaNormData Level
|
||||
| l =>
|
||||
if !l.hasMVar then pure l else
|
||||
match l with
|
||||
|
|
@ -280,7 +280,7 @@ partial def uAlphaNormalizeCore : Level → State AlphaNormData Level
|
|||
pure l
|
||||
| some alphaIdx => pure $ Level.mvar (mkNumName alphaMetaPrefix alphaIdx)
|
||||
|
||||
partial def eAlphaNormalizeCore : Expr → State AlphaNormData Expr
|
||||
partial def eAlphaNormalizeCore : Expr → StateM AlphaNormData Expr
|
||||
| e =>
|
||||
if e.isConst then pure e
|
||||
else if e.isFVar then pure e
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ instance tracer : SimpleMonadTracerAdapter (TypeUtilM σ ϕ) :=
|
|||
getTraceState := getTraceState,
|
||||
modifyTraceState := fun f => modify $ fun s => { traceState := f s.traceState, .. s } }
|
||||
|
||||
@[inline] private def liftStateMCtx {α} (x : State σ α) : TypeUtilM σ ϕ α :=
|
||||
@[inline] private def liftStateMCtx {α} (x : StateM σ α) : TypeUtilM σ ϕ α :=
|
||||
fun _ s =>
|
||||
let (a, mctx) := x.run s.mctx;
|
||||
EState.Result.ok a { mctx := mctx, .. s }
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue