feat: handle explicit node kinds in macro_rules, and handle choice kind

This commit is contained in:
Leonardo de Moura 2020-01-26 09:39:46 -08:00
parent 6a1712717f
commit 7133d3fc84
2 changed files with 59 additions and 13 deletions

View file

@ -180,26 +180,54 @@ fun stx => do
trace `Elab stx $ fun _ => d;
withMacroExpansion stx d $ elabCommand d
def getMacroRulesAltKind (alt : Syntax) : CommandElabM SyntaxNodeKind :=
def elabMacroRulesAux (k : SyntaxNodeKind) (alts : Array Syntax) : CommandElabM Syntax := do
alts ← alts.mapSepElemsM $ fun alt => do {
let lhs := alt.getArg 0;
let pat := lhs.getArg 0;
match_syntax pat with
| `(`($quot)) =>
let k' := quot.getKind;
if k' == k then
pure alt
else if k' == choiceKind then do
match quot.getArgs.find? $ fun quotAlt => if quotAlt.getKind == k then some quotAlt else none with
| none => throwError alt ("invalid macro_rules alternative, expected syntax node kind '" ++ k ++ "'")
| some quot => do
pat ← `(`($quot));
let lhs := lhs.setArg 0 pat;
pure $ alt.setArg 0 lhs
else
throwError alt ("invalid macro_rules alternative, unexpected syntax node kind '" ++ k' ++ "'")
| stx => throwUnsupportedSyntax
};
`(@[macro $(Lean.mkIdent k)] def myMacro : Macro := fun stx => match_syntax stx with $alts:matchAlt* | _ => throw Lean.Macro.Exception.unsupportedSyntax)
def inferMacroRulesAltKind (alt : Syntax) : CommandElabM SyntaxNodeKind :=
match_syntax (alt.getArg 0).getArg 0 with
| `(`($quot)) => pure quot.getKind
| stx => throwUnsupportedSyntax
def elabMacroRulesAux (alts : Array Syntax) : CommandElabM Syntax := do
k ← getMacroRulesAltKind (alts.get! 0);
altsK ← alts.filterSepElemsM (fun alt => do k' ← getMacroRulesAltKind alt; pure $ k == k');
altsNotK ← alts.filterSepElemsM (fun alt => do k' ← getMacroRulesAltKind alt; pure $ k != k');
altsKDef ← `(@[macro $(Lean.mkIdent k)] def myMacro : Macro := fun stx => match_syntax stx with $altsK:matchAlt* | _ => throw Lean.Macro.Exception.unsupportedSyntax);
if altsNotK.isEmpty then
pure altsKDef
else
`($altsKDef:command macro_rules $altsNotK:matchAlt*)
def elabNoKindMacroRulesAux (alts : Array Syntax) : CommandElabM Syntax := do
k ← inferMacroRulesAltKind (alts.get! 0);
if k == choiceKind then
throwError (alts.get! 0)
"invalid macro_rules alternative, multiple interpretations for first element (solution: specify node kind using `macro_rules [<kind>] ...`)"
else do
altsK ← alts.filterSepElemsM (fun alt => do k' ← inferMacroRulesAltKind alt; pure $ k == k');
altsNotK ← alts.filterSepElemsM (fun alt => do k' ← inferMacroRulesAltKind alt; pure $ k != k');
defCmd ← elabMacroRulesAux k altsK;
if altsNotK.isEmpty then
pure defCmd
else
`($defCmd:command macro_rules $altsNotK:matchAlt*)
@[builtinCommandElab «macro_rules»] def elabMacroRules : CommandElab :=
adaptExpander $ fun stx => match_syntax stx with
| `(macro_rules $alts:matchAlt*) => elabMacroRulesAux alts
| `(macro_rules | $alts:matchAlt*) => elabMacroRulesAux alts
| _ => throwUnsupportedSyntax
| `(macro_rules $alts:matchAlt*) => elabNoKindMacroRulesAux alts
| `(macro_rules | $alts:matchAlt*) => elabNoKindMacroRulesAux alts
| `(macro_rules [$kind] $alts:matchAlt*) => elabMacroRulesAux kind.getId alts
| `(macro_rules [$kind] | $alts:matchAlt*) => elabMacroRulesAux kind.getId alts
| _ => throwUnsupportedSyntax
/- We just ignore Lean3 notation declaration commands. -/
@[builtinCommandElab «mixfix»] def elabMixfix : CommandElab := fun _ => pure ()

View file

@ -782,4 +782,22 @@ filterSepElemsMAux a p 0 #[]
def filterSepElems (a : Array Syntax) (p : Syntax → Bool) : Array Syntax :=
Id.run $ a.filterSepElemsM p
private partial def mapSepElemsMAux {m : Type → Type} [Monad m] (a : Array Syntax) (f : Syntax → m Syntax) : Nat → Array Syntax → m (Array Syntax)
| i, acc =>
if h : i < a.size then do
let stx := a.get ⟨i, h⟩;
if i % 2 == 0 then do
stx ← f stx;
mapSepElemsMAux (i+1) (acc.push stx)
else
mapSepElemsMAux (i+1) (acc.push stx)
else
pure acc
def mapSepElemsM {m : Type → Type} [Monad m] (a : Array Syntax) (f : Syntax → m Syntax) : m (Array Syntax) :=
mapSepElemsMAux a f 0 #[]
def mapSepElems (a : Array Syntax) (f : Syntax → Syntax) : Array Syntax :=
Id.run $ a.mapSepElemsM f
end Array