feat: support more indexed monads in elabDoWith (#13801)
This PR adds two new fields to `DoOps`, `splitMonadApp?` and `mkMonadApp`, so that callers of `elabDoWith` can use indexed monads like `Measure α` (where `Measure : (α : Type u) → [MeasureSpace α] → Type u` carries instance arguments) that the default `m α` decomposition cannot handle. The existing behavior moves into `DoOps.default`. `splitMonadApp?` replaces the hard-coded `.app m α` step inside the `extractMonadInfo` recursion, and `mkMonadApp` replaces the hard-coded `mkApp m α` used to construct the monadic type. --------- Co-authored-by: Sebastian Graf <sg@lean-fro.org>
This commit is contained in:
parent
da8bcf7916
commit
65b34530d3
3 changed files with 92 additions and 14 deletions
|
|
@ -116,6 +116,10 @@ structure DoOps where
|
|||
`pure e >>= k` to `let x := e; k x`.
|
||||
-/
|
||||
isPureApp? : Expr → Option Expr
|
||||
/-- Match a monad application `m α`, returning `MonadInfo` for `m` and `α`. -/
|
||||
splitMonadApp? : Expr → Term.TermElabM (Option (MonadInfo × Expr))
|
||||
/-- Construct `m α` from `α`. -/
|
||||
mkMonadApp : Expr → DoElabM Expr
|
||||
deriving Inhabited
|
||||
|
||||
unsafe def DoOps.toDoOpsRefImpl (o : DoOps) : DoOpsRef :=
|
||||
|
|
@ -225,8 +229,8 @@ unsafe def ContInfoRef.toContInfoImpl (m : ContInfoRef) : ContInfo :=
|
|||
opaque ContInfoRef.toContInfo (m : ContInfoRef) : ContInfo
|
||||
|
||||
/-- Constructs `m α` from `α`. -/
|
||||
def mkMonadApp (resultType : Expr) : DoElabM Expr :=
|
||||
return mkApp (← read).monadInfo.m resultType
|
||||
def mkMonadApp (resultType : Expr) : DoElabM Expr := do
|
||||
(← read).ops.toDoOps.mkMonadApp resultType
|
||||
|
||||
/-- The cached `PUnit` expression. -/
|
||||
def mkPUnit : DoElabM Expr := do
|
||||
|
|
@ -282,6 +286,14 @@ def DoOps.default : DoOps where
|
|||
return mkApp6 (mkConst ``Bind.bind [info.u, info.v]) info.m instBind α β e k
|
||||
isPureApp? e :=
|
||||
if e.isAppOfArity ``Pure.pure 4 then some (e.getArg! 3) else none
|
||||
splitMonadApp? type := do
|
||||
let .app m resultType := type.consumeMData | return none
|
||||
unless ← isType resultType do return none
|
||||
let u ← getDecLevel resultType
|
||||
let v ← getDecLevel type
|
||||
return some ({ m, u := u.normalize, v := v.normalize }, resultType)
|
||||
mkMonadApp α := do
|
||||
return mkApp (← read).monadInfo.m α
|
||||
|
||||
/-- Register the given name as that of a `mut` variable. -/
|
||||
def declareMutVar (x : Ident) (k : DoElabM α) : DoElabM α := do
|
||||
|
|
@ -694,19 +706,11 @@ def enterFinally (resultType : Expr) (k : DoElabM Expr) : DoElabM Expr := do
|
|||
withDoBlockResultType resultType k
|
||||
|
||||
/-- Extracts `MonadInfo` and monadic result type `α` from the expected type of a `do` block `m α`. -/
|
||||
private partial def extractMonadInfo (expectedType? : Option Expr) : Term.TermElabM (MonadInfo × Expr) := do
|
||||
private partial def extractMonadInfo (ops : DoOps) (expectedType? : Option Expr) : Term.TermElabM (MonadInfo × Expr) := do
|
||||
let some expectedType := expectedType? | mkUnknownMonadResult
|
||||
let expectedType ← instantiateMVars expectedType
|
||||
let extractStep? (type : Expr) : Term.TermElabM (Option (MonadInfo × Expr)) := do
|
||||
let .app m resultType := type.consumeMData | return none
|
||||
unless ← isType resultType do return none
|
||||
let u ← getDecLevel resultType
|
||||
let v ← getDecLevel type
|
||||
let u := u.normalize
|
||||
let v := v.normalize
|
||||
return some ({ m, u, v }, resultType)
|
||||
let rec extract? (type : Expr) : Term.TermElabM (Option (MonadInfo × Expr)) := do
|
||||
match (← extractStep? type) with
|
||||
match (← ops.splitMonadApp? type) with
|
||||
| some r => return r
|
||||
| none =>
|
||||
let typeNew ← whnfCore type
|
||||
|
|
@ -731,7 +735,7 @@ where
|
|||
|
||||
/-- Create the `Context` for `do` elaboration from the given expected type of a `do` block. -/
|
||||
def mkContext (expectedType? : Option Expr) (ops : DoOps := .default) : TermElabM Context := do
|
||||
let (mi, resultType) ← extractMonadInfo expectedType?
|
||||
let (mi, resultType) ← extractMonadInfo ops expectedType?
|
||||
let returnCont ← ReturnCont.mkPure resultType
|
||||
let contInfo := ContInfo.toContInfoRef { returnCont }
|
||||
return { monadInfo := mi, doBlockResultType := resultType, contInfo,
|
||||
|
|
|
|||
73
tests/elab/doNotationIndexedMonad.lean
Normal file
73
tests/elab/doNotationIndexedMonad.lean
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
import Lean
|
||||
|
||||
/-!
|
||||
Tests that `DoOps` callbacks can take apart and reconstruct an indexed monad
|
||||
application like `Measure Nat`, where `Measure : (α : Type) → [MeasureSpace α] → Type`
|
||||
has an instance argument the default extractor cannot peel off.
|
||||
|
||||
`splitMonadApp?` lets the caller decompose the expected type into the
|
||||
`Measure` constant plus the result type, and `mkMonadApp` lets it rebuild
|
||||
`Measure α` with the instance argument synthesised.
|
||||
-/
|
||||
|
||||
open Lean Lean.Parser Lean.Meta Lean.Elab Lean.Elab.Do Lean.Elab.Term
|
||||
|
||||
set_option backward.do.legacy false
|
||||
|
||||
class MeasureSpace (α : Type u) where
|
||||
|
||||
structure Measure (α : Type u) [MeasureSpace α] where
|
||||
value : α
|
||||
|
||||
def Measure.pure {α} [MeasureSpace α] (x : α) : Measure α := ⟨x⟩
|
||||
def Measure.bind {α β} [MeasureSpace α] [MeasureSpace β]
|
||||
(mx : Measure α) (f : α → Measure β) : Measure β := f mx.value
|
||||
|
||||
def randOps : DoOps := { DoOps.default with
|
||||
mkPureApp _ e := do
|
||||
let eStx ← Term.exprToSyntax e
|
||||
Term.elabTermEnsuringType (← `(Measure.pure $eStx)) none
|
||||
mkBindApp _ _ e k := do
|
||||
let eStx ← Term.exprToSyntax e
|
||||
let kStx ← Term.exprToSyntax k
|
||||
Term.elabTermEnsuringType (← `(Measure.bind $eStx $kStx)) none
|
||||
isPureApp? e :=
|
||||
if e.isAppOfArity ``Measure.pure 3 then some e.appArg! else none
|
||||
splitMonadApp? type := do
|
||||
let type := type.consumeMData
|
||||
unless type.isAppOfArity ``Measure 2 do return none
|
||||
let resultType := type.getAppArgs[0]!
|
||||
let u ← getDecLevel resultType
|
||||
return some ({ m := type.getAppFn, u := u.normalize, v := u.normalize }, resultType)
|
||||
mkMonadApp α := do
|
||||
let m ← Term.exprToSyntax (← read).monadInfo.m
|
||||
Term.elabTermEnsuringType (← `($m $(← Term.exprToSyntax α))) none
|
||||
}
|
||||
|
||||
syntax (name := randKind) "do_rand " doSeq : term
|
||||
|
||||
@[term_elab randKind] def elabRand : Term.TermElab := fun stx et? => do
|
||||
let `(do_rand $doSeq) := stx | throwUnsupportedSyntax
|
||||
elabDoWith randOps doSeq et?
|
||||
|
||||
instance : MeasureSpace Nat := ⟨⟩
|
||||
|
||||
def uniform (n : Nat) : Measure Nat := ⟨n/2⟩
|
||||
|
||||
/-- info: Measure.pure 42 : Measure Nat -/
|
||||
#guard_msgs in
|
||||
#check (do_rand return 42 : Measure Nat)
|
||||
|
||||
/-- info: uniform 10 : Measure Nat -/
|
||||
#guard_msgs in
|
||||
#check (do_rand do
|
||||
let a : Nat ← uniform 10
|
||||
return a : Measure Nat)
|
||||
|
||||
/--
|
||||
info: (uniform 10).bind fun a => Measure.pure (a + 1) : Measure Nat
|
||||
-/
|
||||
#guard_msgs in
|
||||
#check (do_rand do
|
||||
let a : Nat ← uniform 10
|
||||
return a + 1 : Measure Nat)
|
||||
|
|
@ -41,7 +41,7 @@ def modifyN (f : Nat → Nat) : IState Nat Nat Unit := fun i => ((), f i)
|
|||
|
||||
/-! ## Pluggable ops emitting `IxMonad.pure` / `IxMonad.bind` -/
|
||||
|
||||
def ixOps : DoOps where
|
||||
def ixOps : DoOps := { DoOps.default with
|
||||
mkPureApp α e := do
|
||||
let info := (← read).monadInfo
|
||||
let mα := mkApp info.m α
|
||||
|
|
@ -58,6 +58,7 @@ def ixOps : DoOps where
|
|||
isPureApp? e :=
|
||||
-- `@IxMonad.pure ι m inst α i e` — 6 args.
|
||||
if e.isAppOfArity ``IxMonad.pure 6 then some (e.getArg! 5) else none
|
||||
}
|
||||
|
||||
/-! ## `ido` surface syntax -/
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue