chore: try macros first

This commit is contained in:
Leonardo de Moura 2020-02-01 00:53:49 -08:00
parent ca919c2021
commit cd4ec6313e
2 changed files with 38 additions and 12 deletions

View file

@ -480,18 +480,20 @@ partial def elabTermAux (expectedType? : Option Expr) (catchExPostpone := true)
| stx => withFreshMacroScope $ withIncRecDepth stx $ do
trace `Elab.step stx $ fun _ => stx;
s ← get;
let table := (termElabAttribute.ext.getState s.env).table;
let k := stx.getKind;
match table.find? k with
| some elabFns => elabTermUsing s stx expectedType? catchExPostpone elabFns
| none => do
env ← getEnv;
stx' ← catch
(adaptMacro (getMacros env) stx)
(fun ex => match ex with
| Exception.ex Elab.Exception.unsupportedSyntax => throwError stx ("elaboration function for '" ++ toString k ++ "' has not been implemented")
| _ => throw ex);
withMacroExpansion stx stx' $ elabTermAux stx'
env ← getEnv;
stxNew? ← catch
(do newStx ← adaptMacro (getMacros env) stx; pure (some newStx))
(fun ex => match ex with
| Exception.ex Elab.Exception.unsupportedSyntax => pure none
| _ => throw ex);
match stxNew? with
| some stxNew => withMacroExpansion stx stxNew $ elabTermAux stxNew
| _ =>
let table := (termElabAttribute.ext.getState s.env).table;
let k := stx.getKind;
match table.find? k with
| some elabFns => elabTermUsing s stx expectedType? catchExPostpone elabFns
| none => throwError stx ("elaboration function for '" ++ toString k ++ "' has not been implemented")
/--
Main function for elaborating terms.

View file

@ -1,3 +1,19 @@
open Lean
partial def expandHash : Syntax → StateT Bool MacroM Syntax
| Syntax.node k args =>
if k == `doHash then do set true; `((^MonadState.get))
else do
args ← args.mapM expandHash;
pure $ Syntax.node k args
| stx => pure stx
@[macro Lean.Parser.Term.do] def expandDo : Macro :=
fun stx => do
(stx, expanded) ← expandHash stx false;
if expanded then pure stx
else Macro.throwUnsupported
new_frontend
def f : IO Nat :=
@ -26,3 +42,11 @@ else do
x ← f;
y ← g x;
IO.println y
syntax [doHash] "#":max : term
def tst4 : StateT (Nat × Nat) IO Unit := do
if #.1 == 0 then
IO.println "first field is zero"
else
IO.println "first field is not zero"