diff --git a/src/Init/Lean/Elab/Term.lean b/src/Init/Lean/Elab/Term.lean index 3a34e3dd81..cd360dcf85 100644 --- a/src/Init/Lean/Elab/Term.lean +++ b/src/Init/Lean/Elab/Term.lean @@ -480,18 +480,20 @@ partial def elabTermAux (expectedType? : Option Expr) (catchExPostpone := true) | stx => withFreshMacroScope $ withIncRecDepth stx $ do trace `Elab.step stx $ fun _ => stx; s ← get; - let table := (termElabAttribute.ext.getState s.env).table; - let k := stx.getKind; - match table.find? k with - | some elabFns => elabTermUsing s stx expectedType? catchExPostpone elabFns - | none => do - env ← getEnv; - stx' ← catch - (adaptMacro (getMacros env) stx) - (fun ex => match ex with - | Exception.ex Elab.Exception.unsupportedSyntax => throwError stx ("elaboration function for '" ++ toString k ++ "' has not been implemented") - | _ => throw ex); - withMacroExpansion stx stx' $ elabTermAux stx' + env ← getEnv; + stxNew? ← catch + (do newStx ← adaptMacro (getMacros env) stx; pure (some newStx)) + (fun ex => match ex with + | Exception.ex Elab.Exception.unsupportedSyntax => pure none + | _ => throw ex); + match stxNew? with + | some stxNew => withMacroExpansion stx stxNew $ elabTermAux stxNew + | _ => + let table := (termElabAttribute.ext.getState s.env).table; + let k := stx.getKind; + match table.find? k with + | some elabFns => elabTermUsing s stx expectedType? catchExPostpone elabFns + | none => throwError stx ("elaboration function for '" ++ toString k ++ "' has not been implemented") /-- Main function for elaborating terms. diff --git a/tests/lean/run/doNotation1.lean b/tests/lean/run/doNotation1.lean index 98a492616b..d317fc3181 100644 --- a/tests/lean/run/doNotation1.lean +++ b/tests/lean/run/doNotation1.lean @@ -1,3 +1,19 @@ +open Lean + +partial def expandHash : Syntax → StateT Bool MacroM Syntax +| Syntax.node k args => + if k == `doHash then do set true; `((^MonadState.get)) + else do + args ← args.mapM expandHash; + pure $ Syntax.node k args +| stx => pure stx + +@[macro Lean.Parser.Term.do] def expandDo : Macro := +fun stx => do + (stx, expanded) ← expandHash stx false; + if expanded then pure stx + else Macro.throwUnsupported + new_frontend def f : IO Nat := @@ -26,3 +42,11 @@ else do x ← f; y ← g x; IO.println y + +syntax [doHash] "#":max : term + +def tst4 : StateT (Nat × Nat) IO Unit := do +if #.1 == 0 then + IO.println "first field is zero" +else + IO.println "first field is not zero"