feat: new do elaborator, part 1: doElem_elab attribute (#11150)
This PR adds a new, inactive and unused `doElem_elab` attribute that will allow users to register custom elaborators for `doElem`s in the form of the new type `DoElab`. The old `do` elaborator is active by default but can be switched off by disabling the new option `backward.do.legacy`.
This commit is contained in:
parent
d464b13569
commit
09cf07b71c
11 changed files with 2756 additions and 1832 deletions
|
|
@ -148,6 +148,23 @@ This is the inverse of `ExceptT.mk`.
|
|||
@[always_inline, inline, expose]
|
||||
def ExceptT.run {ε : Type u} {m : Type u → Type v} {α : Type u} (x : ExceptT ε m α) : m (Except ε α) := x
|
||||
|
||||
/--
|
||||
Use a monadic action that may throw an exception by providing explicit success and failure
|
||||
continuations.
|
||||
-/
|
||||
@[always_inline, inline, expose]
|
||||
def ExceptT.runK [Monad m] (x : ExceptT ε m α) (ok : α → m β) (error : ε → m β) : m β :=
|
||||
x.run >>= (·.casesOn error ok)
|
||||
|
||||
/--
|
||||
Returns the value of a computation, forgetting whether it was an exception or a success.
|
||||
|
||||
This corresponds to early return.
|
||||
-/
|
||||
@[always_inline, inline, expose]
|
||||
def ExceptT.runCatch [Monad m] (x : ExceptT α m α) : m α :=
|
||||
x.runK pure pure
|
||||
|
||||
namespace ExceptT
|
||||
|
||||
variable {ε : Type u} {m : Type u → Type v} [Monad m]
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
743
src/Lean/Elab/Do/Basic.lean
Normal file
743
src/Lean/Elab/Do/Basic.lean
Normal file
|
|
@ -0,0 +1,743 @@
|
|||
/-
|
||||
Copyright (c) 2025 Lean FRO LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sebastian Graf
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Lean.Elab.Term.TermElabM
|
||||
public import Lean.Elab.Binders
|
||||
import Lean.Meta.ProdN
|
||||
meta import Lean.Parser.Do
|
||||
|
||||
public section
|
||||
|
||||
namespace Lean.Elab.Do
|
||||
|
||||
open Lean Meta
|
||||
|
||||
builtin_initialize registerTraceClass `Elab.do
|
||||
|
||||
structure MonadInfo where
|
||||
/-- The inferred type of the monad of type `Type u → Type v`. -/
|
||||
m : Expr
|
||||
/-- The `u` in `m : Type u → Type v`. -/
|
||||
u : Level
|
||||
/-- The `v` in `m : Type u → Type v`. -/
|
||||
v : Level
|
||||
/-- The cached `PUnit` expression. -/
|
||||
cachedPUnit : Expr :=
|
||||
if u matches .zero then mkConst ``Unit else mkConst ``PUnit [mkLevelSucc u]
|
||||
/-- The cached `PUnit.unit` expression. -/
|
||||
cachedPUnitUnit : Expr :=
|
||||
if u matches .zero then mkConst ``Unit.unit else mkConst ``PUnit.unit [mkLevelSucc u]
|
||||
|
||||
-- Same pattern as for `Methods`/`MethodsRef` in `SimpM`.
|
||||
private opaque ContInfoRefPointed : NonemptyType.{0}
|
||||
|
||||
def ContInfoRef : Type := ContInfoRefPointed.type
|
||||
|
||||
instance : Nonempty ContInfoRef :=
|
||||
by exact ContInfoRefPointed.property
|
||||
|
||||
structure Context where
|
||||
/-- Inferred and cached information about the monad. -/
|
||||
monadInfo : MonadInfo
|
||||
/-- The mutable variables in declaration order. -/
|
||||
mutVars : Array Name := #[]
|
||||
/--
|
||||
The expected type of the current `do` block.
|
||||
This can be different from `earlyReturnType` in `for` loop `do` blocks, for example.
|
||||
-/
|
||||
doBlockResultType : Expr
|
||||
/-- Information about `return`, `break` and `continue` continuations. -/
|
||||
contInfo : ContInfoRef
|
||||
|
||||
structure MonadInstanceCache where
|
||||
/-- The inferred `Pure` instance of `(← read).monadInfo.m`. -/
|
||||
instPure : Option Expr := none
|
||||
/-- The inferred `Bind` instance of `(← read).monadInfo.m`. -/
|
||||
instBind : Option Expr := none
|
||||
/-- The cached `Pure.pure` expression. -/
|
||||
cachedPure : Option Expr := none
|
||||
/-- The cached `Bind.bind` expression. -/
|
||||
cachedBind : Option Expr := none
|
||||
deriving Nonempty
|
||||
|
||||
/--
|
||||
A continuation metavariable.
|
||||
|
||||
When generating jumps to join points or filling in expressions for `break` or `continue`, it is
|
||||
still unclear what mutable variables need to be passed, because it depends on which mutable
|
||||
variables were reassigned in the control flow path to *any* of the jumps.
|
||||
|
||||
The mechanism of `ContVarId` allows to delay the assignment of the jump expressions until the local
|
||||
contexts of all the jumps are known.
|
||||
-/
|
||||
structure ContVarId where
|
||||
name : Name
|
||||
deriving Inhabited, BEq, Hashable
|
||||
|
||||
/--
|
||||
Information about a jump site associated to `ContVarId`.
|
||||
There will be one instance per jump site to a join point, or for each `break` or `continue`
|
||||
element.
|
||||
-/
|
||||
structure ContVarJump where
|
||||
/--
|
||||
The metavariable to be assigned with the jump to the join point.
|
||||
Conveniently, its captured local context is that of the jump, in which the new mutable variable
|
||||
definitions and result variable are in scope.
|
||||
-/
|
||||
mvar : Expr
|
||||
/-- A reference for error reporting. -/
|
||||
ref : Syntax
|
||||
|
||||
/--
|
||||
Information about a `ContVarId`.
|
||||
-/
|
||||
structure ContVarInfo where
|
||||
/-- The monadic type of the continuation. -/
|
||||
type : Expr
|
||||
/--
|
||||
A superset of the local variable names that the jumps will refer to. Often the `mut` variables.
|
||||
Any `let`-bound FV will be turned into a `have`-bound FV by setting their `nondep` flag in the
|
||||
local context of the metavariable for the jump site. This is a technicality to ensure that
|
||||
`isDefEq` will not inline the `let`s.
|
||||
-/
|
||||
tunneledVars : Std.HashSet Name
|
||||
/-- Local context at the time the continuation variable was created. -/
|
||||
lctx : LocalContext
|
||||
/-- The tracked jumps to the continuation. Each contains a metavariable to be assigned later. -/
|
||||
jumps : Array ContVarJump
|
||||
|
||||
structure State where
|
||||
monadInstanceCache : MonadInstanceCache := {}
|
||||
contVars : Std.HashMap ContVarId ContVarInfo := {}
|
||||
deriving Nonempty
|
||||
|
||||
abbrev DoElabM := ReaderT Context <| StateRefT State Term.TermElabM
|
||||
|
||||
/--
|
||||
Elaboration of a `do` block `do $e; $rest`, results in a call
|
||||
``elabTerm `(do $e; $rest) = elabElem e dec``, where `elabElem e ·` is the elaborator for `do`
|
||||
element `e`, and `dec` is the `DoElemCont` describing the elaboration of the rest of the block
|
||||
`rest`.
|
||||
|
||||
If the semantics of `e` resumes its continuation `rest`, its elaborator must bind its result to
|
||||
`resultName`, ensure that it has type `resultType` and then elaborate `rest` using `dec`.
|
||||
|
||||
Clearly, for term elements `e : m α`, the result has type `α`.
|
||||
More subtly, for binding elements `let x := e` or `let x ← e`, the result has type `PUnit` and is
|
||||
unrelated to the type of the bound variable `x`.
|
||||
|
||||
Examples:
|
||||
* `return` drops the continuation; `return x; pure ()` elaborates to `pure x`.
|
||||
* `let x ← e; rest x` elaborates to `e >>= fun x => rest x`.
|
||||
* `let x := 3; let y ← (let x ← e); rest x` elaborates to
|
||||
`let x := 3; e >>= fun x_1 => let y := (); rest x`, which is immediately zeta-reduced to
|
||||
`let x := 3; e >>= fun x_1 => rest x`.
|
||||
* `one; two` elaborates to `one >>= fun (_ : PUnit) => two`; it is an error if `one` does not have
|
||||
type `PUnit`.
|
||||
-/
|
||||
structure DoElemCont where
|
||||
/-- The name of the monadic result variable. -/
|
||||
resultName : Name
|
||||
/-- The type of the monadic result. -/
|
||||
resultType : Expr
|
||||
/-- The continuation to elaborate the `rest` of the block. -/
|
||||
k : DoElabM Expr
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
The type of elaborators for `do` block elements.
|
||||
|
||||
It is ``elabTerm `(do $e; $rest) = elabElem e dec``, where `elabElem e ·` is the elaborator for `do`
|
||||
element `e`, and `dec` is the `DoElemCont` describing the elaboration of the rest of the block
|
||||
`rest`.
|
||||
-/
|
||||
abbrev DoElab := TSyntax `doElem → DoElemCont → DoElabM Expr
|
||||
|
||||
/--
|
||||
Information about a success, `return`, `break` or `continue` continuation that will be filled in
|
||||
after the code using it has been elaborated.
|
||||
-/
|
||||
structure ContInfo where
|
||||
returnCont : DoElemCont
|
||||
breakCont : Option (DoElabM Expr) := none
|
||||
continueCont : Option (DoElabM Expr) := none
|
||||
deriving Inhabited
|
||||
|
||||
unsafe def ContInfo.toContInfoRefImpl (m : ContInfo) : ContInfoRef :=
|
||||
unsafeCast m
|
||||
|
||||
@[implemented_by ContInfo.toContInfoRefImpl]
|
||||
opaque ContInfo.toContInfoRef (m : ContInfo) : ContInfoRef
|
||||
|
||||
unsafe def ContInfoRef.toContInfoImpl (m : ContInfoRef) : ContInfo :=
|
||||
unsafeCast m
|
||||
|
||||
@[implemented_by ContInfoRef.toContInfoImpl]
|
||||
opaque ContInfoRef.toContInfo (m : ContInfoRef) : ContInfo
|
||||
|
||||
/-- Constructs `m α` from `α`. -/
|
||||
def mkMonadicType (resultType : Expr) : DoElabM Expr := do
|
||||
return mkApp (← read).monadInfo.m resultType
|
||||
|
||||
/-- The cached `PUnit` expression. -/
|
||||
def mkPUnit : DoElabM Expr := do
|
||||
return (← read).monadInfo.cachedPUnit
|
||||
|
||||
/-- The cached ``PUnit.unit`` expression. -/
|
||||
def mkPUnitUnit : DoElabM Expr := do
|
||||
return (← read).monadInfo.cachedPUnitUnit
|
||||
|
||||
/-- The cached `@Pure.pure m instPure` expression. -/
|
||||
private def getCachedPure : DoElabM Expr := do
|
||||
let s ← get
|
||||
if let some cachedPure := s.monadInstanceCache.cachedPure then return cachedPure
|
||||
let info := (← read).monadInfo
|
||||
let instPure ← Term.mkInstMVar (mkApp (mkConst ``Pure [info.u, info.v]) info.m)
|
||||
let cachedPure := mkApp2 (mkConst ``Pure.pure [info.u, info.v]) info.m instPure
|
||||
set { s with monadInstanceCache := { s.monadInstanceCache with cachedPure := some cachedPure } : State}
|
||||
return cachedPure
|
||||
|
||||
/-- The expression ``pure (α:=α) e``. -/
|
||||
def mkPureApp (α e : Expr) : DoElabM Expr := do
|
||||
let e ← Term.ensureHasType α e
|
||||
return mkApp2 (← getCachedPure) α e
|
||||
|
||||
/-- Create a `DoElemCont` returning the result using `pure`. -/
|
||||
def DoElemCont.mkPure (resultType : Expr) : TermElabM DoElemCont := do
|
||||
let r ← mkFreshUserName `r
|
||||
return { resultName := r, resultType, k := do mkPureApp resultType (← getFVarFromUserName r) }
|
||||
|
||||
/-- The cached `@Bind.bind m instBind` expression. -/
|
||||
private def getCachedBind : DoElabM Expr := do
|
||||
let s ← get
|
||||
if let some cachedBind := s.monadInstanceCache.cachedBind then return cachedBind
|
||||
let info := (← read).monadInfo
|
||||
let instBind ← Term.mkInstMVar (mkApp (mkConst ``Bind [info.u, info.v]) info.m)
|
||||
let cachedBind := mkApp2 (mkConst ``Bind.bind [info.u, info.v]) info.m instBind
|
||||
set { s with monadInstanceCache := { s.monadInstanceCache with cachedBind := some cachedBind } : State}
|
||||
return cachedBind
|
||||
|
||||
/-- The expression ``Bind.bind (α:=α) (β:=β) e k``. -/
|
||||
def mkBindApp (α β e k : Expr) : DoElabM Expr := do
|
||||
let mα ← mkMonadicType α
|
||||
let e ← Term.ensureHasType mα e
|
||||
let k ← Term.ensureHasType (← mkArrow α (← mkMonadicType β)) k
|
||||
let cachedBind ← getCachedBind
|
||||
return mkApp4 cachedBind α β e k
|
||||
|
||||
/-- Register the given name as that of a `mut` variable. -/
|
||||
def declareMutVar (x : Name) : DoElabM α → DoElabM α :=
|
||||
withReader fun ctx => { ctx with mutVars := ctx.mutVars.push x }
|
||||
|
||||
/-- Register the given name as that of a `mut` variable if the syntax token `mut` is present. -/
|
||||
def declareMutVar? (mutTk? : Option Syntax) (x : Name) (k : DoElabM α) : DoElabM α :=
|
||||
if mutTk?.isSome then declareMutVar x k else k
|
||||
|
||||
/-- Throw an error if the given name is not a declared `mut` variable. -/
|
||||
def throwUnlessMutVarDeclared (x : Name) : DoElabM Unit := do
|
||||
unless (← read).mutVars.contains x do
|
||||
throwError "undeclared mutable variable `{x}`"
|
||||
|
||||
/-- Throw an error if a declaration of the given name would shadow a `mut` variable. -/
|
||||
def checkMutVarsForShadowing (x : Name) : DoElabM Unit := do
|
||||
if (← read).mutVars.contains x then
|
||||
throwError "mutable variable `{x.simpMacroScopes}` cannot be shadowed"
|
||||
|
||||
/-- Create a fresh `α` that would fit in `m α`. -/
|
||||
def mkFreshResultType (userName := `α) : DoElabM Expr := do
|
||||
mkFreshExprMVar (mkSort (mkLevelSucc (← read).monadInfo.u)) (userName := userName)
|
||||
|
||||
def synthUsingDefEq (msg : String) (expected : Expr) (actual : Expr) : DoElabM Unit := do
|
||||
unless ← isDefEq expected actual do
|
||||
throwError "Failed to synthesize {msg}. {expected} is not definitionally equal to {actual}."
|
||||
|
||||
/--
|
||||
Has the effect of ``e >>= fun (x : eResultTy) => $(← k `(x))``.
|
||||
Ensures that `e` has type `m eResultTy`.
|
||||
-/
|
||||
def mkBindCancellingPure (x : Name) (eResultTy e : Expr) (k : Expr → DoElabM Expr) : DoElabM Expr := do
|
||||
withLocalDeclD x eResultTy fun x => do
|
||||
let body ← k x
|
||||
let body' := body.consumeMData
|
||||
if body'.isAppOfArity ``Pure.pure 4 && body'.getArg! 3 == x then
|
||||
return e
|
||||
let kResultTy ← mkFreshResultType `kResultTy
|
||||
let k ← mkLambdaFVars #[x] body
|
||||
mkBindApp eResultTy kResultTy e k
|
||||
|
||||
/--
|
||||
A variant of `Term.elabType` that takes the universe of the monad into account, unless
|
||||
`freshLevel` is set.
|
||||
-/
|
||||
def elabType (ty? : Option (TSyntax `term)) (freshLevel := false) : DoElabM Expr := do
|
||||
let u ← if freshLevel then mkFreshLevelMVar else (mkLevelSucc ·.monadInfo.u) <$> read
|
||||
let sort := mkSort u
|
||||
match ty? with
|
||||
| none => mkFreshExprMVar sort
|
||||
| some ty => Term.elabTermEnsuringType ty sort
|
||||
|
||||
private partial def withPendingMVars (k : TermElabM α) : TermElabM (α × List MVarId) := do
|
||||
let pendingMVarsSaved := (← get).pendingMVars
|
||||
modify fun s => { s with pendingMVars := [] }
|
||||
try
|
||||
let a ← k
|
||||
let pendingMVars := (← get).pendingMVars
|
||||
return (a, pendingMVars)
|
||||
finally
|
||||
modify fun s => { s with pendingMVars := s.pendingMVars ++ pendingMVarsSaved }
|
||||
|
||||
def elabTerm (stx : Syntax) (expectedType? : Option Expr) : DoElabM Expr := do
|
||||
let (e, _pendingMVars) ← withPendingMVars <| Term.elabTerm stx expectedType?
|
||||
-- for mvarId in pendingMVars.reverse do
|
||||
-- let some mvarDecl ← Term.getSyntheticMVarDecl? mvarId | continue
|
||||
-- let .postponed _ := mvarDecl.kind | continue
|
||||
-- match mvarDecl.stx with
|
||||
-- | `(<== $e) => logInfo m!"Elaborate {e}"
|
||||
-- | _ => continue
|
||||
return e
|
||||
|
||||
def elabTermEnsuringType (stx : Syntax) (expectedType? : Option Expr) : DoElabM Expr := do
|
||||
let e ← Term.elabTermEnsuringType stx expectedType?
|
||||
-- nandle nested actions
|
||||
return e
|
||||
|
||||
def elabBinder (binder : Syntax) (x : Expr → DoElabM α) : DoElabM α := do
|
||||
controlAt TermElabM fun runInBase => Term.elabBinder binder (runInBase ∘ x)
|
||||
|
||||
/--
|
||||
The subset of `mutVars` that were reassigned in any of the `childCtxs` relative to the given
|
||||
`rootCtx`.
|
||||
-/
|
||||
def getReassignedMutVars (rootCtx : LocalContext) (mutVars : Std.HashSet Name) (childCtxs : Array LocalContext) : Std.HashSet Name := Id.run do
|
||||
let mut reassignedMutVars := Std.HashSet.emptyWithCapacity mutVars.size
|
||||
for childCtx in childCtxs do
|
||||
let newDefs := childCtx.findFromUserNames mutVars (start := rootCtx.numIndices)
|
||||
reassignedMutVars := reassignedMutVars.insertMany (newDefs.map (·.userName))
|
||||
return reassignedMutVars
|
||||
|
||||
/--
|
||||
Adds the new reaching definitions of the given `tunneledVars` in `childCtx` relative to `rootCtx` as
|
||||
non-dependent decls.
|
||||
-/
|
||||
def addReachingDefsAsNonDep (rootCtx childCtx : LocalContext) (tunneledVars : Std.HashSet Name) : LocalContext := Id.run do
|
||||
let tunnelDecls := childCtx.findFromUserNames tunneledVars (start := rootCtx.numIndices)
|
||||
let mut rootCtx := rootCtx
|
||||
for decl in tunnelDecls do
|
||||
rootCtx := rootCtx.addDecl (decl.setNondep true)
|
||||
return rootCtx
|
||||
|
||||
/--
|
||||
Creates a new continuation variable of type `m α` given the result type `α`.
|
||||
The `tunneledVars` is a superset of the `let`-bound variable names that the jumps will refer to.
|
||||
Often it will be the `mut` variables. Leaving it empty inlines `let`-bound variables at jump sites.
|
||||
-/
|
||||
def mkFreshContVar (resultType : Expr) (tunneledVars : Array Name) : DoElabM ContVarId := do
|
||||
let name ← mkFreshId
|
||||
let contVarId := ContVarId.mk name
|
||||
let type ← mkMonadicType resultType
|
||||
let tunneledVars := Std.HashSet.ofArray tunneledVars
|
||||
let cvInfo := { type, jumps := #[], lctx := (← getLCtx), tunneledVars }
|
||||
modify fun s => { s with contVars := s.contVars.insert contVarId cvInfo }
|
||||
return contVarId
|
||||
|
||||
def ContVarId.find (contVarId : ContVarId) : DoElabM ContVarInfo := do
|
||||
match (← get).contVars.get? contVarId with
|
||||
| some info => return info
|
||||
| none => throwError "contVarId {contVarId.name} not found"
|
||||
|
||||
/-- Creates a new jump site for the continuation variable, to be synthesized later. -/
|
||||
def ContVarId.mkJump (contVarId : ContVarId) : DoElabM Expr := do
|
||||
let info ← contVarId.find
|
||||
let lctx := addReachingDefsAsNonDep info.lctx (← getLCtx) info.tunneledVars
|
||||
let mvar ← withLCtx' lctx (mkFreshExprMVar info.type)
|
||||
let jumps := info.jumps.push { mvar, ref := (← getRef) }
|
||||
modify fun s => { s with contVars := s.contVars.insert contVarId { info with jumps } }
|
||||
return mvar
|
||||
|
||||
/-- The number of jump sites allocated for the continuation variable. -/
|
||||
def ContVarId.jumpCount (contVarId : ContVarId) : DoElabM Nat := do
|
||||
let info ← contVarId.find
|
||||
return info.jumps.size
|
||||
|
||||
/--
|
||||
Synthesize the jump sites for the continuation variable.
|
||||
`k` is run once for each jump site, in the `LocalContext` of the jump site.
|
||||
The result of `k` is used to fill in the jump site.
|
||||
-/
|
||||
def ContVarId.synthesizeJumps (contVarId : ContVarId) (k : DoElabM Expr) : DoElabM Unit := do
|
||||
let info ← contVarId.find
|
||||
for jump in info.jumps do
|
||||
jump.mvar.mvarId!.withContext do withRef jump.ref do
|
||||
let res ← k
|
||||
fullApproxDefEq <| synthUsingDefEq "jump site" jump.mvar res
|
||||
|
||||
def ContVarId.erase (contVarId : ContVarId) : DoElabM Unit := do
|
||||
modify fun s => { s with contVars := s.contVars.erase contVarId }
|
||||
|
||||
/--
|
||||
The subset of `(← read).mutVars` that were reassigned at any of the jump sites of the continuation
|
||||
variable. The result array has the same order as `(← read).mutVars`.
|
||||
-/
|
||||
def ContVarId.getReassignedMutVars (contVarId : ContVarId) (rootCtx : LocalContext) : DoElabM (Std.HashSet Name) := do
|
||||
let info ← contVarId.find
|
||||
let childCtxs ← info.jumps.mapM fun j => return (← j.mvar.mvarId!.getDecl).lctx
|
||||
return Lean.Elab.Do.getReassignedMutVars rootCtx (.ofArray (← read).mutVars) childCtxs
|
||||
|
||||
/--
|
||||
Restores the local context to `oldCtx` and adds the new reaching definitions of the mut vars and
|
||||
result. Then resume the continuation `k` with the `mutVars` restored to the given `oldMutVars`.
|
||||
|
||||
This function is useful to de-nest
|
||||
```
|
||||
let mut x := 0
|
||||
let y := 3
|
||||
let z ← do
|
||||
let mut y ← e
|
||||
x := y + 1
|
||||
pure y
|
||||
let y := y + 3
|
||||
pure (x + y + z)
|
||||
```
|
||||
into
|
||||
```
|
||||
let mut x := 0
|
||||
let y := 3
|
||||
let mut y† ← e
|
||||
x := y† + 1
|
||||
let z ← pure y†
|
||||
let y := y + 3
|
||||
pure (x + y + z)
|
||||
```
|
||||
Note that the continuation of the `let z ← ...` bind, roughly
|
||||
``k := .cont `z _ `(let y := y + 3; pure (x + y + z))``,
|
||||
needs to elaborated in a local context that contains the reassignment of `x`, but not the shadowing
|
||||
mut var definition of `y`.
|
||||
-/
|
||||
def withLCtxKeepingMutVarDefs (oldCtx : LocalContext) (oldMutVars : Array Name) (resultName : Name) (k : DoElabM α) : DoElabM α := do
|
||||
let newCtx := addReachingDefsAsNonDep oldCtx (← getLCtx) (.ofArray <| oldMutVars.push resultName)
|
||||
withLCtx' newCtx <| withReader (fun ctx => { ctx with mutVars := oldMutVars }) k
|
||||
|
||||
/--
|
||||
Return `$e >>= fun ($dec.resultName : $dec.resultType) => $(← dec.k)`, cancelling
|
||||
the bind if `$(← dec.k)` is `pure $dec.resultName`.
|
||||
-/
|
||||
def DoElemCont.mkBindUnlessPure (dec : DoElemCont) (e : Expr) : DoElabM Expr := do
|
||||
mkBindCancellingPure dec.resultName dec.resultType e (fun _ => dec.k)
|
||||
|
||||
/--
|
||||
Return `let $k.resultName : PUnit := PUnit.unit; $(← k.k)`, ensuring that the result type of `k.k`
|
||||
is `PUnit` and then immediately zeta-reduce the `let`.
|
||||
-/
|
||||
def DoElemCont.continueWithUnit (dec : DoElemCont) : DoElabM Expr := do
|
||||
let unit ← mkPUnitUnit
|
||||
discard <| Term.ensureHasType dec.resultType unit
|
||||
mapLetDeclZeta dec.resultName (← mkPUnit) unit (fun _ => dec.k)
|
||||
|
||||
/--
|
||||
Call `caller` with a duplicable proxy of `dec`.
|
||||
When the proxy is elaborated more than once, a join point is introduced so that `dec` is only
|
||||
elaborated once to fill in the RHS of this join point.
|
||||
|
||||
This is useful for control-flow constructs like `if` and `match`, where multiple tail-called
|
||||
branches share the continuation.
|
||||
-/
|
||||
def DoElemCont.withDuplicableCont (nondupDec : DoElemCont) (caller : DoElemCont → DoElabM Expr) : DoElabM Expr := do
|
||||
let α := (← read).doBlockResultType
|
||||
let mα ← mkMonadicType α
|
||||
let joinTy ← mkFreshExprMVar (mkSort (mkLevelSucc (← read).monadInfo.v)) (userName := `joinTy)
|
||||
let joinRhs ← mkFreshExprMVar joinTy (userName := `joinRhs)
|
||||
withLetDecl (← mkFreshUserName `__do_jp) joinTy joinRhs (kind := .implDetail) (nondep := true) fun jp => do
|
||||
let mutVars := (← read).mutVars
|
||||
let contVarId ← mkFreshContVar α (mutVars.push nondupDec.resultName)
|
||||
let duplicableDec := { nondupDec with k := contVarId.mkJump }
|
||||
let e ← caller duplicableDec
|
||||
|
||||
-- Now determine whether we need to realize the join point.
|
||||
let jumpCount ← contVarId.jumpCount
|
||||
if jumpCount = 0 then
|
||||
-- Do nothing. No MVar needs to be assigned.
|
||||
Term.ensureHasType mα e
|
||||
else if jumpCount = 1 then
|
||||
-- Linear use of the continuation. Do not introduce a join point; just emit the continuation
|
||||
-- directly.
|
||||
contVarId.synthesizeJumps nondupDec.k
|
||||
let e ← Term.ensureHasType mα e
|
||||
-- Now zeta-reduce `jp`. Should be a semantic no-op.
|
||||
let e ← elimMVarDeps #[jp] e
|
||||
return e.replaceFVar jp joinRhs
|
||||
else -- jumps.size > 1
|
||||
-- Non-linear use of the continuation. Introduce a join point and synthesize jumps to it.
|
||||
|
||||
-- Compute the union of all reassigned mut vars. These + `r` constitute the parameters
|
||||
-- of the join point. We take a little care to preserve the declaration order that is manifest
|
||||
-- in the array `(← read).mutVars`.
|
||||
let reassignedMutVars ← contVarId.getReassignedMutVars (← joinRhs.mvarId!.getDecl).lctx
|
||||
let reassignedMutVars := mutVars.filter reassignedMutVars.contains
|
||||
|
||||
-- Assign the `joinTy` based on the types of the reassigned mut vars and the result type.
|
||||
let reassignedDecls ← reassignedMutVars.mapM (getLocalDeclFromUserName ·)
|
||||
let reassignedTys := reassignedDecls.map (·.type)
|
||||
let resTy ← mkFreshResultType
|
||||
let joinTy' ← mkArrowN (reassignedTys.push resTy) mα
|
||||
synthUsingDefEq "join point type" joinTy joinTy'
|
||||
|
||||
-- Assign the `joinRhs` with the result of the continuation.
|
||||
let rhs ← joinRhs.mvarId!.withContext do
|
||||
withLocalDeclsDND (reassignedDecls.map (fun d => (d.userName, d.type)) |>.push (nondupDec.resultName, resTy)) fun xs => do
|
||||
mkLambdaFVars xs (← nondupDec.k)
|
||||
synthUsingDefEq "join point RHS" joinRhs rhs
|
||||
|
||||
-- Finally, assign the MVars with the jump to `jp`.
|
||||
contVarId.synthesizeJumps do
|
||||
let r ← getFVarFromUserName nondupDec.resultName
|
||||
let mut jump := jp
|
||||
for name in reassignedMutVars do
|
||||
let newDefn ← getLocalDeclFromUserName name
|
||||
jump := mkApp jump newDefn.toExpr
|
||||
return mkApp jump (← Term.ensureHasType resTy r "Mismatched result type for match arm. It")
|
||||
|
||||
mkLetFVars #[jp] (generalizeNondepLet := false) (← Term.ensureHasType mα e)
|
||||
|
||||
/--
|
||||
Given a list of mut vars `vars` and an FVar `tupleVar` binding a tuple, bind the mut vars to the
|
||||
fields of the tuple and call `k` in the resulting local context.
|
||||
-/
|
||||
def bindMutVarsFromTuple (vars : List Name) (tupleVar : FVarId) (k : DoElabM Expr) : DoElabM Expr :=
|
||||
do go vars tupleVar (← tupleVar.getType) #[]
|
||||
where
|
||||
go vars tupleVar tupleTy letFVars := do
|
||||
let tuple := mkFVar tupleVar
|
||||
match vars with
|
||||
| [] => mkLetFVars letFVars (← k)
|
||||
| [x] =>
|
||||
withLetDecl x tupleTy tuple fun x => do mkLetFVars (letFVars.push x) (← k)
|
||||
| [x, y] =>
|
||||
let (fst, fstTy, snd, sndTy) ← getProdFields tuple tupleTy
|
||||
withLetDecl x fstTy fst fun x =>
|
||||
withLetDecl y sndTy snd fun y => do mkLetFVars (letFVars.push x |>.push y) (← k)
|
||||
| x :: xs => do
|
||||
let (fst, fstTy, snd, sndTy) ← getProdFields tuple tupleTy
|
||||
withLetDecl x fstTy fst fun x => do
|
||||
withLetDecl (← tupleVar.getUserName) sndTy snd fun r => do
|
||||
go xs r.fvarId! sndTy (letFVars |>.push x |>.push r)
|
||||
|
||||
def getReturnCont : DoElabM DoElemCont := do
|
||||
return (← read).contInfo.toContInfo.returnCont
|
||||
|
||||
def getBreakCont : DoElabM (Option (DoElabM Expr)) := do
|
||||
return (← read).contInfo.toContInfo.breakCont
|
||||
|
||||
def getContinueCont : DoElabM (Option (DoElabM Expr)) := do
|
||||
return (← read).contInfo.toContInfo.continueCont
|
||||
|
||||
/--
|
||||
Introduce proxy redefinitions for *all* mut vars and call the continuation `k` with a function
|
||||
`elimProxyDefs : Expr → MetaM Expr` similar to `mkLetFVars` that will replace the proxy defs with
|
||||
the actual reassigned or original definitions.
|
||||
-/
|
||||
@[inline]
|
||||
def withProxyMutVarDefs [Inhabited α] (k : (Expr → MetaM Expr) → DoElabM α) : DoElabM α := do
|
||||
let mutVars := (← read).mutVars
|
||||
let outerCtx ← getLCtx
|
||||
let outerDecls := mutVars.map outerCtx.getFromUserName!
|
||||
withLocalDeclsDND (← outerDecls.mapM fun x => do return (x.userName, x.type)) (kind := .implDetail) fun proxyDefs => do
|
||||
let proxyCtx ← getLCtx
|
||||
let elimProxyDefs e : MetaM Expr := do
|
||||
let innerCtx ← getLCtx
|
||||
|
||||
let actualDefs := proxyDefs.map fun pDef =>
|
||||
let x := (proxyCtx.getFVar! pDef).userName
|
||||
let iDef := (innerCtx.getFromUserName! x).toExpr
|
||||
if iDef == pDef then
|
||||
(outerCtx.getFromUserName! x).toExpr -- original definition
|
||||
else
|
||||
iDef -- reassigned definition
|
||||
let e ← elimMVarDeps proxyDefs e
|
||||
return e.replaceFVars proxyDefs actualDefs
|
||||
k elimProxyDefs
|
||||
|
||||
/--
|
||||
Prepare the context for elaborating the body of a loop.
|
||||
This includes setting the return continuation, break continuation, continue continuation, as
|
||||
well as the changed result type of the `do` block in the loop body.
|
||||
-/
|
||||
def enterLoopBody (resultType : Expr) (returnCont : DoElemCont) (breakCont continueCont : DoElabM Expr) : (body : DoElabM α) → DoElabM α :=
|
||||
let contInfo := ContInfo.toContInfoRef { breakCont, continueCont, returnCont }
|
||||
withReader fun ctx => { ctx with contInfo, doBlockResultType := resultType }
|
||||
|
||||
/--
|
||||
Prepare the context for elaborating the body of a `do` block that does not support `mut` vars,
|
||||
`break`, `continue` or `return`.
|
||||
-/
|
||||
def withoutControl (k : DoElabM Expr) : DoElabM Expr := do
|
||||
let error := throwError "This `do` block does not support `break`, `continue` or `return`."
|
||||
let dec ← getReturnCont
|
||||
let contInfo := { breakCont := error, continueCont := error, returnCont := { dec with k := error }}
|
||||
let contInfo := ContInfo.toContInfoRef contInfo
|
||||
withReader (fun ctx => { ctx with contInfo }) k
|
||||
|
||||
/--
|
||||
Prepare the context for elaborating the body of a `finally` block.
|
||||
There is no support for `mut` vars, `break`, `continue` or `return` in a `finally` block.
|
||||
-/
|
||||
def enterFinally (resultType : Expr) (k : DoElabM Expr) : DoElabM Expr := do
|
||||
withoutControl do
|
||||
withReader (fun ctx => { ctx with doBlockResultType := resultType }) k
|
||||
|
||||
/-- Extracts `MonadInfo` and monadic result type `α` from the expected type of a `do` block `m α`. -/
|
||||
private partial def extractMonadInfo (expectedType? : Option Expr) : Term.TermElabM (MonadInfo × Expr) := do
|
||||
let some expectedType := expectedType? | mkUnknownMonadResult
|
||||
let extractStep? (type : Expr) : Term.TermElabM (Option (MonadInfo × Expr)) := do
|
||||
let .app m resultType := type.consumeMData | return none
|
||||
unless ← isType resultType do return none
|
||||
let .succ u ← getLevel resultType | return none
|
||||
let .succ v ← getLevel type | return none
|
||||
let u := u.normalize
|
||||
let v := v.normalize
|
||||
return some ({ m, u, v }, resultType)
|
||||
let rec extract? (type : Expr) : Term.TermElabM (Option (MonadInfo × Expr)) := do
|
||||
match (← extractStep? type) with
|
||||
| some r => return r
|
||||
| none =>
|
||||
let typeNew ← whnfCore type
|
||||
if typeNew != type then
|
||||
extract? typeNew
|
||||
else
|
||||
if typeNew.getAppFn.isMVar then
|
||||
mkUnknownMonadResult
|
||||
else match (← unfoldDefinition? typeNew) with
|
||||
| some typeNew => extract? typeNew
|
||||
| none => return none
|
||||
match (← extract? expectedType) with
|
||||
| some r => return r
|
||||
| none => throwError "invalid `do` notation, expected type is not a monad application{indentExpr expectedType}\nYou can use the `do` notation in pure code by writing `Id.run do` instead of `do`, where `Id` is the identity monad."
|
||||
where
|
||||
mkUnknownMonadResult : TermElabM (MonadInfo × Expr) := do
|
||||
let u ← mkFreshLevelMVar
|
||||
let v ← mkFreshLevelMVar
|
||||
let m ← mkFreshExprMVar (← mkArrow (mkSort (mkLevelSucc u)) (mkSort (mkLevelSucc v))) (userName := `m)
|
||||
let resultType ← mkFreshExprMVar (mkSort (mkLevelSucc u)) (userName := `α)
|
||||
return ({ m, u, v }, resultType)
|
||||
|
||||
/-- Create the `Context` for `do` elaboration from the given expected type of a `do` block. -/
|
||||
def mkContext (expectedType? : Option Expr) : TermElabM Context := do
|
||||
let (mi, resultType) ← extractMonadInfo expectedType?
|
||||
let returnCont ← DoElemCont.mkPure resultType
|
||||
let contInfo := ContInfo.toContInfoRef { returnCont }
|
||||
return { monadInfo := mi, doBlockResultType := resultType, contInfo }
|
||||
|
||||
/--
|
||||
Backtrackable state for the `TermElabM` monad.
|
||||
-/
|
||||
structure SavedState where
|
||||
«term» : Term.SavedState
|
||||
«do» : State
|
||||
deriving Nonempty
|
||||
|
||||
def SavedState.restore (s : SavedState) : DoElabM Unit := do
|
||||
s.term.restore
|
||||
set s.do
|
||||
|
||||
protected def DoElabM.saveState : DoElabM SavedState :=
|
||||
return { «term» := (← Term.saveState), «do» := (← get) }
|
||||
|
||||
instance : MonadBacktrack SavedState DoElabM where
|
||||
saveState := DoElabM.saveState
|
||||
restoreState b := b.restore
|
||||
|
||||
unsafe def mkDoElemElabAttributeUnsafe (ref : Name) : IO (KeyedDeclsAttribute DoElab) :=
|
||||
mkElabAttribute DoElab `builtin_doElem_elab `doElem_elab `Lean.Parser.Term.doElem ``Lean.Elab.Do.DoElab "do element" ref
|
||||
|
||||
@[implemented_by mkDoElemElabAttributeUnsafe]
|
||||
opaque mkDoElemElabAttribute (ref : Name) : IO (KeyedDeclsAttribute DoElab)
|
||||
|
||||
/--
|
||||
Registers a `do` element elaborator for the given syntax node kind.
|
||||
|
||||
A `do` element elaborator should have type `DoElab` (which is
|
||||
`Lean.Syntax → DoElemCont → DoElabM Expr`), i.e. should take syntax of the given syntax node kind
|
||||
and a `DoElemCont` as parameters and produce an expression.
|
||||
|
||||
When elaborating a `do` block `do e; rest`, the elaborator for `e` is invoked with the syntax of `e`
|
||||
and the `DoElemCont` representing `rest`.
|
||||
|
||||
The `elab_rules` and `elab` commands should usually be preferred over using this attribute
|
||||
directly.
|
||||
-/
|
||||
@[builtin_doc]
|
||||
builtin_initialize doElemElabAttribute : KeyedDeclsAttribute DoElab ← mkDoElemElabAttribute decl_name%
|
||||
|
||||
private def elabDoElemFns (stx : TSyntax `doElem) (cont : DoElemCont)
|
||||
(fns : List (KeyedDeclsAttribute.AttributeEntry DoElab)) : DoElabM Expr := do
|
||||
let s ← saveState
|
||||
match fns with
|
||||
| [] => throwError "unexpected `do` element syntax{indentD stx}"
|
||||
| elabFn :: elabFns =>
|
||||
try
|
||||
elabFn.value stx cont
|
||||
catch ex => match ex with
|
||||
| .internal id _ =>
|
||||
if id == unsupportedSyntaxExceptionId then
|
||||
s.restore
|
||||
elabDoElemFns stx cont elabFns
|
||||
else
|
||||
throw ex
|
||||
| _ => throw ex
|
||||
|
||||
partial def elabDoElem (stx : TSyntax `doElem) (cont : DoElemCont) : DoElabM Expr := do
|
||||
-- withTraceNode `Elab.step (fun _ => return m!"expected type: {expectedType?}, term\n{stx}")
|
||||
-- (tag := stx.getKind.toString) do
|
||||
let k := stx.raw.getKind
|
||||
checkSystem "do element elaborator"
|
||||
profileitM Exception "do element elaborator" (decl := k) (← getOptions) <|
|
||||
withRef stx <| withIncRecDepth <| withFreshMacroScope <| do
|
||||
let env ← getEnv
|
||||
let result ← match (← liftMacroM (expandMacroImpl? env stx)) with
|
||||
| some (_decl, stxNew?) =>
|
||||
let stxNew ← liftMacroM <| liftExcept stxNew?
|
||||
-- withTermInfoContext' decl stx (expectedType? := expectedType?) <|
|
||||
Term.withMacroExpansion stx stxNew <|
|
||||
withRef stxNew <| elabDoElem stx cont
|
||||
| none =>
|
||||
match doElemElabAttribute.getEntries (← getEnv) k with
|
||||
| [] => throwError "elaboration function for `{k}` has not been implemented{indentD stx}"
|
||||
| elabFns => elabDoElemFns stx cont elabFns
|
||||
|
||||
def elabDoElems1 (doElems : Array (TSyntax `doElem)) (cont : DoElemCont) : DoElabM Expr := do
|
||||
if h : doElems.size = 0 then
|
||||
throwError "Empty array of `do` elements passed to `elabDoElems1`."
|
||||
else
|
||||
let back := doElems.back
|
||||
let unit ← mkPUnit
|
||||
let r ← mkFreshUserName `r
|
||||
doElems.pop.foldr (init := elabDoElem back cont) fun el k => elabDoElem el (.mk r unit k)
|
||||
|
||||
def elabDoSeq (doSeq : TSyntax ``Lean.Parser.Term.doSeq) (cont : DoElemCont) : DoElabM Expr :=
|
||||
elabDoElems1 (Lean.Parser.Term.getDoElems doSeq) cont
|
||||
|
||||
syntax:arg (name := dooBlock) "doo" doSeq : term
|
||||
|
||||
@[builtin_term_elab «dooBlock»] def elabDooBlock : Term.TermElab := fun e expectedType? => do
|
||||
let `(doo $doSeq) := e | throwError "unexpected `do` block syntax{indentD e}"
|
||||
Term.tryPostponeIfNoneOrMVar expectedType?
|
||||
let ctx ← mkContext expectedType?
|
||||
let cont ← DoElemCont.mkPure ctx.doBlockResultType
|
||||
let res ← elabDoSeq doSeq cont |>.run ctx |>.run' {}
|
||||
trace[Elab.do] "{res}"
|
||||
pure res
|
||||
|
||||
-- @[builtin_term_elab «do»]
|
||||
def elabDo : Term.TermElab := fun e expectedType? => do
|
||||
let `(do $doSeq) := e | throwError "unexpected `do` block syntax{indentD e}"
|
||||
Term.tryPostponeIfNoneOrMVar expectedType?
|
||||
let ctx ← mkContext expectedType?
|
||||
let cont ← DoElemCont.mkPure ctx.doBlockResultType
|
||||
let res ← elabDoSeq doSeq cont |>.run ctx |>.run' {}
|
||||
trace[Elab.do] "{res}"
|
||||
pure res
|
||||
1812
src/Lean/Elab/Do/Legacy.lean
Normal file
1812
src/Lean/Elab/Do/Legacy.lean
Normal file
File diff suppressed because it is too large
Load diff
52
src/Lean/Elab/Do/Switch.lean
Normal file
52
src/Lean/Elab/Do/Switch.lean
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
/-
|
||||
Copyright (c) 2025 Lean FRO LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sebastian Graf
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.System.IO
|
||||
public import Lean.Data.Options
|
||||
public import Lean.Elab.Term.TermElabM
|
||||
import Lean.Elab.Do.Basic
|
||||
import Lean.Elab.Do.Legacy
|
||||
meta import Lean.Parser.Do
|
||||
|
||||
public section
|
||||
|
||||
namespace Lean.Elab.Term
|
||||
|
||||
register_builtin_option backward.do.legacy : Bool := {
|
||||
defValue := true
|
||||
descr := "Use the legacy `do` elaborator instead of the new, extensible implementation."
|
||||
}
|
||||
|
||||
private def toDoElem (newKind : SyntaxNodeKind) : Macro := fun stx => do
|
||||
let stx := stx.setKind newKind
|
||||
withRef stx `(do $(⟨stx⟩):doElem)
|
||||
|
||||
@[builtin_macro Lean.Parser.Term.termFor]
|
||||
def expandTermFor : Macro := toDoElem ``Parser.Term.doFor
|
||||
|
||||
@[builtin_macro Lean.Parser.Term.termTry]
|
||||
def expandTermTry : Macro := toDoElem ``Parser.Term.doTry
|
||||
|
||||
@[builtin_macro Lean.Parser.Term.termUnless]
|
||||
def expandTermUnless : Macro := toDoElem ``Parser.Term.doUnless
|
||||
|
||||
@[builtin_macro Lean.Parser.Term.termReturn]
|
||||
def expandTermReturn : Macro := toDoElem ``Parser.Term.doReturn
|
||||
|
||||
@[builtin_term_elab «do»]
|
||||
def elabDo : TermElab := fun stx expectedType? => do
|
||||
if backward.do.legacy.get (← getOptions) then
|
||||
Term.Do.elabDo stx expectedType?
|
||||
else
|
||||
Elab.Do.elabDo stx expectedType?
|
||||
|
||||
@[builtin_term_elab liftMethod] def elabTermLiftMethod : TermElab := fun stx ty => do
|
||||
if backward.do.legacy.get (← getOptions) then
|
||||
Term.elabLiftMethod stx ty
|
||||
else
|
||||
throwError "Not implemented yet"
|
||||
|
|
@ -8,6 +8,7 @@ module
|
|||
prelude
|
||||
public import Lean.Elab.MacroArgUtil
|
||||
public import Lean.Elab.AuxDef
|
||||
public import Lean.Elab.Do.Basic
|
||||
meta import Lean.Parser.Syntax
|
||||
|
||||
public section
|
||||
|
|
@ -57,6 +58,11 @@ def elabElabRulesAux (doc? : Option (TSyntax ``docComment))
|
|||
aux_def elabRules $(mkIdent k) : Lean.Elab.Term.TermElab :=
|
||||
fun stx expectedType? => Lean.Elab.Term.withExpectedType expectedType? fun $expId => match stx with
|
||||
$alts:matchAlt* | _ => no_error_if_unused% throwUnsupportedSyntax)
|
||||
else if catName == `doElem then
|
||||
`($[$doc?:docComment]? @[$(← mkAttrs `do_elab),*] $vis:visibility
|
||||
aux_def elabRules $(mkIdent k) : Lean.Elab.Do.DoElab :=
|
||||
fun stx $expId => match stx with
|
||||
$alts:matchAlt* | _ => no_error_if_unused% throwUnsupportedSyntax)
|
||||
else
|
||||
throwErrorAt expId "syntax category `{catName}` does not support expected type specification"
|
||||
else if catName == `term then
|
||||
|
|
@ -72,6 +78,11 @@ def elabElabRulesAux (doc? : Option (TSyntax ``docComment))
|
|||
`($[$doc?:docComment]? @[$(← mkAttrs `tactic),*] $vis:visibility
|
||||
aux_def elabRules $(mkIdent k) : Lean.Elab.Tactic.Tactic :=
|
||||
fun $alts:matchAlt* | _ => no_error_if_unused% throwUnsupportedSyntax)
|
||||
else if catName == `doElem then
|
||||
`($[$doc?:docComment]? @[$(← mkAttrs `do_elab),*] $vis:visibility
|
||||
aux_def elabRules $(mkIdent k) : Lean.Elab.Do.DoElab :=
|
||||
fun stx cont => match stx with
|
||||
$alts:matchAlt* | _ => no_error_if_unused% throwUnsupportedSyntax)
|
||||
else
|
||||
-- We considered making the command extensible and support new user-defined categories. We think it is unnecessary.
|
||||
-- If users want this feature, they add their own `elab_rules` macro that uses this one as a fallback.
|
||||
|
|
|
|||
|
|
@ -730,9 +730,11 @@ def withPushMacroExpansionStack (beforeStx afterStx : Syntax) (x : TermElabM α)
|
|||
withReader (fun ctx => { ctx with macroStack := { before := beforeStx, after := afterStx } :: ctx.macroStack }) x
|
||||
|
||||
/-- Elaborate `x` with `stx` on the macro stack and produce macro expansion info -/
|
||||
def withMacroExpansion (beforeStx afterStx : Syntax) (x : TermElabM α) : TermElabM α :=
|
||||
withMacroExpansionInfo beforeStx afterStx do
|
||||
withPushMacroExpansionStack beforeStx afterStx x
|
||||
@[specialize]
|
||||
def withMacroExpansion [Monad n] [MonadControlT TermElabM n] (beforeStx afterStx : Syntax) (x : n α) : n α :=
|
||||
controlAt TermElabM fun runInBase => do
|
||||
withMacroExpansionInfo beforeStx afterStx do
|
||||
withPushMacroExpansionStack beforeStx afterStx <| runInBase x
|
||||
|
||||
/--
|
||||
Add the given metavariable to the list of pending synthetic metavariables.
|
||||
|
|
|
|||
|
|
@ -395,6 +395,11 @@ def findFromUserName? (lctx : LocalContext) (userName : Name) : Option LocalDecl
|
|||
| none => none
|
||||
| some decl => if decl.userName == userName then some decl else none
|
||||
|
||||
def getFromUserName! (lctx : LocalContext) (userName : Name) : LocalDecl :=
|
||||
match lctx.findFromUserName? userName with
|
||||
| some decl => decl
|
||||
| none => panic! s!"unknown local declaration `{userName}`"
|
||||
|
||||
def usesUserName (lctx : LocalContext) (userName : Name) : Bool :=
|
||||
(lctx.findFromUserName? userName).isSome
|
||||
|
||||
|
|
@ -631,6 +636,22 @@ def sortFVarsByContextOrder (lctx : LocalContext) (hyps : Array FVarId) : Array
|
|||
| some ldecl => (ldecl.index, fvarId)
|
||||
hyps.qsort (fun h i => h.fst < i.fst) |>.map (·.snd)
|
||||
|
||||
/--
|
||||
Batched version of `Lean.LocalContext.findFromUserName?`.
|
||||
Finds the visible local declarations for each of the given `userNames` up to a certain `start`
|
||||
index exclusively, if any.
|
||||
-/
|
||||
def findFromUserNames (lctx : LocalContext) (userNames : Std.HashSet Name) (start := 0) : Array LocalDecl :=
|
||||
Array.reverse <| Id.run <| ExceptT.runCatch do
|
||||
let (_, _, acc) ← lctx.foldrM (init := (userNames, lctx.numIndices, #[])) fun decl (userNames, num, acc) => do
|
||||
if userNames.isEmpty then throw acc -- stop when we found all user names
|
||||
if num ≤ start then throw acc -- stop when we reached the start index
|
||||
if userNames.contains decl.userName then
|
||||
pure (userNames.erase decl.userName, num - 1, acc.push decl)
|
||||
else
|
||||
pure (userNames, num - 1, acc)
|
||||
return acc.reverse
|
||||
|
||||
end LocalContext
|
||||
|
||||
/-- Class used to denote that `m` has a local context. -/
|
||||
|
|
|
|||
|
|
@ -1785,30 +1785,31 @@ partial def withLocalDecls
|
|||
[Inhabited α]
|
||||
(declInfos : Array (Name × BinderInfo × (Array Expr → n Expr)))
|
||||
(k : (xs : Array Expr) → n α)
|
||||
(kind : LocalDeclKind := .default)
|
||||
: n α :=
|
||||
loop #[]
|
||||
where
|
||||
loop [Inhabited α] (acc : Array Expr) : n α := do
|
||||
if acc.size < declInfos.size then
|
||||
let (name, bi, typeCtor) := declInfos[acc.size]!
|
||||
withLocalDecl name bi (←typeCtor acc) fun x => loop (acc.push x)
|
||||
withLocalDecl name bi (←typeCtor acc) (fun x => loop (acc.push x)) kind
|
||||
else
|
||||
k acc
|
||||
|
||||
/--
|
||||
Variant of `withLocalDecls` using `BinderInfo.default`
|
||||
-/
|
||||
def withLocalDeclsD [Inhabited α] (declInfos : Array (Name × (Array Expr → n Expr))) (k : (xs : Array Expr) → n α) : n α :=
|
||||
def withLocalDeclsD [Inhabited α] (declInfos : Array (Name × (Array Expr → n Expr))) (k : (xs : Array Expr) → n α) (kind : LocalDeclKind := .default) : n α :=
|
||||
withLocalDecls
|
||||
(declInfos.map (fun (name, typeCtor) => (name, BinderInfo.default, typeCtor))) k
|
||||
(declInfos.map (fun (name, typeCtor) => (name, BinderInfo.default, typeCtor))) k kind
|
||||
|
||||
/--
|
||||
Simpler variant of `withLocalDeclsD` for bringing variables into scope whose types do not depend
|
||||
on each other.
|
||||
-/
|
||||
def withLocalDeclsDND [Inhabited α] (declInfos : Array (Name × Expr)) (k : (xs : Array Expr) → n α) : n α :=
|
||||
def withLocalDeclsDND [Inhabited α] (declInfos : Array (Name × Expr)) (k : (xs : Array Expr) → n α) (kind : LocalDeclKind := .default) : n α :=
|
||||
withLocalDeclsD
|
||||
(declInfos.map (fun (name, typeCtor) => (name, fun _ => pure typeCtor))) k
|
||||
(declInfos.map (fun (name, typeCtor) => (name, fun _ => pure typeCtor))) k (kind := kind)
|
||||
|
||||
private def withAuxDeclImp (shortDeclName : Name) (type : Expr) (declName : Name) (k : Expr → MetaM α) : MetaM α := do
|
||||
let fvarId ← mkFreshFVarId
|
||||
|
|
@ -1869,6 +1870,15 @@ def mapLetDecl [MonadLiftT MetaM n] (name : Name) (type : Expr) (val : Expr) (k
|
|||
withLetDecl name type val (nondep := nondep) (kind := kind) fun x => do
|
||||
mkLetFVars (usedLetOnly := usedLetOnly) (generalizeNondepLet := false) #[x] (← k x)
|
||||
|
||||
/--
|
||||
Runs `k x` with the local declaration `<name> : <type> := <val>` added to the local context, where `x` is the new free variable.
|
||||
Afterwards, the local declaration is zeta-reduced into the result.
|
||||
-/
|
||||
def mapLetDeclZeta [MonadLiftT MetaM n] (name : Name) (type rhs : Expr) (k : Expr → n Expr) : n Expr := do
|
||||
withLetDecl (n:=n) name type rhs fun x => do
|
||||
let e ← elimMVarDeps #[x] (← k x)
|
||||
return e.replaceFVar x rhs
|
||||
|
||||
def withLocalInstancesImp (decls : List LocalDecl) (k : MetaM α) : MetaM α := do
|
||||
let mut localInsts := (← read).localInstances
|
||||
let size := localInsts.size
|
||||
|
|
|
|||
67
src/Lean/Meta/ProdN.lean
Normal file
67
src/Lean/Meta/ProdN.lean
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
/-
|
||||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sebastian Graf
|
||||
-/
|
||||
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Lean.Meta.InferType
|
||||
import Lean.Meta.DecLevel
|
||||
|
||||
public section
|
||||
|
||||
namespace Lean.Meta
|
||||
|
||||
/--
|
||||
Given types `tᵢ`, return the tuple type `t₁ × t₂ × … × tₙ`.
|
||||
For `n = 0`, return `PUnit`.
|
||||
-/
|
||||
def mkProdN (ts : Array Expr) : MetaM Expr := do
|
||||
if h : ts.size > 0 then
|
||||
let mut tupleTy := ts.back
|
||||
let mut u ← getDecLevel tupleTy
|
||||
let mut ts := ts.pop
|
||||
for i in 0...ts.size do
|
||||
let ty := ts.back!
|
||||
let u' ← getDecLevel ty
|
||||
tupleTy := mkApp2 (mkConst ``Prod [u', u]) ty tupleTy
|
||||
u := (mkLevelMax u u').normalize
|
||||
ts := ts.pop
|
||||
return tupleTy
|
||||
else
|
||||
let u ← mkFreshLevelMVar
|
||||
return mkConst ``PUnit [u]
|
||||
|
||||
/--
|
||||
Given expressions `eᵢ`, return the tuple `(e₁, e₂, …, eₙ)` and its type `t₁ × t₂ × … × tₙ`.
|
||||
For `n = 0`, return `PUnit.unit`.
|
||||
-/
|
||||
def mkProdMkN (es : Array Expr) : MetaM (Expr × Expr) := do
|
||||
if h : es.size > 0 then
|
||||
let mut tuple := es.back
|
||||
let mut tupleTy ← inferType tuple
|
||||
let mut u ← getDecLevel tupleTy
|
||||
let mut es := es.pop
|
||||
for i in 0...es.size do
|
||||
let e := es.back!
|
||||
let ty ← inferType e
|
||||
let u' ← getDecLevel ty
|
||||
tuple := mkApp4 (mkConst ``Prod.mk [u', u]) ty tupleTy e tuple
|
||||
tupleTy := mkApp2 (mkConst ``Prod [u', u]) ty tupleTy
|
||||
u := (mkLevelMax u u').normalize
|
||||
es := es.pop
|
||||
return (tuple, tupleTy)
|
||||
else
|
||||
let u ← mkFreshLevelMVar
|
||||
return (mkConst ``PUnit.unit [u], mkConst ``PUnit [u])
|
||||
|
||||
/-- Given a product `(e₁, e₂)` of type `t₁ × t₂`, return `(e₁, t₁, e₂, t₂)`. -/
|
||||
def getProdFields (tuple tupleTy : Expr) : MetaM (Expr × Expr × Expr × Expr) := do
|
||||
let tupleTy ← instantiateMVarsIfMVarApp tupleTy
|
||||
let_expr c@Prod fstTy sndTy := tupleTy
|
||||
| throwError "Internal error: Expected Prod, got {tuple} of type {tupleTy}"
|
||||
let fst := mkApp3 (mkConst ``Prod.fst c.constLevels!) fstTy sndTy tuple
|
||||
let snd := mkApp3 (mkConst ``Prod.snd c.constLevels!) fstTy sndTy tuple
|
||||
return (fst, fstTy, snd, sndTy)
|
||||
|
|
@ -49,6 +49,14 @@ builtin_initialize
|
|||
register_parser_alias doSeq
|
||||
register_parser_alias termBeforeDo
|
||||
|
||||
def getDoElems (doSeq : TSyntax ``doSeq) : Array (TSyntax `doElem) :=
|
||||
if doSeq.raw.getKind == ``Parser.Term.doSeqBracketed then
|
||||
doSeq.raw[1].getArgs.map fun arg => ⟨arg[0]⟩
|
||||
else if doSeq.raw.getKind == ``Parser.Term.doSeqIndent then
|
||||
doSeq.raw[0].getArgs.map fun arg => ⟨arg[0]⟩
|
||||
else
|
||||
#[]
|
||||
|
||||
def notFollowedByRedefinedTermToken :=
|
||||
-- Remark: we don't currently support `open` and `set_option` in `do`-blocks,
|
||||
-- but we include them in the following list to fix the ambiguity where
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue