feat: use StateT.run instead of function application (#5121)

This PR using `StateT.run` rather than the "defeq abuse" of function
application. There remain many places where we still use function
application for `ReaderT`, but I've updated this in the touched files.

(To really solve this, we would make `StateT` irreducible, but that is
not happening here.)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Kim Morrison 2026-03-03 14:12:26 +11:00 committed by GitHub
parent 9a841125e7
commit dd710dd1bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 21 additions and 21 deletions

View file

@ -414,7 +414,7 @@ Renders a `Format` to a string.
-/
def pretty (f : Format) (width : Nat := defWidth) (indent : Nat := 0) (column := 0) : String :=
let act : StateM State Unit := prettyM f width indent
State.out <| act (State.mk "" column) |>.snd
State.out <| act.run (State.mk "" column) |>.snd
end Format

View file

@ -126,7 +126,7 @@ private partial def replaceRecApps (recArgInfos : Array RecArgInfo) (positions :
(below : Expr) (e : Expr) : M Expr :=
let recFnNames := recArgInfos.map (·.fnName)
let containsRecFn (e : Expr) : StateRefT (HasConstCache recFnNames) M Bool :=
modifyGet (·.contains e)
modifyGet (HasConstCache.contains e |>.run ·)
let rec loop (below : Expr) (e : Expr) : StateRefT (HasConstCache recFnNames) M Expr := do
if !(← containsRecFn e) then
return e

View file

@ -93,7 +93,7 @@ where
e.withApp fun f args => return mkAppN (← loop F f) (← args.mapM (loop F))
containsRecFn (e : Expr) : RecM recFnName Bool := do
modifyGet (·.contains e)
modifyGet (HasConstCache.contains e |>.run)
loop (F : Expr) (e : Expr) : RecM recFnName Expr := do
if !(← containsRecFn e) then

View file

@ -229,7 +229,7 @@ where
loop param f
containsRecFn (e : Expr) : M recFnName α Bool := do
modifyGetThe (HasConstCache #[recFnName]) (·.contains e)
modifyGetThe (HasConstCache #[recFnName]) (HasConstCache.contains e |>.run)
loop (param : Expr) (e : Expr) : M recFnName α Unit := do
if !(← containsRecFn e) then
@ -551,7 +551,7 @@ try first, when the mutually recursive functions have similar argument structure
-/
partial def generateCombinations? (numMeasures : Array Nat) (threshold : Nat := 32) :
Option (Array (Array Nat)) :=
(do goUniform 0; go 0 #[]) |>.run #[] |>.2
(do goUniform 0; go 0 #[]) |>.run.run #[] |>.2
where
-- Enumerate all permissible uniform combinations
goUniform (idx : Nat) : OptionT (StateM (Array (Array Nat))) Unit := do

View file

@ -137,7 +137,7 @@ end AbstractMVars
-/
def abstractMVars (e : Expr) (levels : Bool := true): MetaM AbstractMVarsResult := do
let e ← instantiateMVars e
let (e, s) := AbstractMVars.abstractExprMVars e
let (e, s) := AbstractMVars.abstractExprMVars e |>.run
{ mctx := (← getMCtx), lctx := (← getLCtx), ngen := (← getNGen), abstractLevels := levels }
setNGen s.ngen
setMCtx s.mctx

View file

@ -242,7 +242,7 @@ def tell (x : Expr) : M Unit := fun xs => pure ((), xs.push x)
def localM (f : Array Expr → MetaM (Array Expr)) (act : M α) : M α := fun xs => do
let n := xs.size
let (b, xs') ← act xs
let (b, xs') ← StateT.run act xs
pure (b, xs'[*...n] ++ (← f xs'[n...*]))
def localMapM (f : Expr → MetaM Expr) (act : M α) : M α :=

View file

@ -1474,7 +1474,7 @@ structure UnivMVarParamResult where
def levelMVarToParam (mctx : MetavarContext) (alreadyUsedPred : Name → Bool) (except : LMVarId → Bool) (e : Expr) (paramNamePrefix : Name := `u) (nextParamIdx : Nat := 1)
: UnivMVarParamResult :=
let (e, s) := LevelMVarToParam.main e { except, paramNamePrefix, alreadyUsedPred } { mctx, nextParamIdx }
let (e, s) := LevelMVarToParam.main e { except, paramNamePrefix, alreadyUsedPred } |>.run { mctx, nextParamIdx }
{ mctx := s.mctx
newParamNames := s.paramNames
nextParamIdx := s.nextParamIdx

View file

@ -661,7 +661,7 @@ def handleRpcRelease (p : Lsp.RpcReleaseParams) : WorkerM Unit := do
discard do rpcReleaseRef ref
seshRef.modify fun st =>
let st := st.keptAlive monoMsNow
let ((), objects) := discardRefs st.objects
let ((), objects) := discardRefs.run st.objects
{ st with objects }
def handleRpcKeepAlive (p : Lsp.RpcKeepAliveParams) : WorkerM Unit := do

View file

@ -97,7 +97,7 @@ def wrapRpcProcedure (method : Name) paramType respType
| Except.error e => throw e
| Except.ok ret =>
seshRef.modifyGet fun st =>
rpcEncode ret st.objects |>.map id ({st with objects := ·})
rpcEncode ret |>.run st.objects |>.map id ({st with objects := ·})
def registerBuiltinRpcProcedure (method : Name) paramType respType
[RpcEncodable paramType] [RpcEncodable respType]

View file

@ -116,7 +116,7 @@ private abbrev MsgFmtM := StateT (Array EmbedFmt) IO
open MessageData in
private partial def msgToInteractiveAux (msgData : MessageData) : IO (Format × Array EmbedFmt) :=
go { currNamespace := Name.anonymous, openDecls := [] } none msgData #[]
go { currNamespace := Name.anonymous, openDecls := [] } none msgData |>.run #[]
where
pushEmbed (e : EmbedFmt) : MsgFmtM Nat :=
modifyGet fun es => (es.size, es.push e)

View file

@ -100,7 +100,7 @@ private instance : Std.Format.MonadPrettyFormat (StateM TaggedState) where
is the indentation level at this point. The latter is used to print sub-trees accurately by passing
it again as the `indent` argument. -/
def prettyTagged (f : Format) (indent := 0) (w : Nat := Std.Format.defWidth) : TaggedText (Nat × Nat) :=
(f.prettyM w indent : StateM TaggedState Unit) {} |>.snd.out
(f.prettyM w indent : StateM TaggedState Unit).run {} |>.snd.out
/-- Remove tags, leaving just the pretty-printed string. -/
partial def stripTags (tt : TaggedText α) : String :=

View file

@ -70,7 +70,7 @@ private abbrev MonitorM := ReaderT MonitorContext <| StateT MonitorState BaseIO
@[inline] private def MonitorM.run
(ctx : MonitorContext) (s : MonitorState) (self : MonitorM α)
: BaseIO (α × MonitorState) :=
self ctx s
StateT.run (ReaderT.run self ctx) s
/--
The ANSI escape sequence for clearing the current line

View file

@ -67,7 +67,7 @@ def elabConfigFile
let input ← IO.FS.readFile configFile
let inputCtx := Parser.mkInputContext input configFile.toString
let (header, parserState, messages) ← Parser.parseHeader inputCtx
let (env, messages) ← processHeader header leanOpts inputCtx messages
let (env, messages) ← StateT.run (processHeader header leanOpts inputCtx) messages
let env := env.setMainModule configModuleName
-- Configure extensions

View file

@ -145,7 +145,7 @@ public def run?' {ε σ α : Type u} [Functor m] (init : σ) (x : EStateT ε σ
: StateT σ m α := fun s => do
match (← x s) with
| .ok a s => return (a, s)
| .error e s => h e s
| .error e s => StateT.run (h e) s
/-- Lift the `m` monad into the `EStateT ε σ m` monad transformer. -/
@[always_inline, inline]

View file

@ -32,11 +32,11 @@ public instance (priority := low) [Pure m] [MonadExceptOf ε m] : MonadLiftT (Ex
-- Remark: not necessarily optimal; uses context non-linearly
public instance (priority := low) [Monad m] [MonadReaderOf ρ m] [MonadLiftT n m] : MonadLiftT (ReaderT ρ n) m where
monadLift act := do act (← read)
monadLift act := do act.run (← read)
-- Remark: not necessarily optimal; uses state non-linearly
public instance (priority := low) [Monad m] [MonadStateOf σ m] [MonadLiftT n m] : MonadLiftT (StateT σ n) m where
monadLift act := do let (a, s) ← act (← get); set s; pure a
monadLift act := do let (a, s) ← act.run (m := n) (← get); set s; pure a
public instance (priority := low) [Monad m] [Alternative m] [MonadLiftT n m] : MonadLiftT (OptionT n) m where
monadLift act := act.run >>= liftM

View file

@ -264,7 +264,7 @@ public instance [Monad n] [MonadLiftT m n] : MonadLog (MonadLogT m n) where
ReaderT.adapt f self
@[inline] public def ignoreLog [Pure m] (self : MonadLogT m n α) : n α :=
self MonadLog.nop
self.run MonadLog.nop
end MonadLogT
@ -486,7 +486,7 @@ public instance [Monad m] : MonadLog (LogT m) := .ofMonadState
namespace LogT
public abbrev run [Functor m] (self : LogT m α) (log : Log := {}) : m (α × Log) :=
public abbrev run (self : LogT m α) (log : Log := {}) : m (α × Log) :=
StateT.run self log
public abbrev run' [Functor m] (self : LogT m α) (log : Log := {}) : m α :=
@ -502,7 +502,7 @@ Thus, this is best used when the lift cannot fail.
[Monad n] [MonadStateOf Log n] [MonadLiftT m n] [MonadFinally n]
(self : LogT m α)
: n α := do
let (a, log) ← self (← takeLog)
let (a, log) ← self.run (← takeLog)
set log
return a
@ -513,7 +513,7 @@ using the new monad's `logger`.
@[inline] public def replayLog
[Monad n] [logger : MonadLog n] [MonadLiftT m n] (self : LogT m α)
: n α := do
let (a, log) ← self {}
let (a, log) ← self.run {}
log.replay (logger := logger)
return a