From dd710dd1bd8e34e0ffff7db3e1117e11039834d0 Mon Sep 17 00:00:00 2001 From: Kim Morrison <477956+kim-em@users.noreply.github.com> Date: Tue, 3 Mar 2026 14:12:26 +1100 Subject: [PATCH] feat: use StateT.run instead of function application (#5121) This PR using `StateT.run` rather than the "defeq abuse" of function application. There remain many places where we still use function application for `ReaderT`, but I've updated this in the touched files. (To really solve this, we would make `StateT` irreducible, but that is not happening here.) --------- Co-authored-by: Claude Opus 4.6 --- src/Init/Data/Format/Basic.lean | 2 +- src/Lean/Elab/PreDefinition/Structural/BRecOn.lean | 2 +- src/Lean/Elab/PreDefinition/WF/Fix.lean | 2 +- src/Lean/Elab/PreDefinition/WF/GuessLex.lean | 4 ++-- src/Lean/Meta/AbstractMVars.lean | 2 +- src/Lean/Meta/Tactic/FunInd.lean | 2 +- src/Lean/MetavarContext.lean | 2 +- src/Lean/Server/FileWorker.lean | 2 +- src/Lean/Server/Rpc/RequestHandling.lean | 2 +- src/Lean/Widget/InteractiveDiagnostic.lean | 2 +- src/Lean/Widget/TaggedText.lean | 2 +- src/lake/Lake/Build/Run.lean | 2 +- src/lake/Lake/Load/Lean/Elab.lean | 2 +- src/lake/Lake/Util/EStateT.lean | 2 +- src/lake/Lake/Util/Lift.lean | 4 ++-- src/lake/Lake/Util/Log.lean | 8 ++++---- 16 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/Init/Data/Format/Basic.lean b/src/Init/Data/Format/Basic.lean index a9805b366b..e0dbb5fa8b 100644 --- a/src/Init/Data/Format/Basic.lean +++ b/src/Init/Data/Format/Basic.lean @@ -414,7 +414,7 @@ Renders a `Format` to a string. -/ def pretty (f : Format) (width : Nat := defWidth) (indent : Nat := 0) (column := 0) : String := let act : StateM State Unit := prettyM f width indent - State.out <| act (State.mk "" column) |>.snd + State.out <| act.run (State.mk "" column) |>.snd end Format diff --git a/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean b/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean index 259ced1c96..9280dea15a 100644 --- a/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean +++ b/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean @@ -126,7 +126,7 @@ private partial def replaceRecApps (recArgInfos : Array RecArgInfo) (positions : (below : Expr) (e : Expr) : M Expr := let recFnNames := recArgInfos.map (·.fnName) let containsRecFn (e : Expr) : StateRefT (HasConstCache recFnNames) M Bool := - modifyGet (·.contains e) + modifyGet (HasConstCache.contains e |>.run ·) let rec loop (below : Expr) (e : Expr) : StateRefT (HasConstCache recFnNames) M Expr := do if !(← containsRecFn e) then return e diff --git a/src/Lean/Elab/PreDefinition/WF/Fix.lean b/src/Lean/Elab/PreDefinition/WF/Fix.lean index 04757ab0d9..0fef8a9aac 100644 --- a/src/Lean/Elab/PreDefinition/WF/Fix.lean +++ b/src/Lean/Elab/PreDefinition/WF/Fix.lean @@ -93,7 +93,7 @@ where e.withApp fun f args => return mkAppN (← loop F f) (← args.mapM (loop F)) containsRecFn (e : Expr) : RecM recFnName Bool := do - modifyGet (·.contains e) + modifyGet (HasConstCache.contains e |>.run) loop (F : Expr) (e : Expr) : RecM recFnName Expr := do if !(← containsRecFn e) then diff --git a/src/Lean/Elab/PreDefinition/WF/GuessLex.lean b/src/Lean/Elab/PreDefinition/WF/GuessLex.lean index eaca707d1b..aa68104a56 100644 --- a/src/Lean/Elab/PreDefinition/WF/GuessLex.lean +++ b/src/Lean/Elab/PreDefinition/WF/GuessLex.lean @@ -229,7 +229,7 @@ where loop param f containsRecFn (e : Expr) : M recFnName α Bool := do - modifyGetThe (HasConstCache #[recFnName]) (·.contains e) + modifyGetThe (HasConstCache #[recFnName]) (HasConstCache.contains e |>.run) loop (param : Expr) (e : Expr) : M recFnName α Unit := do if !(← containsRecFn e) then @@ -551,7 +551,7 @@ try first, when the mutually recursive functions have similar argument structure -/ partial def generateCombinations? (numMeasures : Array Nat) (threshold : Nat := 32) : Option (Array (Array Nat)) := - (do goUniform 0; go 0 #[]) |>.run #[] |>.2 + (do goUniform 0; go 0 #[]) |>.run.run #[] |>.2 where -- Enumerate all permissible uniform combinations goUniform (idx : Nat) : OptionT (StateM (Array (Array Nat))) Unit := do diff --git a/src/Lean/Meta/AbstractMVars.lean b/src/Lean/Meta/AbstractMVars.lean index da60e24bcd..175738e203 100644 --- a/src/Lean/Meta/AbstractMVars.lean +++ b/src/Lean/Meta/AbstractMVars.lean @@ -137,7 +137,7 @@ end AbstractMVars -/ def abstractMVars (e : Expr) (levels : Bool := true): MetaM AbstractMVarsResult := do let e ← instantiateMVars e - let (e, s) := AbstractMVars.abstractExprMVars e + let (e, s) := AbstractMVars.abstractExprMVars e |>.run { mctx := (← getMCtx), lctx := (← getLCtx), ngen := (← getNGen), abstractLevels := levels } setNGen s.ngen setMCtx s.mctx diff --git a/src/Lean/Meta/Tactic/FunInd.lean b/src/Lean/Meta/Tactic/FunInd.lean index 7902200d21..6ac89206b4 100644 --- a/src/Lean/Meta/Tactic/FunInd.lean +++ b/src/Lean/Meta/Tactic/FunInd.lean @@ -242,7 +242,7 @@ def tell (x : Expr) : M Unit := fun xs => pure ((), xs.push x) def localM (f : Array Expr → MetaM (Array Expr)) (act : M α) : M α := fun xs => do let n := xs.size - let (b, xs') ← act xs + let (b, xs') ← StateT.run act xs pure (b, xs'[*...n] ++ (← f xs'[n...*])) def localMapM (f : Expr → MetaM Expr) (act : M α) : M α := diff --git a/src/Lean/MetavarContext.lean b/src/Lean/MetavarContext.lean index 4ae7f52cf4..63ebeb43ed 100644 --- a/src/Lean/MetavarContext.lean +++ b/src/Lean/MetavarContext.lean @@ -1474,7 +1474,7 @@ structure UnivMVarParamResult where def levelMVarToParam (mctx : MetavarContext) (alreadyUsedPred : Name → Bool) (except : LMVarId → Bool) (e : Expr) (paramNamePrefix : Name := `u) (nextParamIdx : Nat := 1) : UnivMVarParamResult := - let (e, s) := LevelMVarToParam.main e { except, paramNamePrefix, alreadyUsedPred } { mctx, nextParamIdx } + let (e, s) := LevelMVarToParam.main e { except, paramNamePrefix, alreadyUsedPred } |>.run { mctx, nextParamIdx } { mctx := s.mctx newParamNames := s.paramNames nextParamIdx := s.nextParamIdx diff --git a/src/Lean/Server/FileWorker.lean b/src/Lean/Server/FileWorker.lean index 1a28a4aa18..fad934d4bd 100644 --- a/src/Lean/Server/FileWorker.lean +++ b/src/Lean/Server/FileWorker.lean @@ -661,7 +661,7 @@ def handleRpcRelease (p : Lsp.RpcReleaseParams) : WorkerM Unit := do discard do rpcReleaseRef ref seshRef.modify fun st => let st := st.keptAlive monoMsNow - let ((), objects) := discardRefs st.objects + let ((), objects) := discardRefs.run st.objects { st with objects } def handleRpcKeepAlive (p : Lsp.RpcKeepAliveParams) : WorkerM Unit := do diff --git a/src/Lean/Server/Rpc/RequestHandling.lean b/src/Lean/Server/Rpc/RequestHandling.lean index 56c4ec768e..5fb7b7587d 100644 --- a/src/Lean/Server/Rpc/RequestHandling.lean +++ b/src/Lean/Server/Rpc/RequestHandling.lean @@ -97,7 +97,7 @@ def wrapRpcProcedure (method : Name) paramType respType | Except.error e => throw e | Except.ok ret => seshRef.modifyGet fun st => - rpcEncode ret st.objects |>.map id ({st with objects := ·}) + rpcEncode ret |>.run st.objects |>.map id ({st with objects := ·}) def registerBuiltinRpcProcedure (method : Name) paramType respType [RpcEncodable paramType] [RpcEncodable respType] diff --git a/src/Lean/Widget/InteractiveDiagnostic.lean b/src/Lean/Widget/InteractiveDiagnostic.lean index b4712facd6..5fc9e0b22a 100644 --- a/src/Lean/Widget/InteractiveDiagnostic.lean +++ b/src/Lean/Widget/InteractiveDiagnostic.lean @@ -116,7 +116,7 @@ private abbrev MsgFmtM := StateT (Array EmbedFmt) IO open MessageData in private partial def msgToInteractiveAux (msgData : MessageData) : IO (Format × Array EmbedFmt) := - go { currNamespace := Name.anonymous, openDecls := [] } none msgData #[] + go { currNamespace := Name.anonymous, openDecls := [] } none msgData |>.run #[] where pushEmbed (e : EmbedFmt) : MsgFmtM Nat := modifyGet fun es => (es.size, es.push e) diff --git a/src/Lean/Widget/TaggedText.lean b/src/Lean/Widget/TaggedText.lean index f903b03faa..48895a5a4e 100644 --- a/src/Lean/Widget/TaggedText.lean +++ b/src/Lean/Widget/TaggedText.lean @@ -100,7 +100,7 @@ private instance : Std.Format.MonadPrettyFormat (StateM TaggedState) where is the indentation level at this point. The latter is used to print sub-trees accurately by passing it again as the `indent` argument. -/ def prettyTagged (f : Format) (indent := 0) (w : Nat := Std.Format.defWidth) : TaggedText (Nat × Nat) := - (f.prettyM w indent : StateM TaggedState Unit) {} |>.snd.out + (f.prettyM w indent : StateM TaggedState Unit).run {} |>.snd.out /-- Remove tags, leaving just the pretty-printed string. -/ partial def stripTags (tt : TaggedText α) : String := diff --git a/src/lake/Lake/Build/Run.lean b/src/lake/Lake/Build/Run.lean index bd0f8fef28..5980f5c06c 100644 --- a/src/lake/Lake/Build/Run.lean +++ b/src/lake/Lake/Build/Run.lean @@ -70,7 +70,7 @@ private abbrev MonitorM := ReaderT MonitorContext <| StateT MonitorState BaseIO @[inline] private def MonitorM.run (ctx : MonitorContext) (s : MonitorState) (self : MonitorM α) : BaseIO (α × MonitorState) := - self ctx s + StateT.run (ReaderT.run self ctx) s /-- The ANSI escape sequence for clearing the current line diff --git a/src/lake/Lake/Load/Lean/Elab.lean b/src/lake/Lake/Load/Lean/Elab.lean index a520d11439..17086b358f 100644 --- a/src/lake/Lake/Load/Lean/Elab.lean +++ b/src/lake/Lake/Load/Lean/Elab.lean @@ -67,7 +67,7 @@ def elabConfigFile let input ← IO.FS.readFile configFile let inputCtx := Parser.mkInputContext input configFile.toString let (header, parserState, messages) ← Parser.parseHeader inputCtx - let (env, messages) ← processHeader header leanOpts inputCtx messages + let (env, messages) ← StateT.run (processHeader header leanOpts inputCtx) messages let env := env.setMainModule configModuleName -- Configure extensions diff --git a/src/lake/Lake/Util/EStateT.lean b/src/lake/Lake/Util/EStateT.lean index 3b5f4475b4..d05f339a1e 100644 --- a/src/lake/Lake/Util/EStateT.lean +++ b/src/lake/Lake/Util/EStateT.lean @@ -145,7 +145,7 @@ public def run?' {ε σ α : Type u} [Functor m] (init : σ) (x : EStateT ε σ : StateT σ m α := fun s => do match (← x s) with | .ok a s => return (a, s) - | .error e s => h e s + | .error e s => StateT.run (h e) s /-- Lift the `m` monad into the `EStateT ε σ m` monad transformer. -/ @[always_inline, inline] diff --git a/src/lake/Lake/Util/Lift.lean b/src/lake/Lake/Util/Lift.lean index 60c228dde2..c90a73c09a 100644 --- a/src/lake/Lake/Util/Lift.lean +++ b/src/lake/Lake/Util/Lift.lean @@ -32,11 +32,11 @@ public instance (priority := low) [Pure m] [MonadExceptOf ε m] : MonadLiftT (Ex -- Remark: not necessarily optimal; uses context non-linearly public instance (priority := low) [Monad m] [MonadReaderOf ρ m] [MonadLiftT n m] : MonadLiftT (ReaderT ρ n) m where - monadLift act := do act (← read) + monadLift act := do act.run (← read) -- Remark: not necessarily optimal; uses state non-linearly public instance (priority := low) [Monad m] [MonadStateOf σ m] [MonadLiftT n m] : MonadLiftT (StateT σ n) m where - monadLift act := do let (a, s) ← act (← get); set s; pure a + monadLift act := do let (a, s) ← act.run (m := n) (← get); set s; pure a public instance (priority := low) [Monad m] [Alternative m] [MonadLiftT n m] : MonadLiftT (OptionT n) m where monadLift act := act.run >>= liftM diff --git a/src/lake/Lake/Util/Log.lean b/src/lake/Lake/Util/Log.lean index 257b81dc72..f10be14232 100644 --- a/src/lake/Lake/Util/Log.lean +++ b/src/lake/Lake/Util/Log.lean @@ -264,7 +264,7 @@ public instance [Monad n] [MonadLiftT m n] : MonadLog (MonadLogT m n) where ReaderT.adapt f self @[inline] public def ignoreLog [Pure m] (self : MonadLogT m n α) : n α := - self MonadLog.nop + self.run MonadLog.nop end MonadLogT @@ -486,7 +486,7 @@ public instance [Monad m] : MonadLog (LogT m) := .ofMonadState namespace LogT -public abbrev run [Functor m] (self : LogT m α) (log : Log := {}) : m (α × Log) := +public abbrev run (self : LogT m α) (log : Log := {}) : m (α × Log) := StateT.run self log public abbrev run' [Functor m] (self : LogT m α) (log : Log := {}) : m α := @@ -502,7 +502,7 @@ Thus, this is best used when the lift cannot fail. [Monad n] [MonadStateOf Log n] [MonadLiftT m n] [MonadFinally n] (self : LogT m α) : n α := do - let (a, log) ← self (← takeLog) + let (a, log) ← self.run (← takeLog) set log return a @@ -513,7 +513,7 @@ using the new monad's `logger`. @[inline] public def replayLog [Monad n] [logger : MonadLog n] [MonadLiftT m n] (self : LogT m α) : n α := do - let (a, log) ← self {} + let (a, log) ← self.run {} log.replay (logger := logger) return a