diff --git a/src/Lean/Elab/Do/Basic.lean b/src/Lean/Elab/Do/Basic.lean index ac8a7d0791..c9eaf573a0 100644 --- a/src/Lean/Elab/Do/Basic.lean +++ b/src/Lean/Elab/Do/Basic.lean @@ -116,6 +116,10 @@ structure DoOps where `pure e >>= k` to `let x := e; k x`. -/ isPureApp? : Expr → Option Expr + /-- Match a monad application `m α`, returning `MonadInfo` for `m` and `α`. -/ + splitMonadApp? : Expr → Term.TermElabM (Option (MonadInfo × Expr)) + /-- Construct `m α` from `α`. -/ + mkMonadApp : Expr → DoElabM Expr deriving Inhabited unsafe def DoOps.toDoOpsRefImpl (o : DoOps) : DoOpsRef := @@ -225,8 +229,8 @@ unsafe def ContInfoRef.toContInfoImpl (m : ContInfoRef) : ContInfo := opaque ContInfoRef.toContInfo (m : ContInfoRef) : ContInfo /-- Constructs `m α` from `α`. -/ -def mkMonadApp (resultType : Expr) : DoElabM Expr := - return mkApp (← read).monadInfo.m resultType +def mkMonadApp (resultType : Expr) : DoElabM Expr := do + (← read).ops.toDoOps.mkMonadApp resultType /-- The cached `PUnit` expression. -/ def mkPUnit : DoElabM Expr := do @@ -282,6 +286,14 @@ def DoOps.default : DoOps where return mkApp6 (mkConst ``Bind.bind [info.u, info.v]) info.m instBind α β e k isPureApp? e := if e.isAppOfArity ``Pure.pure 4 then some (e.getArg! 3) else none + splitMonadApp? type := do + let .app m resultType := type.consumeMData | return none + unless ← isType resultType do return none + let u ← getDecLevel resultType + let v ← getDecLevel type + return some ({ m, u := u.normalize, v := v.normalize }, resultType) + mkMonadApp α := do + return mkApp (← read).monadInfo.m α /-- Register the given name as that of a `mut` variable. -/ def declareMutVar (x : Ident) (k : DoElabM α) : DoElabM α := do @@ -694,19 +706,11 @@ def enterFinally (resultType : Expr) (k : DoElabM Expr) : DoElabM Expr := do withDoBlockResultType resultType k /-- Extracts `MonadInfo` and monadic result type `α` from the expected type of a `do` block `m α`. -/ -private partial def extractMonadInfo (expectedType? : Option Expr) : Term.TermElabM (MonadInfo × Expr) := do +private partial def extractMonadInfo (ops : DoOps) (expectedType? : Option Expr) : Term.TermElabM (MonadInfo × Expr) := do let some expectedType := expectedType? | mkUnknownMonadResult let expectedType ← instantiateMVars expectedType - let extractStep? (type : Expr) : Term.TermElabM (Option (MonadInfo × Expr)) := do - let .app m resultType := type.consumeMData | return none - unless ← isType resultType do return none - let u ← getDecLevel resultType - let v ← getDecLevel type - let u := u.normalize - let v := v.normalize - return some ({ m, u, v }, resultType) let rec extract? (type : Expr) : Term.TermElabM (Option (MonadInfo × Expr)) := do - match (← extractStep? type) with + match (← ops.splitMonadApp? type) with | some r => return r | none => let typeNew ← whnfCore type @@ -731,7 +735,7 @@ where /-- Create the `Context` for `do` elaboration from the given expected type of a `do` block. -/ def mkContext (expectedType? : Option Expr) (ops : DoOps := .default) : TermElabM Context := do - let (mi, resultType) ← extractMonadInfo expectedType? + let (mi, resultType) ← extractMonadInfo ops expectedType? let returnCont ← ReturnCont.mkPure resultType let contInfo := ContInfo.toContInfoRef { returnCont } return { monadInfo := mi, doBlockResultType := resultType, contInfo, diff --git a/tests/elab/doNotationIndexedMonad.lean b/tests/elab/doNotationIndexedMonad.lean new file mode 100644 index 0000000000..3125011247 --- /dev/null +++ b/tests/elab/doNotationIndexedMonad.lean @@ -0,0 +1,73 @@ +import Lean + +/-! +Tests that `DoOps` callbacks can take apart and reconstruct an indexed monad +application like `Measure Nat`, where `Measure : (α : Type) → [MeasureSpace α] → Type` +has an instance argument the default extractor cannot peel off. + +`splitMonadApp?` lets the caller decompose the expected type into the +`Measure` constant plus the result type, and `mkMonadApp` lets it rebuild +`Measure α` with the instance argument synthesised. +-/ + +open Lean Lean.Parser Lean.Meta Lean.Elab Lean.Elab.Do Lean.Elab.Term + +set_option backward.do.legacy false + +class MeasureSpace (α : Type u) where + +structure Measure (α : Type u) [MeasureSpace α] where + value : α + +def Measure.pure {α} [MeasureSpace α] (x : α) : Measure α := ⟨x⟩ +def Measure.bind {α β} [MeasureSpace α] [MeasureSpace β] + (mx : Measure α) (f : α → Measure β) : Measure β := f mx.value + +def randOps : DoOps := { DoOps.default with + mkPureApp _ e := do + let eStx ← Term.exprToSyntax e + Term.elabTermEnsuringType (← `(Measure.pure $eStx)) none + mkBindApp _ _ e k := do + let eStx ← Term.exprToSyntax e + let kStx ← Term.exprToSyntax k + Term.elabTermEnsuringType (← `(Measure.bind $eStx $kStx)) none + isPureApp? e := + if e.isAppOfArity ``Measure.pure 3 then some e.appArg! else none + splitMonadApp? type := do + let type := type.consumeMData + unless type.isAppOfArity ``Measure 2 do return none + let resultType := type.getAppArgs[0]! + let u ← getDecLevel resultType + return some ({ m := type.getAppFn, u := u.normalize, v := u.normalize }, resultType) + mkMonadApp α := do + let m ← Term.exprToSyntax (← read).monadInfo.m + Term.elabTermEnsuringType (← `($m $(← Term.exprToSyntax α))) none +} + +syntax (name := randKind) "do_rand " doSeq : term + +@[term_elab randKind] def elabRand : Term.TermElab := fun stx et? => do + let `(do_rand $doSeq) := stx | throwUnsupportedSyntax + elabDoWith randOps doSeq et? + +instance : MeasureSpace Nat := ⟨⟩ + +def uniform (n : Nat) : Measure Nat := ⟨n/2⟩ + +/-- info: Measure.pure 42 : Measure Nat -/ +#guard_msgs in +#check (do_rand return 42 : Measure Nat) + +/-- info: uniform 10 : Measure Nat -/ +#guard_msgs in +#check (do_rand do + let a : Nat ← uniform 10 + return a : Measure Nat) + +/-- +info: (uniform 10).bind fun a => Measure.pure (a + 1) : Measure Nat +-/ +#guard_msgs in +#check (do_rand do + let a : Nat ← uniform 10 + return a + 1 : Measure Nat) diff --git a/tests/elab/doNotationPluggableOps.lean b/tests/elab/doNotationPluggableOps.lean index 3b66218547..a585486f38 100644 --- a/tests/elab/doNotationPluggableOps.lean +++ b/tests/elab/doNotationPluggableOps.lean @@ -41,7 +41,7 @@ def modifyN (f : Nat → Nat) : IState Nat Nat Unit := fun i => ((), f i) /-! ## Pluggable ops emitting `IxMonad.pure` / `IxMonad.bind` -/ -def ixOps : DoOps where +def ixOps : DoOps := { DoOps.default with mkPureApp α e := do let info := (← read).monadInfo let mα := mkApp info.m α @@ -58,6 +58,7 @@ def ixOps : DoOps where isPureApp? e := -- `@IxMonad.pure ι m inst α i e` — 6 args. if e.isAppOfArity ``IxMonad.pure 6 then some (e.getArg! 5) else none +} /-! ## `ido` surface syntax -/