feat: generalize observing

We can now "observe" any `TermElabM α` action.
This commit is contained in:
Leonardo de Moura 2021-01-18 15:10:42 -08:00
parent 0a6d83127d
commit 6df66bf0ac
2 changed files with 12 additions and 11 deletions

View file

@ -767,8 +767,8 @@ false, no elaboration function executed by `x` will reset it to
-/
private partial def elabAppFnId (fIdent : Syntax) (fExplicitUnivs : List Level) (lvals : List LVal)
(namedArgs : Array NamedArg) (args : Array Arg) (expectedType? : Option Expr) (explicit ellipsis overloaded : Bool) (acc : Array TermElabResult)
: TermElabM (Array TermElabResult) := do
(namedArgs : Array NamedArg) (args : Array Arg) (expectedType? : Option Expr) (explicit ellipsis overloaded : Bool) (acc : Array (TermElabResult Expr))
: TermElabM (Array (TermElabResult Expr)) := do
let funLVals ← withRef fIdent <| resolveName' fIdent fExplicitUnivs
let overloaded := overloaded || funLVals.length > 1
-- Set `errToSorry` to `false` if `funLVals` > 1. See comment above about the interaction between `errToSorry` and `observing`.
@ -784,7 +784,7 @@ private partial def elabAppFnId (fIdent : Syntax) (fExplicitUnivs : List Level)
private partial def elabAppFn (f : Syntax) (lvals : List LVal) (namedArgs : Array NamedArg) (args : Array Arg)
(expectedType? : Option Expr) (explicit ellipsis overloaded : Bool) (acc : Array TermElabResult) : TermElabM (Array TermElabResult) :=
(expectedType? : Option Expr) (explicit ellipsis overloaded : Bool) (acc : Array (TermElabResult Expr)) : TermElabM (Array (TermElabResult Expr)) :=
if f.getKind == choiceKind then
-- Set `errToSorry` to `false` when processing choice nodes. See comment above about the interaction between `errToSorry` and `observing`.
withReader (fun ctx => { ctx with errToSorry := false }) do
@ -834,12 +834,12 @@ private partial def elabAppFn (f : Syntax) (lvals : List LVal) (namedArgs : Arra
if overloaded then ensureHasType expectedType? e else pure e
pure $ acc.push s
private def isSuccess (candidate : TermElabResult) : Bool :=
private def isSuccess (candidate : TermElabResult Expr) : Bool :=
match candidate with
| EStateM.Result.ok _ _ => true
| _ => false
private def getSuccess (candidates : Array TermElabResult) : Array TermElabResult :=
private def getSuccess (candidates : Array (TermElabResult Expr)) : Array (TermElabResult Expr) :=
candidates.filter isSuccess
private def toMessageData (ex : Exception) : TermElabM MessageData := do
@ -856,7 +856,7 @@ private def toMessageData (ex : Exception) : TermElabM MessageData := do
private def toMessageList (msgs : Array MessageData) : MessageData :=
indentD (MessageData.joinSep msgs.toList m!"\n\n")
private def mergeFailures {α} (failures : Array TermElabResult) : TermElabM α := do
private def mergeFailures {α} (failures : Array (TermElabResult Expr)) : TermElabM α := do
let msgs ← failures.mapM fun failure =>
match failure with
| EStateM.Result.ok _ _ => unreachable!

View file

@ -169,8 +169,9 @@ def SavedState.restore (s : SavedState) : TermElabM Unit := do
set s.elab
setTraceState traceState
abbrev TermElabResult := EStateM.Result Exception SavedState Expr
instance : Inhabited TermElabResult where
abbrev TermElabResult (α : Type) := EStateM.Result Exception SavedState α
instance [Inhabited α] : Inhabited (TermElabResult α) where
default := EStateM.Result.ok arbitrary arbitrary
def setMessageLog (messages : MessageLog) : TermElabM Unit :=
@ -186,7 +187,7 @@ def getMessageLog : TermElabM MessageLog :=
Execute `x`, save resulting expression and new state.
If `x` fails, then it also stores exception and new state.
Remark: we do not capture `Exception.postpone`. -/
@[inline] def observing (x : TermElabM Expr) : TermElabM TermElabResult := do
@[inline] def observing (x : TermElabM α) : TermElabM (TermElabResult α) := do
let s ← saveAllState
try
let e ← x
@ -205,9 +206,9 @@ def getMessageLog : TermElabM MessageLog :=
/--
Apply the result/exception and state captured with `observing`.
We use this method to implement overloaded notation and symbols. -/
def applyResult (result : TermElabResult) : TermElabM Expr :=
def applyResult (result : TermElabResult α) : TermElabM α :=
match result with
| EStateM.Result.ok e r => do r.restore; pure e
| EStateM.Result.ok a r => do r.restore; pure a
| EStateM.Result.error ex r => do r.restore; throw ex
@[inline] protected def liftMetaM {α} (x : MetaM α) : TermElabM α :=