diff --git a/src/Lean/Elab/Do/Basic.lean b/src/Lean/Elab/Do/Basic.lean index 6231d9fbe2..16dd038f74 100644 --- a/src/Lean/Elab/Do/Basic.lean +++ b/src/Lean/Elab/Do/Basic.lean @@ -45,6 +45,15 @@ def ContInfoRef : Type := ContInfoRefPointed.type instance : Nonempty ContInfoRef := by exact ContInfoRefPointed.property +-- Same pattern as `ContInfoRef` above; used so `Context` can carry `DoOps` without +-- depending on `DoElabM`. +private opaque DoOpsRefPointed : NonemptyType.{0} + +def DoOpsRef : Type := DoOpsRefPointed.type + +instance : Nonempty DoOpsRef := + by exact DoOpsRefPointed.property + /-- Whether a code block is alive or dead. -/ inductive CodeLiveness where /-- We inferred the code is semantically dead and don't need to elaborate it at all. -/ @@ -90,9 +99,37 @@ structure Context where Whether the current `do` element is dead code. `elabDoElem` will emit a warning if not `.alive`. -/ deadCode : CodeLiveness := .alive + /-- Pluggable builders for `pure` and `bind` applications. -/ + ops : DoOpsRef abbrev DoElabM := ReaderT Context Term.TermElabM +/-- Pluggable builders for the `pure` / `bind` applications emitted by the `do` elaborator. -/ +structure DoOps where + /-- Build `pure (α:=α) e : m α`. -/ + mkPureApp : (α e : Expr) → DoElabM Expr + /-- Build `bind (α:=α) (β:=β) e k : m β`. -/ + mkBindApp : (α β e k : Expr) → DoElabM Expr + /-- + If `e` is syntactically a `pure …` application, return the pure value; otherwise `none`. + Used by `DoElemCont.mkBindUnlessPure` to contract `e >>= pure` to `e` and + `pure e >>= k` to `let x := e; k x`. + -/ + isPureApp? : Expr → Option Expr + deriving Inhabited + +unsafe def DoOps.toDoOpsRefImpl (o : DoOps) : DoOpsRef := + unsafeCast o + +@[implemented_by DoOps.toDoOpsRefImpl] +opaque DoOps.toDoOpsRef (o : DoOps) : DoOpsRef + +unsafe def DoOpsRef.toDoOpsImpl (r : DoOpsRef) : DoOps := + unsafeCast r + +@[implemented_by DoOpsRef.toDoOpsImpl] +opaque DoOpsRef.toDoOps (r : DoOpsRef) : DoOps + /-- Whether the continuation of a `do` element is duplicable and if so whether it is just `pure r` for the result variable `r`. Saying `nonDuplicable` is always safe; `duplicable` allows for more @@ -201,16 +238,7 @@ def mkPUnitUnit : DoElabM Expr := do /-- The expression ``pure (α:=α) e``. -/ def mkPureApp (α e : Expr) : DoElabM Expr := do - let info := (← read).monadInfo - if (← read).deadCode matches .deadSyntactically then - -- There is no dead syntax here. Just return a fresh metavariable so that we don't - -- do the `Term.ensureHasType` check below. - return ← mkFreshExprMVar (mkApp info.m α) - let α ← Term.ensureHasType (mkSort (mkLevelSucc info.u)) α - let e ← Term.ensureHasType α e - let instPure ← Term.mkInstMVar (mkApp (mkConst ``Pure [info.u, info.v]) info.m) - let instPure ← instantiateMVars instPure - return mkApp4 (mkConst ``Pure.pure [info.u, info.v]) info.m instPure α e + (← read).ops.toDoOps.mkPureApp α e /-- Create a `DoElemCont` returning the result using `pure`. -/ def DoElemCont.mkPure (resultType : Expr) : TermElabM DoElemCont := do @@ -229,13 +257,31 @@ def ReturnCont.mkPure (resultType : Expr) : TermElabM ReturnCont := do /-- The expression ``Bind.bind (α:=α) (β:=β) e k``. -/ def mkBindApp (α β e k : Expr) : DoElabM Expr := do - let info := (← read).monadInfo - let α ← Term.ensureHasType (mkSort (mkLevelSucc info.u)) α - let mα := mkApp info.m α - let e ← Term.ensureHasType mα e - let k ← Term.ensureHasType (← mkArrow α (mkApp info.m β)) k - let instBind ← Term.mkInstMVar (mkApp (mkConst ``Bind [info.u, info.v]) info.m) - return mkApp6 (mkConst ``Bind.bind [info.u, info.v]) info.m instBind α β e k + (← read).ops.toDoOps.mkBindApp α β e k + +/-- `DoOps` emitting `Pure.pure` / `Bind.bind`. -/ +def DoOps.default : DoOps where + mkPureApp α e := do + let info := (← read).monadInfo + if (← read).deadCode matches .deadSyntactically then + -- There is no dead syntax here. Just return a fresh metavariable so that we don't + -- do the `Term.ensureHasType` check below. + return ← mkFreshExprMVar (mkApp info.m α) + let α ← Term.ensureHasType (mkSort (mkLevelSucc info.u)) α + let e ← Term.ensureHasType α e + let instPure ← Term.mkInstMVar (mkApp (mkConst ``Pure [info.u, info.v]) info.m) + let instPure ← instantiateMVars instPure + return mkApp4 (mkConst ``Pure.pure [info.u, info.v]) info.m instPure α e + mkBindApp α β e k := do + let info := (← read).monadInfo + let α ← Term.ensureHasType (mkSort (mkLevelSucc info.u)) α + let mα := mkApp info.m α + let e ← Term.ensureHasType mα e + let k ← Term.ensureHasType (← mkArrow α (mkApp info.m β)) k + let instBind ← Term.mkInstMVar (mkApp (mkConst ``Bind [info.u, info.v]) info.m) + return mkApp6 (mkConst ``Bind.bind [info.u, info.v]) info.m instBind α β e k + isPureApp? e := + if e.isAppOfArity ``Pure.pure 4 then some (e.getArg! 3) else none /-- Register the given name as that of a `mut` variable. -/ def declareMutVar (x : Ident) (k : DoElabM α) : DoElabM α := do @@ -434,18 +480,19 @@ def DoElemCont.mkBindUnlessPure (dec : DoElemCont) (e : Expr) : DoElabM Expr := withLocalDecl x .default eResultTy (kind := declKind) fun xFVar => do let body ← k let body' := body.consumeMData + let ops := (← read).ops.toDoOps -- First try to contract `e >>= pure` into `e`. -- Reason: for `pure e >>= pure`, we want to get `pure e` and not `have xFVar := e; pure xFVar`. - if body'.isAppOfArity ``Pure.pure 4 && body'.getArg! 3 == xFVar then - let body'' ← mkPureApp eResultTy xFVar - if ← withNewMCtxDepth do isDefEq body' body'' then - return e + if let some pureArg := ops.isPureApp? body' then + if pureArg == xFVar then + let body'' ← mkPureApp eResultTy xFVar + if ← withNewMCtxDepth do isDefEq body' body'' then + return e -- Now test whether we can contract `pure e >>= k` into `have xFVar := e; k xFVar`. We zeta `xFVar` when -- `e` is duplicable; we don't look at `k` to see whether it is used at most once. let e' := e.consumeMData - if e'.isAppOfArity ``Pure.pure 4 then - let eRes := e'.getArg! 3 + if let some eRes := ops.isPureApp? e' then let e' ← mkPureApp eResultTy eRes let (isPure, isDuplicable) ← withNewMCtxDepth do let isPure ← isDefEq e e' @@ -683,11 +730,12 @@ where return ({ m, u, v }, resultType) /-- Create the `Context` for `do` elaboration from the given expected type of a `do` block. -/ -def mkContext (expectedType? : Option Expr) : TermElabM Context := do +def mkContext (expectedType? : Option Expr) (ops : DoOps := .default) : TermElabM Context := do let (mi, resultType) ← extractMonadInfo expectedType? let returnCont ← ReturnCont.mkPure resultType let contInfo := ContInfo.toContInfoRef { returnCont } - return { monadInfo := mi, doBlockResultType := resultType, contInfo } + return { monadInfo := mi, doBlockResultType := resultType, contInfo, + ops := ops.toDoOpsRef } section NestedActions @@ -903,11 +951,11 @@ def elabNestedAction : Term.TermElab := fun stx _ty? => do let `(← $_rhs) := stx | throwUnsupportedSyntax throwErrorAt stx "Nested action `{stx}` must be nested inside a `do` expression." --- @[builtin_term_elab «do»] -- once the legacy `do` elaborator has been phased out -def elabDo : Term.TermElab := fun e expectedType? => do - let `(do $doSeq) := e | throwError "unexpected `do` block syntax{indentD e}" +/-- Elaborate `doSeq` using `ops` for pure/bind construction. -/ +def elabDoWith (ops : DoOps) (doSeq : TSyntax ``doSeq) + (expectedType? : Option Expr) : TermElabM Expr := do Term.tryPostponeIfNoneOrMVar expectedType? - let ctx ← mkContext expectedType? + let ctx ← mkContext expectedType? (ops := ops) let cont ← DoElemCont.mkPure ctx.doBlockResultType let res ← elabDoSeq doSeq cont |>.run ctx -- Synthesizing default instances here is harmful for expressions such as @@ -920,3 +968,8 @@ def elabDo : Term.TermElab := fun e expectedType? => do -- Term.synthesizeSyntheticMVarsUsingDefault trace[Elab.do] "{← instantiateMVars res}" pure res + +-- @[builtin_term_elab «do»] -- once the legacy `do` elaborator has been phased out +def elabDo : Term.TermElab := fun e expectedType? => do + let `(do $doSeq) := e | throwError "unexpected `do` block syntax{indentD e}" + elabDoWith .default doSeq expectedType? diff --git a/tests/elab/doNotationPluggableOps.lean b/tests/elab/doNotationPluggableOps.lean new file mode 100644 index 0000000000..3b66218547 --- /dev/null +++ b/tests/elab/doNotationPluggableOps.lean @@ -0,0 +1,217 @@ +import Lean + +/-! +Tests for the pluggable pure/bind builders in the `do` elaborator (`DoOps`, `elabDoWith`). + +We define a surface `ido` notation that reuses the full `do` elaborator via `elabDoWith` +but emits `IxMonad.pure` / `IxMonad.bind` instead of `Pure.pure` / `Bind.bind`. + +`IxMonad` is the canonical Atkey parameterised monad (`m : ι → ι → Type u → Type v` with +`pure : α → m i i α` and `bind : m i j α → (α → m j k β) → m i k β`); the shape is +documented in `Control.Monad.Indexed` on Hackage and the PureScript `indexed-monad` +package. + +The control-stack features of `do` (`mut`, `return`, `break`, `continue`, `for`) remain +hard-coded to `Monad` and are therefore off-limits for `ido`. The `ido` programs below +avoid them. +-/ + +open Lean Lean.Parser Lean.Meta Lean.Elab Lean.Elab.Do Lean.Elab.Term + +set_option backward.do.legacy false + +/-! ## Indexed monad and a concrete instance -/ + +class IxMonad (m : ι → ι → Type u → Type v) where + pure : α → m i i α + bind : m i j α → (α → m j k β) → m i k β + +/-- Atkey-style indexed state: `IState i o α = i → α × o`. -/ +abbrev IState (i o α : Type) : Type := i → α × o + +instance : IxMonad IState where + pure a := fun i => (a, i) + bind p f := fun i => let (a, j) := p i; f a j + +/-! Helpers that keep the state type fixed at `Nat` for the common examples. -/ + +def getN : IState Nat Nat Nat := fun s => (s, s) +def putN (n : Nat) : IState Nat Nat Unit := fun _ => ((), n) +def modifyN (f : Nat → Nat) : IState Nat Nat Unit := fun i => ((), f i) + +/-! ## Pluggable ops emitting `IxMonad.pure` / `IxMonad.bind` -/ + +def ixOps : DoOps where + mkPureApp α e := do + let info := (← read).monadInfo + let mα := mkApp info.m α + let eStx ← Term.exprToSyntax e + let stx ← `(IxMonad.pure $eStx) + Term.elabTermEnsuringType stx mα + mkBindApp α β e k := do + let info := (← read).monadInfo + let mβ := mkApp info.m β + let eStx ← Term.exprToSyntax e + let kStx ← Term.exprToSyntax k + let stx ← `(IxMonad.bind $eStx $kStx) + Term.elabTermEnsuringType stx mβ + isPureApp? e := + -- `@IxMonad.pure ι m inst α i e` — 6 args. + if e.isAppOfArity ``IxMonad.pure 6 then some (e.getArg! 5) else none + +/-! ## `ido` surface syntax -/ + +syntax (name := idoKind) "ido " doSeq : term + +@[term_elab idoKind] def elabIDo : Term.TermElab := fun stx et? => do + let `(ido $doSeq) := stx | throwUnsupportedSyntax + elabDoWith ixOps doSeq et? + +/-! ## Example programs + +Each example pairs `#guard_msgs` with `#eval` (or `#check`) to lock behaviour in. +Most keep state type fixed at `Nat`; a couple at the end explore index-preserving +variants with different state types. -/ + +/-! ### 1. Bare pure -/ + +/-- info: (42, 10) -/ +#guard_msgs in +#eval (ido IxMonad.pure 42 : IState Nat Nat Nat) 10 + +/-! ### 2. Monadic `let ← ` -/ + +/-- info: (11, 10) -/ +#guard_msgs in +#eval (ido do + let x ← getN + IxMonad.pure (x + 1) : IState Nat Nat Nat) 10 + +/-! ### 3. Plain `let :=` -/ + +/-- info: (20, 10) -/ +#guard_msgs in +#eval (ido do + let x := 10 + let y ← getN + IxMonad.pure (x + y) : IState Nat Nat Nat) 10 + +/-! ### 4. Sequential unit-typed element -/ + +/-- info: (11, 11) -/ +#guard_msgs in +#eval (ido do + modifyN (· + 1) + getN : IState Nat Nat Nat) 10 + +/-! ### 5. Multi-step chain -/ + +/-- info: ((10, 11), 11) -/ +#guard_msgs in +#eval (ido do + let a ← getN + modifyN (· + 1) + let b ← getN + IxMonad.pure (a, b) : IState Nat Nat (Nat × Nat)) 10 + +/-! ### 6. Nested `(← …)` in pure context -/ + +/-- info: (11, 10) -/ +#guard_msgs in +#eval (ido IxMonad.pure ((← getN) + 1) : IState Nat Nat Nat) 10 + +/-! ### 7. Nested `(← …)` appearing twice in one expression -/ + +/-- info: (20, 10) -/ +#guard_msgs in +#eval (ido IxMonad.pure ((← getN) + (← getN)) : IState Nat Nat Nat) 10 + +/-! ### 8. `if/then/else` with do branches -/ + +/-- info: (5, 5) -/ +#guard_msgs in +#eval (ido do + let x ← getN + if x > 0 then IxMonad.pure x else IxMonad.pure 0 : IState Nat Nat Nat) 5 + +/-- info: (0, 0) -/ +#guard_msgs in +#eval (ido do + let x ← getN + if x > 0 then IxMonad.pure x else IxMonad.pure 0 : IState Nat Nat Nat) 0 + +/-! ### 9. `if` with a lifted action in the condition -/ + +/-- info: ((), 4) -/ +#guard_msgs in +#eval (ido do + if (← getN) > 0 then modifyN (· - 1) else IxMonad.pure () : IState Nat Nat Unit) 5 + +/-- info: ((), 0) -/ +#guard_msgs in +#eval (ido do + if (← getN) > 0 then modifyN (· - 1) else IxMonad.pure () : IState Nat Nat Unit) 0 + +/-! ### 10. `match` dispatching into do blocks -/ + +/-- info: (100, 7) -/ +#guard_msgs in +#eval (ido do + match (← getN) with + | 0 => IxMonad.pure 0 + | _ => IxMonad.pure 100 : IState Nat Nat Nat) 7 + +/-! ### 11. Pattern `let` -/ + +/-- info: (3, 0) -/ +#guard_msgs in +#eval (ido do + let (a, b) ← IxMonad.pure (1, 2) + IxMonad.pure (a + b) : IState Nat Nat Nat) 0 + +/-! ### 12. Nested `ido` inside `ido` -/ + +/-- info: (42, 0) -/ +#guard_msgs in +#eval (ido do + let y ← (ido IxMonad.pure 42 : IState Nat Nat Nat) + IxMonad.pure y : IState Nat Nat Nat) 0 + +/-! ### 13. `ido` composing with ordinary `do` -/ + +/-- info: 84 -/ +#guard_msgs in +#eval Id.run do + let (n, _) := (ido IxMonad.pure 42 : IState Nat Nat Nat) 0 + pure (n * 2) + +/-! ### 14. `pure e >>= pure` peephole — confirms the generated term has no redundant + `IxMonad.bind`. + +The equation `(ido do let x ← IxMonad.pure 17; IxMonad.pure x) = IxMonad.pure 17` holds +definitionally only if the peephole in `mkBindUnlessPure` fired and contracted the bind +away, emitting a bare `IxMonad.pure 17`. If the peephole failed, the result would be +`IxMonad.bind (IxMonad.pure 17) IxMonad.pure`, which is not definitionally equal to +`IxMonad.pure 17` because `IxMonad` is a plain `class` without beta-reduction laws. -/ + +example : (ido do + let x ← IxMonad.pure 17 + IxMonad.pure x : IState Nat Nat Nat) = IxMonad.pure 17 := rfl + +/-! ### 15. Deeper chains of binds -/ + +/-- info: (6, 10) -/ +#guard_msgs in +#eval (ido do + let a ← IxMonad.pure 1 + let b ← IxMonad.pure 2 + let c ← IxMonad.pure 3 + IxMonad.pure (a + b + c) : IState Nat Nat Nat) 10 + +/-! ### 16. Index-preserving monad with a different fixed state type -/ + +/-- info: ("hi there", "hi") -/ +#guard_msgs in +#eval (ido do + let s ← (fun (σ : String) => (σ, σ) : IState String String String) + IxMonad.pure (s ++ " there") : IState String String String) "hi"