feat: add withReader method
@Kha `withReader` is a well-behaved version of `adaptReader`. `adaptReader` is
too general, and it often produces counterintuitive elaboration
errors.
Here are two super annoying issues I hit all the time:
1- `adaptReader` + polymorphic code
```
def ex1 : ReaderT Nat IO Unit :=
adaptReader (fun x => x + 1) $
IO.println "foo" -- 3 Errors here failed to synthesize `Monad ?m` and `MonadIO ?m`, and don't know how to synthesize `Type → Type`
```
2- `adaptReader` and notation that requires the expected type
```
structure Context :=
(x y : Nat)
def ex2 : ReaderT Context IO Nat :=
adaptReader (fun s => { s with x := 10 }) $ -- Error at the structure instance
...
```
In the example above, I have to write `fun (s : Context) => ...` to
fix the problem.
The two problems above happen in the old and new frontends. However,
there is a new problem specific for the new frontend. In the new
frontend, a `do` is only elaborated when the expected type is known.
So, `adaptReader (fun ctx => ...) do ...` seldom works :(
As I said above, the issue is that `adaptReader` is too general. Its
type is
```
{ρ ρ' : Type u_1} → {m m' : Type u_1 → Type u_2} → [MonadReaderAdapter ρ ρ' m m'] → {α : Type u_1} → (ρ' → ρ) → m α → m' α
```
`withReader` is a simpler version of `adaptReader`
```
withReader : {ρ : Type u_1} → {m : Type u_1 → Type u_2} → [MonadWithReader ρ m] → {α : Type u_1} → (ρ → ρ) → m α → m α
```
It doesn't have any of the problems above. Moreover, I managed to replace
every single instance of `adaptReader` with `withReader` at the stdlib
and tests. We don't need the `adaptReader` generality.
This commit is contained in:
parent
4fce06c468
commit
9af0a0e18b
29 changed files with 92 additions and 70 deletions
|
|
@ -112,7 +112,6 @@ instance monadReaderTrans {ρ : Type u} {m : Type u → Type v} {n : Type u →
|
|||
instance {ρ : Type u} {m : Type u → Type v} [Monad m] : MonadReaderOf ρ (ReaderT ρ m) :=
|
||||
⟨ReaderT.read⟩
|
||||
|
||||
|
||||
/-- Adapt a Monad stack, changing the Type of its top-most environment.
|
||||
|
||||
This class is comparable to [Control.Lens.Magnify](https://hackage.haskell.org/package/lens-4.15.4/docs/Control-Lens-Zoom.html#t:Magnify), but does not use lenses (why would it), and is derived automatically for any transformer implementing `MonadFunctor`.
|
||||
|
|
@ -173,3 +172,27 @@ instance ReaderT.monadControl (ρ : Type u) (m : Type u → Type v) : MonadContr
|
|||
|
||||
instance ReaderT.finally {m : Type u → Type v} {ρ : Type u} [MonadFinally m] [Monad m] : MonadFinally (ReaderT ρ m) :=
|
||||
{ finally' := fun α β x h ctx => finally' (x ctx) (fun a? => h a? ctx) }
|
||||
|
||||
class MonadWithReaderOf (ρ : Type u) (m : Type u → Type v) :=
|
||||
(withReader {α : Type u} : (ρ → ρ) → m α → m α)
|
||||
|
||||
@[inline] def withTheReader (ρ : Type u) {m : Type u → Type v} [MonadWithReaderOf ρ m] {α : Type u} (f : ρ → ρ) (x : m α) : m α :=
|
||||
MonadWithReaderOf.withReader f x
|
||||
|
||||
class MonadWithReader (ρ : outParam (Type u)) (m : Type u → Type v) :=
|
||||
(withReader {α : Type u} : (ρ → ρ) → m α → m α)
|
||||
|
||||
export MonadWithReader (withReader)
|
||||
|
||||
instance MonadWithReaderOf.isMonadWithReader (ρ : Type u) (m : Type u → Type v) [MonadWithReaderOf ρ m] : MonadWithReader ρ m :=
|
||||
⟨fun α => withTheReader ρ⟩
|
||||
|
||||
section
|
||||
variables {ρ : Type u} {m : Type u → Type v}
|
||||
|
||||
instance monadWithReaderOfTrans {n : Type u → Type v} [MonadWithReaderOf ρ m] [MonadFunctor m m n n] : MonadWithReaderOf ρ n :=
|
||||
⟨fun α f => monadMap fun β => (withTheReader ρ f : m β → m β)⟩
|
||||
|
||||
instance ReaderT.monadWithReaderOf [Monad m] : MonadWithReaderOf ρ (ReaderT ρ m) :=
|
||||
⟨fun α f x ctx => x (f ctx)⟩
|
||||
end
|
||||
|
|
|
|||
|
|
@ -454,13 +454,12 @@ throw $ Exception.error ref msg
|
|||
|
||||
@[inline] protected def withFreshMacroScope {α} (x : MacroM α) : MacroM α := do
|
||||
fresh ← modifyGet (fun s => (s, s+1));
|
||||
adaptReader (fun (ctx : Context) => { ctx with currMacroScope := fresh }) x
|
||||
withReader (fun ctx => { ctx with currMacroScope := fresh }) x
|
||||
|
||||
@[inline] def withIncRecDepth {α} (ref : Syntax) (x : MacroM α) : MacroM α := do
|
||||
ctx ← read;
|
||||
when (ctx.currRecDepth == ctx.maxRecDepth) $ throw $ Exception.error ref maxRecDepthErrorMessage;
|
||||
adaptReader (fun (ctx : Context) => { ctx with currRecDepth := ctx.currRecDepth + 1 }) x
|
||||
|
||||
withReader (fun ctx => { ctx with currRecDepth := ctx.currRecDepth + 1 }) x
|
||||
instance monadQuotation : MonadQuotation MacroM :=
|
||||
{ getCurrMacroScope := fun ctx => pure ctx.currMacroScope,
|
||||
getMainModule := fun ctx => pure ctx.mainModule,
|
||||
|
|
|
|||
|
|
@ -270,7 +270,7 @@ def updateParamSet (ctx : BorrowInfCtx) (ps : Array Param) : BorrowInfCtx :=
|
|||
|
||||
partial def collectFnBody : FnBody → M Unit
|
||||
| FnBody.jdecl j ys v b => do
|
||||
adaptReader (fun ctx => updateParamSet ctx ys) (collectFnBody v);
|
||||
withReader (fun ctx => updateParamSet ctx ys) (collectFnBody v);
|
||||
ctx ← read;
|
||||
updateParamMap (ParamMap.Key.jp ctx.currFn j);
|
||||
collectFnBody b
|
||||
|
|
@ -285,7 +285,7 @@ partial def collectFnBody : FnBody → M Unit
|
|||
|
||||
partial def collectDecl : Decl → M Unit
|
||||
| Decl.fdecl f ys _ b =>
|
||||
adaptReader (fun ctx => let ctx := updateParamSet ctx ys; { ctx with currFn := f }) $ do
|
||||
withReader (fun ctx => let ctx := updateParamSet ctx ys; { ctx with currFn := f }) $ do
|
||||
collectFnBody b;
|
||||
updateParamMap (ParamMap.Key.decl f)
|
||||
| _ => pure ()
|
||||
|
|
|
|||
|
|
@ -151,13 +151,13 @@ match findEnvDecl' ctx.env fid ctx.decls with
|
|||
| none => pure (arbitrary _) -- unreachable if well-formed
|
||||
|
||||
@[inline] def withParams {α : Type} (xs : Array Param) (k : M α) : M α :=
|
||||
adaptReader (fun (ctx : BoxingContext) => { ctx with localCtx := ctx.localCtx.addParams xs }) k
|
||||
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addParams xs }) k
|
||||
|
||||
@[inline] def withVDecl {α : Type} (x : VarId) (ty : IRType) (v : Expr) (k : M α) : M α :=
|
||||
adaptReader (fun (ctx : BoxingContext) => { ctx with localCtx := ctx.localCtx.addLocal x ty v }) k
|
||||
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addLocal x ty v }) k
|
||||
|
||||
@[inline] def withJDecl {α : Type} (j : JoinPointId) (xs : Array Param) (v : FnBody) (k : M α) : M α :=
|
||||
adaptReader (fun (ctx : BoxingContext) => { ctx with localCtx := ctx.localCtx.addJP j xs v }) k
|
||||
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addJP j xs v }) k
|
||||
|
||||
/- If `x` declaration is of the form `x := Expr.lit _` or `x := Expr.fap c #[]`,
|
||||
and `x`'s type is not cheap to box (e.g., it is `UInt64), then return its value. -/
|
||||
|
|
|
|||
|
|
@ -119,19 +119,19 @@ ctx ← read;
|
|||
localCtx ← ps.foldlM (fun (ctx : LocalContext) p => do
|
||||
markVar p.x;
|
||||
pure $ ctx.addParam p) ctx.localCtx;
|
||||
adaptReader (fun _ => { ctx with localCtx := localCtx }) k
|
||||
withReader (fun _ => { ctx with localCtx := localCtx }) k
|
||||
|
||||
partial def checkFnBody : FnBody → M Unit
|
||||
| FnBody.vdecl x t v b => do
|
||||
checkExpr t v;
|
||||
markVar x;
|
||||
ctx ← read;
|
||||
adaptReader (fun (ctx : CheckerContext) => { ctx with localCtx := ctx.localCtx.addLocal x t v }) (checkFnBody b)
|
||||
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addLocal x t v }) (checkFnBody b)
|
||||
| FnBody.jdecl j ys v b => do
|
||||
markJP j;
|
||||
withParams ys (checkFnBody v);
|
||||
ctx ← read;
|
||||
adaptReader (fun (ctx : CheckerContext) => { ctx with localCtx := ctx.localCtx.addJP j ys v }) (checkFnBody b)
|
||||
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addJP j ys v }) (checkFnBody b)
|
||||
| FnBody.set x _ y b => checkVar x *> checkArg y *> checkFnBody b
|
||||
| FnBody.uset x _ y b => checkVar x *> checkVar y *> checkFnBody b
|
||||
| FnBody.sset x _ _ y _ b => checkVar x *> checkVar y *> checkFnBody b
|
||||
|
|
|
|||
|
|
@ -215,7 +215,7 @@ partial def interpFnBody : FnBody → M Unit
|
|||
updateVarAssignment x v;
|
||||
interpFnBody b
|
||||
| FnBody.jdecl j ys v b =>
|
||||
adaptReader (fun (ctx : InterpContext) => { ctx with lctx := ctx.lctx.addJP j ys v }) $
|
||||
withReader (fun ctx => { ctx with lctx := ctx.lctx.addJP j ys v }) $
|
||||
interpFnBody b
|
||||
| FnBody.case _ x _ alts => do
|
||||
v ← findVarValue x;
|
||||
|
|
@ -247,7 +247,7 @@ ctx.decls.size.foldM (fun idx modified => do
|
|||
s ← get;
|
||||
-- dbgTrace (">> " ++ toString fid) $ fun _ =>
|
||||
let currVals := s.funVals.get! idx;
|
||||
adaptReader (fun (ctx : InterpContext) => { ctx with currFnIdx := idx }) $ do
|
||||
withReader (fun ctx => { ctx with currFnIdx := idx }) $ do
|
||||
ys.forM $ fun y => updateVarAssignment y.x top;
|
||||
interpFnBody b;
|
||||
s ← get;
|
||||
|
|
|
|||
|
|
@ -622,7 +622,7 @@ emitLn "}"
|
|||
def emitDeclAux (d : Decl) : M Unit := do
|
||||
env ← getEnv;
|
||||
let (vMap, jpMap) := mkVarJPMaps d;
|
||||
adaptReader (fun (ctx : Context) => { ctx with jpMap := jpMap }) $ do
|
||||
withReader (fun ctx => { ctx with jpMap := jpMap }) $ do
|
||||
unless (hasInitAttr env d.name) $
|
||||
match d with
|
||||
| Decl.fdecl f xs t b => do
|
||||
|
|
@ -653,7 +653,7 @@ unless (hasInitAttr env d.name) $
|
|||
emit "lean_object* "; emit x.x; emit " = _args["; emit i; emitLn "];"
|
||||
};
|
||||
emitLn "_start:";
|
||||
adaptReader (fun (ctx : Context) => { ctx with mainFn := f, mainParams := xs }) (emitFnBody b);
|
||||
withReader (fun ctx => { ctx with mainFn := f, mainParams := xs }) (emitFnBody b);
|
||||
emitLn "}"
|
||||
| _ => pure ()
|
||||
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ private partial def Dmain (x : VarId) (c : CtorInfo) : FnBody → M (FnBody × B
|
|||
pure (FnBody.case tid y yType alts, true)
|
||||
else pure (e, false)
|
||||
| FnBody.jdecl j ys v b => do
|
||||
(b, found) ← adaptReader (fun (ctx : LocalContext) => ctx.addJP j ys v) (Dmain b);
|
||||
(b, found) ← withReader (fun ctx => ctx.addJP j ys v) (Dmain b);
|
||||
(v, _ /- found' -/) ← Dmain v;
|
||||
/- If `found' == true`, then `Dmain b` must also have returned `(b, true)` since
|
||||
we assume the IR does not have dead join points. So, if `x` is live in `j` (i.e., `v`),
|
||||
|
|
@ -138,7 +138,7 @@ partial def R : FnBody → M FnBody
|
|||
pure $ FnBody.case tid x xType alts
|
||||
| FnBody.jdecl j ys v b => do
|
||||
v ← R v;
|
||||
b ← adaptReader (fun (ctx : LocalContext) => ctx.addJP j ys v) (R b);
|
||||
b ← withReader (fun ctx => ctx.addJP j ys v) (R b);
|
||||
pure $ FnBody.jdecl j ys v b
|
||||
| e => do
|
||||
if e.isTerminal then pure e
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ instance CoreM.inhabited {α} : Inhabited (CoreM α) :=
|
|||
|
||||
instance : Ref CoreM :=
|
||||
{ getRef := do ctx ← read; pure ctx.ref,
|
||||
withRef := fun α ref x => adaptReader (fun (ctx : Context) => { ctx with ref := ref }) x }
|
||||
withRef := fun α ref x => withReader (fun ctx => { ctx with ref := ref }) x }
|
||||
|
||||
instance : MonadEnv CoreM :=
|
||||
{ getEnv := do s ← get; pure s.env,
|
||||
|
|
@ -52,7 +52,7 @@ instance : MonadNameGenerator CoreM :=
|
|||
setNGen := fun ngen => modify fun s => { s with ngen := ngen } }
|
||||
|
||||
instance : MonadRecDepth CoreM :=
|
||||
{ withRecDepth := fun α d x => adaptReader (fun (ctx : Context) => { ctx with currRecDepth := d }) x,
|
||||
{ withRecDepth := fun α d x => withReader (fun ctx => { ctx with currRecDepth := d }) x,
|
||||
getRecDepth := do ctx ← read; pure ctx.currRecDepth,
|
||||
getMaxRecDepth := do ctx ← read; pure ctx.maxRecDepth }
|
||||
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ Because `childIdx < 3` in the case of `Expr`, we can injectively map a path
|
|||
Note that `pos` is initialized to `1` (case `childIdxs == []`).
|
||||
-/
|
||||
def descend {α} (child : Expr) (childIdx : Nat) (d : DelabM α) : DelabM α :=
|
||||
adaptReader (fun (cfg : Context) => { cfg with expr := child, pos := cfg.pos * 3 + childIdx }) d
|
||||
withReader (fun cfg => { cfg with expr := child, pos := cfg.pos * 3 + childIdx }) d
|
||||
|
||||
def withAppFn {α} (d : DelabM α) : DelabM α := do
|
||||
Expr.app fn _ _ ← getExpr | unreachable!;
|
||||
|
|
|
|||
|
|
@ -683,7 +683,7 @@ match fIdent with
|
|||
funLVals ← withRef fIdent $ resolveName n preresolved fExplicitUnivs;
|
||||
let overloaded := overloaded || funLVals.length > 1;
|
||||
-- Set `errToSorry` to `false` if `funLVals` > 1. See comment above about the interaction between `errToSorry` and `observing`.
|
||||
adaptReader (fun (ctx : Context) => { ctx with errToSorry := funLVals.length == 1 && ctx.errToSorry }) $
|
||||
withReader (fun ctx => { ctx with errToSorry := funLVals.length == 1 && ctx.errToSorry }) $
|
||||
funLVals.foldlM
|
||||
(fun acc ⟨f, fields⟩ => do
|
||||
let lvals' := fields.map LVal.fieldName;
|
||||
|
|
@ -699,7 +699,7 @@ private partial def elabAppFn : Syntax → List LVal → Array NamedArg → Arra
|
|||
| f, lvals, namedArgs, args, expectedType?, explicit, overloaded, acc =>
|
||||
if f.getKind == choiceKind then
|
||||
-- Set `errToSorry` to `false` when processing choice nodes. See comment above about the interaction between `errToSorry` and `observing`.
|
||||
adaptReader (fun (ctx : Context) => { ctx with errToSorry := false }) do
|
||||
withReader (fun ctx => { ctx with errToSorry := false }) do
|
||||
f.getArgs.foldlM (fun acc f => elabAppFn f lvals namedArgs args expectedType? explicit true acc) acc
|
||||
else match_syntax f with
|
||||
| `($(e).$idx:fieldIdx) =>
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ instance : AddMessageContext CommandElabM :=
|
|||
|
||||
instance : Ref CommandElabM :=
|
||||
{ getRef := Command.getRef,
|
||||
withRef := fun α ref x => adaptReader (fun (ctx : Context) => { ctx with ref := ref }) x }
|
||||
withRef := fun α ref x => withReader (fun ctx => { ctx with ref := ref }) x }
|
||||
|
||||
instance : AddErrorMessageContext CommandElabM :=
|
||||
{ add := fun ref msg => do
|
||||
|
|
@ -138,7 +138,7 @@ protected def getMainModule : CommandElabM Name := do env ← getEnv; pure e
|
|||
|
||||
@[inline] protected def withFreshMacroScope {α} (x : CommandElabM α) : CommandElabM α := do
|
||||
fresh ← modifyGet (fun st => (st.nextMacroScope, { st with nextMacroScope := st.nextMacroScope + 1 }));
|
||||
adaptReader (fun (ctx : Context) => { ctx with currMacroScope := fresh }) x
|
||||
withReader (fun ctx => { ctx with currMacroScope := fresh }) x
|
||||
|
||||
instance CommandElabM.MonadQuotation : MonadQuotation CommandElabM := {
|
||||
getCurrMacroScope := Command.getCurrMacroScope,
|
||||
|
|
@ -158,7 +158,7 @@ private def elabCommandUsing (s : State) (stx : Syntax) : List CommandElab → C
|
|||
|
||||
/- Elaborate `x` with `stx` on the macro stack -/
|
||||
@[inline] def withMacroExpansion {α} (beforeStx afterStx : Syntax) (x : CommandElabM α) : CommandElabM α :=
|
||||
adaptReader (fun (ctx : Context) => { ctx with macroStack := { before := beforeStx, after := afterStx } :: ctx.macroStack }) x
|
||||
withReader (fun ctx => { ctx with macroStack := { before := beforeStx, after := afterStx } :: ctx.macroStack }) x
|
||||
|
||||
instance : MonadMacroAdapter CommandElabM :=
|
||||
{ getCurrMacroScope := getCurrMacroScope,
|
||||
|
|
@ -166,7 +166,7 @@ instance : MonadMacroAdapter CommandElabM :=
|
|||
setNextMacroScope := fun next => modify $ fun s => { s with nextMacroScope := next } }
|
||||
|
||||
instance : MonadRecDepth CommandElabM :=
|
||||
{ withRecDepth := fun α d x => adaptReader (fun (ctx : Context) => { ctx with currRecDepth := d }) x,
|
||||
{ withRecDepth := fun α d x => withReader (fun ctx => { ctx with currRecDepth := d }) x,
|
||||
getRecDepth := do ctx ← read; pure ctx.currRecDepth,
|
||||
getMaxRecDepth := do s ← get; pure s.maxRecDepth }
|
||||
|
||||
|
|
|
|||
|
|
@ -1070,7 +1070,7 @@ structure Context :=
|
|||
abbrev M := ReaderT Context TermElabM
|
||||
|
||||
@[inline] def withNewVars {α} (newVars : Array Name) (x : M α) : M α :=
|
||||
adaptReader (fun (ctx : Context) => { ctx with varSet := insertVars ctx.varSet newVars }) x
|
||||
withReader (fun ctx => { ctx with varSet := insertVars ctx.varSet newVars }) x
|
||||
|
||||
def checkReassignable (xs : Array Name) : M Unit := do
|
||||
ctx ← read;
|
||||
|
|
@ -1082,7 +1082,7 @@ xs.forM fun x =>
|
|||
| _ => throwError ("'" ++ x.simpMacroScopes ++ "' cannot be reassigned")
|
||||
|
||||
@[inline] def withFor {α} (x : M α) : M α :=
|
||||
adaptReader (fun (ctx : Context) => { ctx with insideFor := true }) x
|
||||
withReader (fun ctx => { ctx with insideFor := true }) x
|
||||
|
||||
structure ToForInTermResult :=
|
||||
(uvars : Array Name)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ abbrev LevelElabM := ReaderT Context (EStateM Exception State)
|
|||
|
||||
instance : Ref LevelElabM :=
|
||||
{ getRef := do return (← read).ref,
|
||||
withRef := fun ref x => adaptReader (fun ctx => { ctx with ref := ref }) x }
|
||||
withRef := fun ref x => withReader (fun ctx => { ctx with ref := ref }) x }
|
||||
|
||||
instance : AddMessageContext LevelElabM :=
|
||||
{ addMessageContext := fun msg => pure msg }
|
||||
|
|
|
|||
|
|
@ -702,7 +702,7 @@ localDecls ← s.localDecls.mapM fun d => instantiateLocalDeclMVars d;
|
|||
lctx ← getLCtx;
|
||||
let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.erase d.fvarId) lctx;
|
||||
let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.addDecl d) lctx;
|
||||
adaptTheReader Meta.Context (fun ctx => { ctx with lctx := lctx }) $ k localDecls patterns
|
||||
withTheReader Meta.Context (fun ctx => { ctx with lctx := lctx }) $ k localDecls patterns
|
||||
|
||||
private def withElaboratedLHS {α} (ref : Syntax) (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr)
|
||||
(k : AltLHS → Expr → TermElabM α) : TermElabM α := do
|
||||
|
|
|
|||
|
|
@ -352,7 +352,7 @@ private unsafe partial def toPreterm : Syntax → TermElabM Expr
|
|||
let lctx := lctx.mkLocalDecl n n ty;
|
||||
let params := params.eraseIdx 0;
|
||||
stx ← `(fun $params* => $body);
|
||||
adaptTheReader Meta.Context (fun ctx => { ctx with lctx := lctx }) $ do
|
||||
withTheReader Meta.Context (fun ctx => { ctx with lctx := lctx }) $ do
|
||||
e ← toPreterm stx;
|
||||
pure $ lctx.mkLambda #[mkFVar n] e
|
||||
| `Lean.Parser.Term.let => do
|
||||
|
|
@ -363,7 +363,7 @@ private unsafe partial def toPreterm : Syntax → TermElabM Expr
|
|||
val ← toPreterm val;
|
||||
lctx ← getLCtx;
|
||||
let lctx := lctx.mkLetDecl n n exprPlaceholder val;
|
||||
adaptTheReader Meta.Context (fun ctx => { ctx with lctx := lctx }) $ do
|
||||
withTheReader Meta.Context (fun ctx => { ctx with lctx := lctx }) $ do
|
||||
e ← toPreterm $ body;
|
||||
pure $ lctx.mkLambda #[mkFVar n] e
|
||||
| `Lean.Parser.Term.app => do
|
||||
|
|
|
|||
|
|
@ -742,7 +742,7 @@ def tryToSynthesizeDefault (structs : Array Struct) (allStructNames : Array Name
|
|||
tryToSynthesizeDefaultAux structs allStructNames maxDistance fieldName mvarId 0 0
|
||||
|
||||
partial def step : Struct → M Unit
|
||||
| struct => unlessM isRoundDone $ adaptReader (fun (ctx : Context) => { ctx with structs := ctx.structs.push struct }) $ do
|
||||
| struct => unlessM isRoundDone $ withReader (fun ctx => { ctx with structs := ctx.structs.push struct }) $ do
|
||||
struct.fields.forM $ fun field =>
|
||||
match field.val with
|
||||
| FieldVal.nested struct => step struct
|
||||
|
|
@ -764,7 +764,7 @@ partial def propagateLoop (hierarchyDepth : Nat) : Nat → Struct → M Unit
|
|||
| some field =>
|
||||
if d > hierarchyDepth then
|
||||
throwErrorAt field.ref ("field '" ++ getFieldName field ++ "' is missing")
|
||||
else adaptReader (fun (ctx : Context) => { ctx with maxDistance := d }) $ do
|
||||
else withReader (fun ctx => { ctx with maxDistance := d }) $ do
|
||||
modify $ fun (s : State) => { s with progress := false };
|
||||
step struct;
|
||||
s ← get;
|
||||
|
|
|
|||
|
|
@ -44,10 +44,10 @@ abbrev ToParserDescrM := ReaderT ToParserDescrContext (StateRefT Bool TermElabM)
|
|||
private def markAsTrailingParser : ToParserDescrM Unit := set true
|
||||
|
||||
@[inline] private def withNotFirst {α} (x : ToParserDescrM α) : ToParserDescrM α :=
|
||||
adaptReader (fun (ctx : ToParserDescrContext) => { ctx with first := false }) x
|
||||
withReader (fun ctx => { ctx with first := false }) x
|
||||
|
||||
@[inline] private def withoutLeftRec {α} (x : ToParserDescrM α) : ToParserDescrM α :=
|
||||
adaptReader (fun (ctx : ToParserDescrContext) => { ctx with leftRec := false }) x
|
||||
withReader (fun ctx => { ctx with leftRec := false }) x
|
||||
|
||||
def checkLeftRec (stx : Syntax) : ToParserDescrM Bool := do
|
||||
ctx ← read;
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ ensureAssignmentHasNoMVars mvarId
|
|||
/-- Auxiliary function used to implement `synthesizeSyntheticMVars`. -/
|
||||
private def resumeElabTerm (stx : Syntax) (expectedType? : Option Expr) (errToSorry := true) : TermElabM Expr :=
|
||||
-- Remark: if `ctx.errToSorry` is already false, then we don't enable it. Recall tactics disable `errToSorry`
|
||||
adaptReader (fun ctx => { ctx with errToSorry := ctx.errToSorry && errToSorry }) $
|
||||
withReader (fun ctx => { ctx with errToSorry := ctx.errToSorry && errToSorry }) do
|
||||
elabTerm stx expectedType? false
|
||||
|
||||
/--
|
||||
|
|
@ -50,7 +50,7 @@ private def resumePostponed (macroStack : MacroStack) (declName? : Option Name)
|
|||
withRef stx $ withMVarContext mvarId do
|
||||
let s ← get
|
||||
try
|
||||
adaptReader (m := TermElabM) (fun ctx => { ctx with macroStack := macroStack, declName? := declName? }) do -- TODO: remove (m := TermElabM)
|
||||
withReader (fun ctx => { ctx with macroStack := macroStack, declName? := declName? }) do
|
||||
let mvarDecl ← getMVarDecl mvarId
|
||||
let expectedType ← instantiateMVars mvarDecl.type
|
||||
let result ← resumeElabTerm stx expectedType (!postponeOnError)
|
||||
|
|
@ -104,7 +104,7 @@ pure $ !val.getAppFn.isMVar
|
|||
|
||||
/-- Try to synthesize the given pending synthetic metavariable. -/
|
||||
private def synthesizeSyntheticMVar (mvarSyntheticDecl : SyntheticMVarDecl) (postponeOnError : Bool) (runTactics : Bool) : TermElabM Bool :=
|
||||
withRef mvarSyntheticDecl.stx $
|
||||
withRef mvarSyntheticDecl.stx do
|
||||
match mvarSyntheticDecl.kind with
|
||||
| SyntheticMVarKind.typeClass => synthesizePendingInstMVar mvarSyntheticDecl.mvarId
|
||||
| SyntheticMVarKind.coe header? expectedType eType e f? => synthesizePendingCoeInstMVar mvarSyntheticDecl.mvarId header? expectedType eType e f?
|
||||
|
|
@ -112,7 +112,7 @@ match mvarSyntheticDecl.kind with
|
|||
| SyntheticMVarKind.withDefault _ => checkWithDefault mvarSyntheticDecl.mvarId
|
||||
| SyntheticMVarKind.postponed macroStack declName? => resumePostponed macroStack declName? mvarSyntheticDecl.stx mvarSyntheticDecl.mvarId postponeOnError
|
||||
| SyntheticMVarKind.tactic declName? tacticCode =>
|
||||
adaptReader (m := TermElabM) (fun (ctx : Context) => { ctx with declName? := declName? }) do -- TODO: remove (m := TermElabM)
|
||||
withReader (fun ctx => { ctx with declName? := declName? }) do
|
||||
if runTactics then
|
||||
runTactic mvarSyntheticDecl.mvarId tacticCode
|
||||
pure true
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ loop tactics
|
|||
|
||||
/- Elaborate `x` with `stx` on the macro stack -/
|
||||
@[inline] def withMacroExpansion {α} (beforeStx afterStx : Syntax) (x : TacticM α) : TacticM α :=
|
||||
adaptTheReader Term.Context (fun ctx => { ctx with macroStack := { before := beforeStx, after := afterStx } :: ctx.macroStack }) x
|
||||
withTheReader Term.Context (fun ctx => { ctx with macroStack := { before := beforeStx, after := afterStx } :: ctx.macroStack }) x
|
||||
|
||||
mutual
|
||||
|
||||
|
|
|
|||
|
|
@ -262,7 +262,7 @@ protected def getMainModule : TermElabM Name := do env ← getEnv; pure env.
|
|||
|
||||
@[inline] protected def withFreshMacroScope {α} (x : TermElabM α) : TermElabM α := do
|
||||
fresh ← modifyGetThe Core.State (fun st => (st.nextMacroScope, { st with nextMacroScope := st.nextMacroScope + 1 }));
|
||||
adaptReader (fun (ctx : Context) => { ctx with currMacroScope := fresh }) x
|
||||
withReader (fun ctx => { ctx with currMacroScope := fresh }) x
|
||||
|
||||
instance monadQuotation : MonadQuotation TermElabM := {
|
||||
getCurrMacroScope := Term.getCurrMacroScope,
|
||||
|
|
@ -299,13 +299,13 @@ def getMVarDecl (mvarId : MVarId) : TermElabM MetavarDecl := do mctx ← getMCtx
|
|||
def assignLevelMVar (mvarId : MVarId) (val : Level) : TermElabM Unit := modifyThe Meta.State $ fun s => { s with mctx := s.mctx.assignLevel mvarId val }
|
||||
|
||||
def withDeclName {α} (name : Name) (x : TermElabM α) : TermElabM α :=
|
||||
adaptReader (fun (ctx : Context) => { ctx with declName? := name }) x
|
||||
withReader (fun ctx => { ctx with declName? := name }) x
|
||||
|
||||
def withLevelNames {α} (levelNames : List Name) (x : TermElabM α) : TermElabM α :=
|
||||
adaptReader (fun (ctx : Context) => { ctx with levelNames := levelNames }) x
|
||||
withReader (fun ctx => { ctx with levelNames := levelNames }) x
|
||||
|
||||
def withoutErrToSorry {α} (x : TermElabM α) : TermElabM α :=
|
||||
adaptReader (fun (ctx : Context) => { ctx with errToSorry := false }) x
|
||||
withReader (fun ctx => { ctx with errToSorry := false }) x
|
||||
|
||||
/-- For testing `TermElabM` methods. The #eval command will sign the error. -/
|
||||
def throwErrorIfErrors : TermElabM Unit := do
|
||||
|
|
@ -337,7 +337,7 @@ liftLevelM $ Level.elabLevel stx
|
|||
|
||||
/- Elaborate `x` with `stx` on the macro stack -/
|
||||
@[inline] def withMacroExpansion {α} (beforeStx afterStx : Syntax) (x : TermElabM α) : TermElabM α :=
|
||||
adaptReader (fun (ctx : Context) => { ctx with macroStack := { before := beforeStx, after := afterStx } :: ctx.macroStack }) x
|
||||
withReader (fun ctx => { ctx with macroStack := { before := beforeStx, after := afterStx } :: ctx.macroStack }) x
|
||||
|
||||
/-
|
||||
Add the given metavariable to the list of pending synthetic metavariables.
|
||||
|
|
@ -423,7 +423,7 @@ when foundError throwAbort
|
|||
Execute `x` without allowing it to postpone elaboration tasks.
|
||||
That is, `tryPostpone` is a noop. -/
|
||||
@[inline] def withoutPostponing {α} (x : TermElabM α) : TermElabM α :=
|
||||
adaptReader (fun (ctx : Context) => { ctx with mayPostpone := false }) x
|
||||
withReader (fun ctx => { ctx with mayPostpone := false }) x
|
||||
|
||||
/-- Creates syntax for `(` <ident> `:` <type> `)` -/
|
||||
def mkExplicitBinder (ident : Syntax) (type : Syntax) : Syntax :=
|
||||
|
|
@ -559,7 +559,7 @@ match f? with
|
|||
| some f => Meta.throwAppTypeMismatch f e extraMsg
|
||||
|
||||
@[inline] def withoutMacroStackAtErr {α} (x : TermElabM α) : TermElabM α :=
|
||||
adaptTheReader Core.Context (fun (ctx : Core.Context) => { ctx with options := setMacroStackOption ctx.options false }) x
|
||||
withTheReader Core.Context (fun (ctx : Core.Context) => { ctx with options := setMacroStackOption ctx.options false }) x
|
||||
|
||||
/- Try to synthesize metavariable using type class resolution.
|
||||
This method assumes the local context and local instances of `instMVar` coincide
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ instance MonadQuotation : MonadQuotation Unhygienic := {
|
|||
getMainModule := pure `UnhygienicMain,
|
||||
withFreshMacroScope := fun α x => do
|
||||
fresh ← modifyGet (fun n => (n, n + 1));
|
||||
adaptReader (fun _ => fresh) x
|
||||
withReader (fun _ => fresh) x
|
||||
}
|
||||
protected def run {α : Type} (x : Unhygienic α) : α := run x firstFrontendMacroScope (firstFrontendMacroScope+1)
|
||||
end Unhygienic
|
||||
|
|
|
|||
|
|
@ -413,7 +413,7 @@ def elimMVarDeps (xs : Array Expr) (e : Expr) (preserveOrder : Bool := false) :
|
|||
if xs.isEmpty then pure e else liftMkBindingM $ MetavarContext.elimMVarDeps xs e preserveOrder
|
||||
|
||||
@[inline] def withConfig {α} (f : Config → Config) : n α → n α :=
|
||||
mapMetaM fun _ => adaptReader (fun (ctx : Context) => { ctx with config := f ctx.config })
|
||||
mapMetaM fun _ => withReader (fun ctx => { ctx with config := f ctx.config })
|
||||
|
||||
@[inline] def withTrackingZeta {α} (x : n α) : n α :=
|
||||
withConfig (fun cfg => { cfg with trackZeta := true }) x
|
||||
|
|
@ -525,8 +525,8 @@ match localDecl.binderInfo with
|
|||
| BinderInfo.auxDecl => k
|
||||
| _ =>
|
||||
resettingSynthInstanceCache $
|
||||
adaptReader
|
||||
(fun (ctx : Context) => { ctx with localInstances := ctx.localInstances.push { className := className, fvar := fvar } })
|
||||
withReader
|
||||
(fun ctx => { ctx with localInstances := ctx.localInstances.push { className := className, fvar := fvar } })
|
||||
k
|
||||
|
||||
/-- Add entry `{ className := className, fvar := fvar }` to localInstances,
|
||||
|
|
@ -610,12 +610,12 @@ match maxFVars? with
|
|||
process ()
|
||||
else
|
||||
let type := type.instantiateRevRange j fvars.size fvars;
|
||||
adaptReader (fun (ctx : Context) => { ctx with lctx := lctx }) $
|
||||
withReader (fun ctx => { ctx with lctx := lctx }) $
|
||||
withNewLocalInstancesImp isClassExpensive? fvars j $
|
||||
k fvars type
|
||||
| lctx, fvars, j, type =>
|
||||
let type := type.instantiateRevRange j fvars.size fvars;
|
||||
adaptReader (fun (ctx : Context) => { ctx with lctx := lctx }) $
|
||||
withReader (fun ctx => { ctx with lctx := lctx }) $
|
||||
withNewLocalInstancesImp isClassExpensive? fvars j $
|
||||
if reducing? && fvarsSizeLtMaxFVars fvars maxFVars? then do
|
||||
newType ← whnf type;
|
||||
|
|
@ -714,7 +714,7 @@ private partial def lambdaTelescopeAux {α}
|
|||
lambdaTelescopeAux true lctx (fvars.push fvar) j b
|
||||
| _, lctx, fvars, j, e =>
|
||||
let e := e.instantiateRevRange j fvars.size fvars;
|
||||
adaptReader (fun (ctx : Context) => { ctx with lctx := lctx }) $
|
||||
withReader (fun ctx => { ctx with lctx := lctx }) $
|
||||
withNewLocalInstancesImp isClassExpensive? fvars j $ do
|
||||
k fvars e
|
||||
|
||||
|
|
@ -822,7 +822,7 @@ fvarId ← mkFreshId;
|
|||
ctx ← read;
|
||||
let lctx := ctx.lctx.mkLocalDecl fvarId n type bi;
|
||||
let fvar := mkFVar fvarId;
|
||||
adaptReader (fun (ctx : Context) => { ctx with lctx := lctx }) $
|
||||
withReader (fun ctx => { ctx with lctx := lctx }) $
|
||||
withNewFVar fvar type k
|
||||
|
||||
def withLocalDecl {α} (name : Name) (bi : BinderInfo) (type : Expr) (k : Expr → n α) : n α :=
|
||||
|
|
@ -836,7 +836,7 @@ fvarId ← mkFreshId;
|
|||
ctx ← read;
|
||||
let lctx := ctx.lctx.mkLetDecl fvarId n type val;
|
||||
let fvar := mkFVar fvarId;
|
||||
adaptReader (fun (ctx : Context) => { ctx with lctx := lctx }) $
|
||||
withReader (fun ctx => { ctx with lctx := lctx }) $
|
||||
withNewFVar fvar type k
|
||||
|
||||
def withLetDecl {α} (name : Name) (type : Expr) (val : Expr) (k : Expr → n α) : n α :=
|
||||
|
|
@ -846,7 +846,7 @@ private def withExistingLocalDeclsImp {α} (decls : List LocalDecl) (k : MetaM
|
|||
ctx ← read;
|
||||
let numLocalInstances := ctx.localInstances.size;
|
||||
let lctx := decls.foldl (fun (lctx : LocalContext) decl => lctx.addDecl decl) ctx.lctx;
|
||||
adaptReader (fun (ctx : Context) => { ctx with lctx := lctx }) do
|
||||
withReader (fun ctx => { ctx with lctx := lctx }) do
|
||||
newLocalInsts ← decls.foldlM
|
||||
(fun (newlocalInsts : Array LocalInstance) (decl : LocalDecl) => (do {
|
||||
c? ← isClass? decl.type;
|
||||
|
|
@ -857,7 +857,7 @@ adaptReader (fun (ctx : Context) => { ctx with lctx := lctx }) do
|
|||
if newLocalInsts.size == numLocalInstances then
|
||||
k
|
||||
else
|
||||
resettingSynthInstanceCache $ adaptReader (fun (ctx : Context) => { ctx with localInstances := newLocalInsts }) k
|
||||
resettingSynthInstanceCache $ withReader (fun ctx => { ctx with localInstances := newLocalInsts }) k
|
||||
|
||||
def withExistingLocalDecls {α} (decls : List LocalDecl) : n α → n α :=
|
||||
mapMetaM fun _ => withExistingLocalDeclsImp decls
|
||||
|
|
@ -876,7 +876,7 @@ mapMetaM fun _ => withNewMCtxDepthImp
|
|||
|
||||
private def withLocalContextImp {α} (lctx : LocalContext) (localInsts : LocalInstances) (x : MetaM α) : MetaM α := do
|
||||
localInstsCurr ← getLocalInstances;
|
||||
adaptReader (fun (ctx : Context) => { ctx with lctx := lctx, localInstances := localInsts }) $
|
||||
withReader (fun ctx => { ctx with lctx := lctx, localInstances := localInsts }) $
|
||||
if localInsts == localInstsCurr then
|
||||
x
|
||||
else
|
||||
|
|
|
|||
|
|
@ -217,7 +217,7 @@ private partial def isDefEqBindingAux : LocalContext → Array Expr → Expr →
|
|||
| Expr.forallE n d₁ b₁ _, Expr.forallE _ d₂ b₂ _ => process n d₁ d₂ b₁ b₂
|
||||
| Expr.lam n d₁ b₁ _, Expr.lam _ d₂ b₂ _ => process n d₁ d₂ b₁ b₂
|
||||
| _, _ =>
|
||||
adaptReader (fun (ctx : Context) => { ctx with lctx := lctx }) $
|
||||
withReader (fun ctx => { ctx with lctx := lctx }) $
|
||||
isDefEqBindingDomain fvars ds₂ 0 $
|
||||
Meta.isExprDefEqAux (e₁.instantiateRev fvars) (e₂.instantiateRev fvars)
|
||||
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ lambdaLetTelescope e $ fun xs e => do
|
|||
@[inline] private def withLocalDecl {α} (name : Name) (bi : BinderInfo) (type : Expr) (x : Expr → MetaM α) : MetaM α :=
|
||||
savingCache $ do
|
||||
fvarId ← mkFreshId;
|
||||
adaptReader (fun (ctx : Context) => { ctx with lctx := ctx.lctx.mkLocalDecl fvarId name type bi }) $
|
||||
withReader (fun ctx => { ctx with lctx := ctx.lctx.mkLocalDecl fvarId name type bi }) $
|
||||
x (mkFVar fvarId)
|
||||
|
||||
def throwUnknownMVar {α} (mvarId : MVarId) : MetaM α :=
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ private partial def introNImpAux {σ} (mvarId : MVarId) (mkName : LocalContext
|
|||
: Nat → LocalContext → Array Expr → Nat → σ → Expr → MetaM (Array Expr × MVarId)
|
||||
| 0, lctx, fvars, j, _, type =>
|
||||
let type := type.instantiateRevRange j fvars.size fvars;
|
||||
adaptReader (fun (ctx : Context) => { ctx with lctx := lctx }) $
|
||||
withReader (fun ctx => { ctx with lctx := lctx }) $
|
||||
withNewLocalInstances fvars j $ do
|
||||
tag ← getMVarTag mvarId;
|
||||
let type := type.headBeta;
|
||||
|
|
@ -43,7 +43,7 @@ private partial def introNImpAux {σ} (mvarId : MVarId) (mkName : LocalContext
|
|||
introNImpAux i lctx fvars j s body
|
||||
| (i+1), lctx, fvars, j, s, type =>
|
||||
let type := type.instantiateRevRange j fvars.size fvars;
|
||||
adaptReader (fun (ctx : Context) => { ctx with lctx := lctx }) $
|
||||
withReader (fun ctx => { ctx with lctx := lctx }) $
|
||||
withNewLocalInstances fvars j $ do
|
||||
newType ← whnf type;
|
||||
if newType.isForall then
|
||||
|
|
|
|||
|
|
@ -275,7 +275,7 @@ def withAntiquot.parenthesizer (antiP p : Parenthesizer) : Parenthesizer :=
|
|||
orelse.parenthesizer antiP p
|
||||
|
||||
def parenthesizeCategoryCore (cat : Name) (prec : Nat) : Parenthesizer :=
|
||||
adaptReader (fun (ctx : Context) => { ctx with cat := cat }) do
|
||||
withReader (fun ctx => { ctx with cat := cat }) do
|
||||
stx ← getCur;
|
||||
if stx.getKind == `choice then
|
||||
visitArgs $ stx.getArgs.size.forM $ fun _ => do
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ pure s
|
|||
#eval (f "hello").run' 10 true
|
||||
|
||||
def g : M Nat :=
|
||||
let a : M Nat := adaptTheReader Bool not f
|
||||
adaptReader (fun s => s ++ " world") a
|
||||
let a : M Nat := withTheReader Bool not f
|
||||
withReader (fun s => s ++ " world") a
|
||||
|
||||
#eval (g "hello").run' 10 true
|
||||
|
|
|
|||
|
|
@ -40,8 +40,8 @@ do traceCtx `module $ do {
|
|||
trace! `slow ("slow message: " ++ toString (slow b))
|
||||
|
||||
def run (x : M Unit) : M Unit :=
|
||||
adaptReader
|
||||
(fun (ctx : Core.Context) =>
|
||||
withReader
|
||||
(fun ctx =>
|
||||
-- Try commeting/uncommeting the following `setBool`s
|
||||
let opts := ctx.options;
|
||||
let opts := opts.setBool `trace.module true;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue