feat: use StateT.run instead of function application (#5121)
This PR using `StateT.run` rather than the "defeq abuse" of function application. There remain many places where we still use function application for `ReaderT`, but I've updated this in the touched files. (To really solve this, we would make `StateT` irreducible, but that is not happening here.) --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
9a841125e7
commit
dd710dd1bd
16 changed files with 21 additions and 21 deletions
|
|
@ -414,7 +414,7 @@ Renders a `Format` to a string.
|
|||
-/
|
||||
def pretty (f : Format) (width : Nat := defWidth) (indent : Nat := 0) (column := 0) : String :=
|
||||
let act : StateM State Unit := prettyM f width indent
|
||||
State.out <| act (State.mk "" column) |>.snd
|
||||
State.out <| act.run (State.mk "" column) |>.snd
|
||||
|
||||
end Format
|
||||
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ private partial def replaceRecApps (recArgInfos : Array RecArgInfo) (positions :
|
|||
(below : Expr) (e : Expr) : M Expr :=
|
||||
let recFnNames := recArgInfos.map (·.fnName)
|
||||
let containsRecFn (e : Expr) : StateRefT (HasConstCache recFnNames) M Bool :=
|
||||
modifyGet (·.contains e)
|
||||
modifyGet (HasConstCache.contains e |>.run ·)
|
||||
let rec loop (below : Expr) (e : Expr) : StateRefT (HasConstCache recFnNames) M Expr := do
|
||||
if !(← containsRecFn e) then
|
||||
return e
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ where
|
|||
e.withApp fun f args => return mkAppN (← loop F f) (← args.mapM (loop F))
|
||||
|
||||
containsRecFn (e : Expr) : RecM recFnName Bool := do
|
||||
modifyGet (·.contains e)
|
||||
modifyGet (HasConstCache.contains e |>.run)
|
||||
|
||||
loop (F : Expr) (e : Expr) : RecM recFnName Expr := do
|
||||
if !(← containsRecFn e) then
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ where
|
|||
loop param f
|
||||
|
||||
containsRecFn (e : Expr) : M recFnName α Bool := do
|
||||
modifyGetThe (HasConstCache #[recFnName]) (·.contains e)
|
||||
modifyGetThe (HasConstCache #[recFnName]) (HasConstCache.contains e |>.run)
|
||||
|
||||
loop (param : Expr) (e : Expr) : M recFnName α Unit := do
|
||||
if !(← containsRecFn e) then
|
||||
|
|
@ -551,7 +551,7 @@ try first, when the mutually recursive functions have similar argument structure
|
|||
-/
|
||||
partial def generateCombinations? (numMeasures : Array Nat) (threshold : Nat := 32) :
|
||||
Option (Array (Array Nat)) :=
|
||||
(do goUniform 0; go 0 #[]) |>.run #[] |>.2
|
||||
(do goUniform 0; go 0 #[]) |>.run.run #[] |>.2
|
||||
where
|
||||
-- Enumerate all permissible uniform combinations
|
||||
goUniform (idx : Nat) : OptionT (StateM (Array (Array Nat))) Unit := do
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ end AbstractMVars
|
|||
-/
|
||||
def abstractMVars (e : Expr) (levels : Bool := true): MetaM AbstractMVarsResult := do
|
||||
let e ← instantiateMVars e
|
||||
let (e, s) := AbstractMVars.abstractExprMVars e
|
||||
let (e, s) := AbstractMVars.abstractExprMVars e |>.run
|
||||
{ mctx := (← getMCtx), lctx := (← getLCtx), ngen := (← getNGen), abstractLevels := levels }
|
||||
setNGen s.ngen
|
||||
setMCtx s.mctx
|
||||
|
|
|
|||
|
|
@ -242,7 +242,7 @@ def tell (x : Expr) : M Unit := fun xs => pure ((), xs.push x)
|
|||
|
||||
def localM (f : Array Expr → MetaM (Array Expr)) (act : M α) : M α := fun xs => do
|
||||
let n := xs.size
|
||||
let (b, xs') ← act xs
|
||||
let (b, xs') ← StateT.run act xs
|
||||
pure (b, xs'[*...n] ++ (← f xs'[n...*]))
|
||||
|
||||
def localMapM (f : Expr → MetaM Expr) (act : M α) : M α :=
|
||||
|
|
|
|||
|
|
@ -1474,7 +1474,7 @@ structure UnivMVarParamResult where
|
|||
|
||||
def levelMVarToParam (mctx : MetavarContext) (alreadyUsedPred : Name → Bool) (except : LMVarId → Bool) (e : Expr) (paramNamePrefix : Name := `u) (nextParamIdx : Nat := 1)
|
||||
: UnivMVarParamResult :=
|
||||
let (e, s) := LevelMVarToParam.main e { except, paramNamePrefix, alreadyUsedPred } { mctx, nextParamIdx }
|
||||
let (e, s) := LevelMVarToParam.main e { except, paramNamePrefix, alreadyUsedPred } |>.run { mctx, nextParamIdx }
|
||||
{ mctx := s.mctx
|
||||
newParamNames := s.paramNames
|
||||
nextParamIdx := s.nextParamIdx
|
||||
|
|
|
|||
|
|
@ -661,7 +661,7 @@ def handleRpcRelease (p : Lsp.RpcReleaseParams) : WorkerM Unit := do
|
|||
discard do rpcReleaseRef ref
|
||||
seshRef.modify fun st =>
|
||||
let st := st.keptAlive monoMsNow
|
||||
let ((), objects) := discardRefs st.objects
|
||||
let ((), objects) := discardRefs.run st.objects
|
||||
{ st with objects }
|
||||
|
||||
def handleRpcKeepAlive (p : Lsp.RpcKeepAliveParams) : WorkerM Unit := do
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ def wrapRpcProcedure (method : Name) paramType respType
|
|||
| Except.error e => throw e
|
||||
| Except.ok ret =>
|
||||
seshRef.modifyGet fun st =>
|
||||
rpcEncode ret st.objects |>.map id ({st with objects := ·})
|
||||
rpcEncode ret |>.run st.objects |>.map id ({st with objects := ·})
|
||||
|
||||
def registerBuiltinRpcProcedure (method : Name) paramType respType
|
||||
[RpcEncodable paramType] [RpcEncodable respType]
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ private abbrev MsgFmtM := StateT (Array EmbedFmt) IO
|
|||
|
||||
open MessageData in
|
||||
private partial def msgToInteractiveAux (msgData : MessageData) : IO (Format × Array EmbedFmt) :=
|
||||
go { currNamespace := Name.anonymous, openDecls := [] } none msgData #[]
|
||||
go { currNamespace := Name.anonymous, openDecls := [] } none msgData |>.run #[]
|
||||
where
|
||||
pushEmbed (e : EmbedFmt) : MsgFmtM Nat :=
|
||||
modifyGet fun es => (es.size, es.push e)
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ private instance : Std.Format.MonadPrettyFormat (StateM TaggedState) where
|
|||
is the indentation level at this point. The latter is used to print sub-trees accurately by passing
|
||||
it again as the `indent` argument. -/
|
||||
def prettyTagged (f : Format) (indent := 0) (w : Nat := Std.Format.defWidth) : TaggedText (Nat × Nat) :=
|
||||
(f.prettyM w indent : StateM TaggedState Unit) {} |>.snd.out
|
||||
(f.prettyM w indent : StateM TaggedState Unit).run {} |>.snd.out
|
||||
|
||||
/-- Remove tags, leaving just the pretty-printed string. -/
|
||||
partial def stripTags (tt : TaggedText α) : String :=
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ private abbrev MonitorM := ReaderT MonitorContext <| StateT MonitorState BaseIO
|
|||
@[inline] private def MonitorM.run
|
||||
(ctx : MonitorContext) (s : MonitorState) (self : MonitorM α)
|
||||
: BaseIO (α × MonitorState) :=
|
||||
self ctx s
|
||||
StateT.run (ReaderT.run self ctx) s
|
||||
|
||||
/--
|
||||
The ANSI escape sequence for clearing the current line
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ def elabConfigFile
|
|||
let input ← IO.FS.readFile configFile
|
||||
let inputCtx := Parser.mkInputContext input configFile.toString
|
||||
let (header, parserState, messages) ← Parser.parseHeader inputCtx
|
||||
let (env, messages) ← processHeader header leanOpts inputCtx messages
|
||||
let (env, messages) ← StateT.run (processHeader header leanOpts inputCtx) messages
|
||||
let env := env.setMainModule configModuleName
|
||||
|
||||
-- Configure extensions
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ public def run?' {ε σ α : Type u} [Functor m] (init : σ) (x : EStateT ε σ
|
|||
: StateT σ m α := fun s => do
|
||||
match (← x s) with
|
||||
| .ok a s => return (a, s)
|
||||
| .error e s => h e s
|
||||
| .error e s => StateT.run (h e) s
|
||||
|
||||
/-- Lift the `m` monad into the `EStateT ε σ m` monad transformer. -/
|
||||
@[always_inline, inline]
|
||||
|
|
|
|||
|
|
@ -32,11 +32,11 @@ public instance (priority := low) [Pure m] [MonadExceptOf ε m] : MonadLiftT (Ex
|
|||
|
||||
-- Remark: not necessarily optimal; uses context non-linearly
|
||||
public instance (priority := low) [Monad m] [MonadReaderOf ρ m] [MonadLiftT n m] : MonadLiftT (ReaderT ρ n) m where
|
||||
monadLift act := do act (← read)
|
||||
monadLift act := do act.run (← read)
|
||||
|
||||
-- Remark: not necessarily optimal; uses state non-linearly
|
||||
public instance (priority := low) [Monad m] [MonadStateOf σ m] [MonadLiftT n m] : MonadLiftT (StateT σ n) m where
|
||||
monadLift act := do let (a, s) ← act (← get); set s; pure a
|
||||
monadLift act := do let (a, s) ← act.run (m := n) (← get); set s; pure a
|
||||
|
||||
public instance (priority := low) [Monad m] [Alternative m] [MonadLiftT n m] : MonadLiftT (OptionT n) m where
|
||||
monadLift act := act.run >>= liftM
|
||||
|
|
|
|||
|
|
@ -264,7 +264,7 @@ public instance [Monad n] [MonadLiftT m n] : MonadLog (MonadLogT m n) where
|
|||
ReaderT.adapt f self
|
||||
|
||||
@[inline] public def ignoreLog [Pure m] (self : MonadLogT m n α) : n α :=
|
||||
self MonadLog.nop
|
||||
self.run MonadLog.nop
|
||||
|
||||
end MonadLogT
|
||||
|
||||
|
|
@ -486,7 +486,7 @@ public instance [Monad m] : MonadLog (LogT m) := .ofMonadState
|
|||
|
||||
namespace LogT
|
||||
|
||||
public abbrev run [Functor m] (self : LogT m α) (log : Log := {}) : m (α × Log) :=
|
||||
public abbrev run (self : LogT m α) (log : Log := {}) : m (α × Log) :=
|
||||
StateT.run self log
|
||||
|
||||
public abbrev run' [Functor m] (self : LogT m α) (log : Log := {}) : m α :=
|
||||
|
|
@ -502,7 +502,7 @@ Thus, this is best used when the lift cannot fail.
|
|||
[Monad n] [MonadStateOf Log n] [MonadLiftT m n] [MonadFinally n]
|
||||
(self : LogT m α)
|
||||
: n α := do
|
||||
let (a, log) ← self (← takeLog)
|
||||
let (a, log) ← self.run (← takeLog)
|
||||
set log
|
||||
return a
|
||||
|
||||
|
|
@ -513,7 +513,7 @@ using the new monad's `logger`.
|
|||
@[inline] public def replayLog
|
||||
[Monad n] [logger : MonadLog n] [MonadLiftT m n] (self : LogT m α)
|
||||
: n α := do
|
||||
let (a, log) ← self {}
|
||||
let (a, log) ← self.run {}
|
||||
log.replay (logger := logger)
|
||||
return a
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue