diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index 2fc3790992..d194cf3951 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -200,6 +200,19 @@ inductive Code where | jmp (ref : Syntax) (jpName : Name) (args : Array Syntax) deriving Inhabited +def Code.getRef? : Code → Option Syntax + | .decl _ doElem _ => doElem + | .reassign _ doElem _ => doElem + | .joinpoint .. => none + | .seq a _ => a + | .action a => a + | .break ref => ref + | .continue ref => ref + | .return ref _ => ref + | .ite ref .. => ref + | .match ref .. => ref + | .jmp ref .. => ref + abbrev VarSet := Std.RBMap Name Syntax Name.cmp /-- A code block, and the collection of variables updated by it. -/ @@ -1014,25 +1027,32 @@ def mkJmp (ref : Syntax) (j : Name) (args : Array Syntax) : Syntax := Syntax.mkApp (mkIdentFrom ref j) args partial def toTerm (c : Code) : M Syntax := do - match c with - | Code.return ref val => withRef ref <| returnToTerm val - | Code.continue ref => withRef ref continueToTerm - | Code.break ref => withRef ref breakToTerm - | Code.action e => actionTerminalToTerm e - | Code.joinpoint j ps b k => mkJoinPoint j ps (← toTerm b) (← toTerm k) - | Code.jmp ref j args => return mkJmp ref j args - | Code.decl _ stx k => declToTerm stx (← toTerm k) - | Code.reassign _ stx k => reassignToTerm stx (← toTerm k) - | Code.seq stx k => seqToTerm stx (← toTerm k) - | Code.ite ref _ o c t e => withRef ref <| do mkIte o c (← toTerm t) (← toTerm e) - | Code.«match» ref genParam discrs optMotive alts => - let mut termAlts := #[] - for alt in alts do - let rhs ← toTerm alt.rhs - let termAlt := mkNode `Lean.Parser.Term.matchAlt #[mkAtomFrom alt.ref "|", mkNullNode #[alt.patterns], mkAtomFrom alt.ref "=>", rhs] - termAlts := termAlts.push termAlt - let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode termAlts] - return mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts] + let term ← go c + if let some ref := c.getRef? then + `(with_annotate_term $ref $term) + else + return term +where + go (c : Code) : M Syntax := do + match c with + | Code.return ref val => withRef ref <| returnToTerm val + | Code.continue ref => withRef ref continueToTerm + | Code.break ref => withRef ref breakToTerm + | Code.action e => actionTerminalToTerm e + | Code.joinpoint j ps b k => mkJoinPoint j ps (← toTerm b) (← toTerm k) + | Code.jmp ref j args => return mkJmp ref j args + | Code.decl _ stx k => declToTerm stx (← toTerm k) + | Code.reassign _ stx k => reassignToTerm stx (← toTerm k) + | Code.seq stx k => seqToTerm stx (← toTerm k) + | Code.ite ref _ o c t e => withRef ref <| do mkIte o c (← toTerm t) (← toTerm e) + | Code.«match» ref genParam discrs optMotive alts => + let mut termAlts := #[] + for alt in alts do + let rhs ← toTerm alt.rhs + let termAlt := mkNode `Lean.Parser.Term.matchAlt #[mkAtomFrom alt.ref "|", mkNullNode #[alt.patterns], mkAtomFrom alt.ref "=>", rhs] + termAlts := termAlts.push termAlt + let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode termAlts] + return mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts] def run (code : Code) (m : Syntax) (returnType : Syntax) (uvars : Array Var := #[]) (kind := Kind.regular) : MacroM Syntax := toTerm code { m, returnType, kind, uvars } @@ -1421,9 +1441,11 @@ mutual let forInBody ← liftMacroM <| destructTuple uvars (← `(r)) forInBody let optType ← `(Option $((← read).returnType)) let forInTerm ← if let some h := h? then - `(for_in'% $(xs) (MProd.mk (none : $optType) $uvarsTuple) fun $x $h (r : MProd $optType _) => let r := r.2; $forInBody) + `(with_annotate_term $doFor + for_in'% $(xs) (MProd.mk (none : $optType) $uvarsTuple) fun $x $h (r : MProd $optType _) => let r := r.2; $forInBody) else - `(for_in% $(xs) (MProd.mk (none : $optType) $uvarsTuple) fun $x (r : MProd $optType _) => let r := r.2; $forInBody) + `(with_annotate_term $doFor + for_in% $(xs) (MProd.mk (none : $optType) $uvarsTuple) fun $x (r : MProd $optType _) => let r := r.2; $forInBody) let auxDo ← `(do let r ← $forInTerm:term; $uvarsTuple:term := r.2; match r.1 with @@ -1433,9 +1455,11 @@ mutual else let forInBody ← liftMacroM <| destructTuple uvars (← `(r)) forInBody let forInTerm ← if let some h := h? then - `(for_in'% $(xs) $uvarsTuple fun $x $h r => $forInBody) + `(with_annotate_term $doFor + for_in'% $(xs) $uvarsTuple fun $x $h r => $forInBody) else - `(for_in% $(xs) $uvarsTuple fun $x r => $forInBody) + `(with_annotate_term $doFor + for_in% $(xs) $uvarsTuple fun $x r => $forInBody) if doElems.isEmpty then let auxDo ← `(do let r ← $forInTerm:term; $uvarsTuple:term := r; @@ -1507,10 +1531,10 @@ mutual let term ← catches.foldlM (init := term) fun term «catch» => do let catchTerm ← toTerm «catch».codeBlock if catch.optType.isNone then - ``(MonadExcept.tryCatch $term (fun $(«catch».x):ident => $catchTerm)) + `(with_annotate_term $doTry MonadExcept.tryCatch $term (fun $(«catch».x):ident => $catchTerm)) else let type := «catch».optType[1] - ``(tryCatchThe $type $term (fun $(«catch».x):ident => $catchTerm)) + `(with_annotate_term $doTry tryCatchThe $type $term (fun $(«catch».x):ident => $catchTerm)) let term ← match finallyCode? with | none => pure term | some finallyCode => withRef optFinally do @@ -1519,7 +1543,7 @@ mutual if hasBreakContinueReturn finallyCode.code then throwError "`finally` currently does `return`, `break`, nor `continue`" let finallyTerm ← liftMacroM <| ToTerm.run finallyCode.code ctx.m ctx.returnType {} ToTerm.Kind.regular - ``(tryFinally $term $finallyTerm) + `(with_annotate_term $doTry tryFinally $term $finallyTerm) let doElemsNew ← liftMacroM <| ToTerm.matchNestedTermResult term uvars a r bc doSeqToCode (doElemsNew ++ doElems) diff --git a/tests/lean/interactive/plainTermGoal.lean.expected.out b/tests/lean/interactive/plainTermGoal.lean.expected.out index f4855befc2..1e85669f56 100644 --- a/tests/lean/interactive/plainTermGoal.lean.expected.out +++ b/tests/lean/interactive/plainTermGoal.lean.expected.out @@ -26,8 +26,8 @@ {"textDocument": {"uri": "file://plainTermGoal.lean"}, "position": {"line": 11, "character": 10}} {"range": - {"start": {"line": 9, "character": 25}, "end": {"line": 13, "character": 11}}, - "goal": "⊢ Option Unit"} + {"start": {"line": 11, "character": 2}, "end": {"line": 11, "character": 19}}, + "goal": "y : Int\n⊢ Option Unit"} {"textDocument": {"uri": "file://plainTermGoal.lean"}, "position": {"line": 16, "character": 17}} {"range":