feat: support more indexed monads in elabDoWith (#13801)

This PR adds two new fields to `DoOps`, `splitMonadApp?` and
`mkMonadApp`, so that callers of `elabDoWith` can use indexed monads
like `Measure α` (where `Measure : (α : Type u) → [MeasureSpace α] →
Type u` carries instance arguments) that the default `m α` decomposition
cannot handle. The existing behavior moves into `DoOps.default`.

`splitMonadApp?` replaces the hard-coded `.app m α` step inside the
`extractMonadInfo` recursion, and `mkMonadApp` replaces the hard-coded
`mkApp m α` used to construct the monadic type.

---------

Co-authored-by: Sebastian Graf <sg@lean-fro.org>
This commit is contained in:
Sebastian Graf 2026-05-20 18:59:25 +01:00 committed by GitHub
parent da8bcf7916
commit 65b34530d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 92 additions and 14 deletions

View file

@ -116,6 +116,10 @@ structure DoOps where
`pure e >>= k` to `let x := e; k x`.
-/
isPureApp? : Expr → Option Expr
/-- Match a monad application `m α`, returning `MonadInfo` for `m` and `α`. -/
splitMonadApp? : Expr → Term.TermElabM (Option (MonadInfo × Expr))
/-- Construct `m α` from `α`. -/
mkMonadApp : Expr → DoElabM Expr
deriving Inhabited
unsafe def DoOps.toDoOpsRefImpl (o : DoOps) : DoOpsRef :=
@ -225,8 +229,8 @@ unsafe def ContInfoRef.toContInfoImpl (m : ContInfoRef) : ContInfo :=
opaque ContInfoRef.toContInfo (m : ContInfoRef) : ContInfo
/-- Constructs `m α` from `α`. -/
def mkMonadApp (resultType : Expr) : DoElabM Expr :=
return mkApp (← read).monadInfo.m resultType
def mkMonadApp (resultType : Expr) : DoElabM Expr := do
(← read).ops.toDoOps.mkMonadApp resultType
/-- The cached `PUnit` expression. -/
def mkPUnit : DoElabM Expr := do
@ -282,6 +286,14 @@ def DoOps.default : DoOps where
return mkApp6 (mkConst ``Bind.bind [info.u, info.v]) info.m instBind α β e k
isPureApp? e :=
if e.isAppOfArity ``Pure.pure 4 then some (e.getArg! 3) else none
splitMonadApp? type := do
let .app m resultType := type.consumeMData | return none
unless ← isType resultType do return none
let u ← getDecLevel resultType
let v ← getDecLevel type
return some ({ m, u := u.normalize, v := v.normalize }, resultType)
mkMonadApp α := do
return mkApp (← read).monadInfo.m α
/-- Register the given name as that of a `mut` variable. -/
def declareMutVar (x : Ident) (k : DoElabM α) : DoElabM α := do
@ -694,19 +706,11 @@ def enterFinally (resultType : Expr) (k : DoElabM Expr) : DoElabM Expr := do
withDoBlockResultType resultType k
/-- Extracts `MonadInfo` and monadic result type `α` from the expected type of a `do` block `m α`. -/
private partial def extractMonadInfo (expectedType? : Option Expr) : Term.TermElabM (MonadInfo × Expr) := do
private partial def extractMonadInfo (ops : DoOps) (expectedType? : Option Expr) : Term.TermElabM (MonadInfo × Expr) := do
let some expectedType := expectedType? | mkUnknownMonadResult
let expectedType ← instantiateMVars expectedType
let extractStep? (type : Expr) : Term.TermElabM (Option (MonadInfo × Expr)) := do
let .app m resultType := type.consumeMData | return none
unless ← isType resultType do return none
let u ← getDecLevel resultType
let v ← getDecLevel type
let u := u.normalize
let v := v.normalize
return some ({ m, u, v }, resultType)
let rec extract? (type : Expr) : Term.TermElabM (Option (MonadInfo × Expr)) := do
match (← extractStep? type) with
match (← ops.splitMonadApp? type) with
| some r => return r
| none =>
let typeNew ← whnfCore type
@ -731,7 +735,7 @@ where
/-- Create the `Context` for `do` elaboration from the given expected type of a `do` block. -/
def mkContext (expectedType? : Option Expr) (ops : DoOps := .default) : TermElabM Context := do
let (mi, resultType) ← extractMonadInfo expectedType?
let (mi, resultType) ← extractMonadInfo ops expectedType?
let returnCont ← ReturnCont.mkPure resultType
let contInfo := ContInfo.toContInfoRef { returnCont }
return { monadInfo := mi, doBlockResultType := resultType, contInfo,

View file

@ -0,0 +1,73 @@
import Lean
/-!
Tests that `DoOps` callbacks can take apart and reconstruct an indexed monad
application like `Measure Nat`, where `Measure : (α : Type) → [MeasureSpace α] → Type`
has an instance argument the default extractor cannot peel off.
`splitMonadApp?` lets the caller decompose the expected type into the
`Measure` constant plus the result type, and `mkMonadApp` lets it rebuild
`Measure α` with the instance argument synthesised.
-/
open Lean Lean.Parser Lean.Meta Lean.Elab Lean.Elab.Do Lean.Elab.Term
set_option backward.do.legacy false
class MeasureSpace (α : Type u) where
structure Measure (α : Type u) [MeasureSpace α] where
value : α
def Measure.pure {α} [MeasureSpace α] (x : α) : Measure α := ⟨x⟩
def Measure.bind {α β} [MeasureSpace α] [MeasureSpace β]
(mx : Measure α) (f : α → Measure β) : Measure β := f mx.value
def randOps : DoOps := { DoOps.default with
mkPureApp _ e := do
let eStx ← Term.exprToSyntax e
Term.elabTermEnsuringType (← `(Measure.pure $eStx)) none
mkBindApp _ _ e k := do
let eStx ← Term.exprToSyntax e
let kStx ← Term.exprToSyntax k
Term.elabTermEnsuringType (← `(Measure.bind $eStx $kStx)) none
isPureApp? e :=
if e.isAppOfArity ``Measure.pure 3 then some e.appArg! else none
splitMonadApp? type := do
let type := type.consumeMData
unless type.isAppOfArity ``Measure 2 do return none
let resultType := type.getAppArgs[0]!
let u ← getDecLevel resultType
return some ({ m := type.getAppFn, u := u.normalize, v := u.normalize }, resultType)
mkMonadApp α := do
let m ← Term.exprToSyntax (← read).monadInfo.m
Term.elabTermEnsuringType (← `($m $(← Term.exprToSyntax α))) none
}
syntax (name := randKind) "do_rand " doSeq : term
@[term_elab randKind] def elabRand : Term.TermElab := fun stx et? => do
let `(do_rand $doSeq) := stx | throwUnsupportedSyntax
elabDoWith randOps doSeq et?
instance : MeasureSpace Nat := ⟨⟩
def uniform (n : Nat) : Measure Nat := ⟨n/2⟩
/-- info: Measure.pure 42 : Measure Nat -/
#guard_msgs in
#check (do_rand return 42 : Measure Nat)
/-- info: uniform 10 : Measure Nat -/
#guard_msgs in
#check (do_rand do
let a : Nat ← uniform 10
return a : Measure Nat)
/--
info: (uniform 10).bind fun a => Measure.pure (a + 1) : Measure Nat
-/
#guard_msgs in
#check (do_rand do
let a : Nat ← uniform 10
return a + 1 : Measure Nat)

View file

@ -41,7 +41,7 @@ def modifyN (f : Nat → Nat) : IState Nat Nat Unit := fun i => ((), f i)
/-! ## Pluggable ops emitting `IxMonad.pure` / `IxMonad.bind` -/
def ixOps : DoOps where
def ixOps : DoOps := { DoOps.default with
mkPureApp α e := do
let info := (← read).monadInfo
let mα := mkApp info.m α
@ -58,6 +58,7 @@ def ixOps : DoOps where
isPureApp? e :=
-- `@IxMonad.pure ι m inst α i e` — 6 args.
if e.isAppOfArity ``IxMonad.pure 6 then some (e.getArg! 5) else none
}
/-! ## `ido` surface syntax -/