feat: handle explicit node kinds in macro_rules, and handle choice kind
This commit is contained in:
parent
6a1712717f
commit
7133d3fc84
2 changed files with 59 additions and 13 deletions
|
|
@ -180,26 +180,54 @@ fun stx => do
|
|||
trace `Elab stx $ fun _ => d;
|
||||
withMacroExpansion stx d $ elabCommand d
|
||||
|
||||
def getMacroRulesAltKind (alt : Syntax) : CommandElabM SyntaxNodeKind :=
|
||||
def elabMacroRulesAux (k : SyntaxNodeKind) (alts : Array Syntax) : CommandElabM Syntax := do
|
||||
alts ← alts.mapSepElemsM $ fun alt => do {
|
||||
let lhs := alt.getArg 0;
|
||||
let pat := lhs.getArg 0;
|
||||
match_syntax pat with
|
||||
| `(`($quot)) =>
|
||||
let k' := quot.getKind;
|
||||
if k' == k then
|
||||
pure alt
|
||||
else if k' == choiceKind then do
|
||||
match quot.getArgs.find? $ fun quotAlt => if quotAlt.getKind == k then some quotAlt else none with
|
||||
| none => throwError alt ("invalid macro_rules alternative, expected syntax node kind '" ++ k ++ "'")
|
||||
| some quot => do
|
||||
pat ← `(`($quot));
|
||||
let lhs := lhs.setArg 0 pat;
|
||||
pure $ alt.setArg 0 lhs
|
||||
else
|
||||
throwError alt ("invalid macro_rules alternative, unexpected syntax node kind '" ++ k' ++ "'")
|
||||
| stx => throwUnsupportedSyntax
|
||||
};
|
||||
`(@[macro $(Lean.mkIdent k)] def myMacro : Macro := fun stx => match_syntax stx with $alts:matchAlt* | _ => throw Lean.Macro.Exception.unsupportedSyntax)
|
||||
|
||||
def inferMacroRulesAltKind (alt : Syntax) : CommandElabM SyntaxNodeKind :=
|
||||
match_syntax (alt.getArg 0).getArg 0 with
|
||||
| `(`($quot)) => pure quot.getKind
|
||||
| stx => throwUnsupportedSyntax
|
||||
|
||||
def elabMacroRulesAux (alts : Array Syntax) : CommandElabM Syntax := do
|
||||
k ← getMacroRulesAltKind (alts.get! 0);
|
||||
altsK ← alts.filterSepElemsM (fun alt => do k' ← getMacroRulesAltKind alt; pure $ k == k');
|
||||
altsNotK ← alts.filterSepElemsM (fun alt => do k' ← getMacroRulesAltKind alt; pure $ k != k');
|
||||
altsKDef ← `(@[macro $(Lean.mkIdent k)] def myMacro : Macro := fun stx => match_syntax stx with $altsK:matchAlt* | _ => throw Lean.Macro.Exception.unsupportedSyntax);
|
||||
if altsNotK.isEmpty then
|
||||
pure altsKDef
|
||||
else
|
||||
`($altsKDef:command macro_rules $altsNotK:matchAlt*)
|
||||
def elabNoKindMacroRulesAux (alts : Array Syntax) : CommandElabM Syntax := do
|
||||
k ← inferMacroRulesAltKind (alts.get! 0);
|
||||
if k == choiceKind then
|
||||
throwError (alts.get! 0)
|
||||
"invalid macro_rules alternative, multiple interpretations for first element (solution: specify node kind using `macro_rules [<kind>] ...`)"
|
||||
else do
|
||||
altsK ← alts.filterSepElemsM (fun alt => do k' ← inferMacroRulesAltKind alt; pure $ k == k');
|
||||
altsNotK ← alts.filterSepElemsM (fun alt => do k' ← inferMacroRulesAltKind alt; pure $ k != k');
|
||||
defCmd ← elabMacroRulesAux k altsK;
|
||||
if altsNotK.isEmpty then
|
||||
pure defCmd
|
||||
else
|
||||
`($defCmd:command macro_rules $altsNotK:matchAlt*)
|
||||
|
||||
@[builtinCommandElab «macro_rules»] def elabMacroRules : CommandElab :=
|
||||
adaptExpander $ fun stx => match_syntax stx with
|
||||
| `(macro_rules $alts:matchAlt*) => elabMacroRulesAux alts
|
||||
| `(macro_rules | $alts:matchAlt*) => elabMacroRulesAux alts
|
||||
| _ => throwUnsupportedSyntax
|
||||
| `(macro_rules $alts:matchAlt*) => elabNoKindMacroRulesAux alts
|
||||
| `(macro_rules | $alts:matchAlt*) => elabNoKindMacroRulesAux alts
|
||||
| `(macro_rules [$kind] $alts:matchAlt*) => elabMacroRulesAux kind.getId alts
|
||||
| `(macro_rules [$kind] | $alts:matchAlt*) => elabMacroRulesAux kind.getId alts
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
/- We just ignore Lean3 notation declaration commands. -/
|
||||
@[builtinCommandElab «mixfix»] def elabMixfix : CommandElab := fun _ => pure ()
|
||||
|
|
|
|||
|
|
@ -782,4 +782,22 @@ filterSepElemsMAux a p 0 #[]
|
|||
def filterSepElems (a : Array Syntax) (p : Syntax → Bool) : Array Syntax :=
|
||||
Id.run $ a.filterSepElemsM p
|
||||
|
||||
private partial def mapSepElemsMAux {m : Type → Type} [Monad m] (a : Array Syntax) (f : Syntax → m Syntax) : Nat → Array Syntax → m (Array Syntax)
|
||||
| i, acc =>
|
||||
if h : i < a.size then do
|
||||
let stx := a.get ⟨i, h⟩;
|
||||
if i % 2 == 0 then do
|
||||
stx ← f stx;
|
||||
mapSepElemsMAux (i+1) (acc.push stx)
|
||||
else
|
||||
mapSepElemsMAux (i+1) (acc.push stx)
|
||||
else
|
||||
pure acc
|
||||
|
||||
def mapSepElemsM {m : Type → Type} [Monad m] (a : Array Syntax) (f : Syntax → m Syntax) : m (Array Syntax) :=
|
||||
mapSepElemsMAux a f 0 #[]
|
||||
|
||||
def mapSepElems (a : Array Syntax) (f : Syntax → Syntax) : Array Syntax :=
|
||||
Id.run $ a.mapSepElemsM f
|
||||
|
||||
end Array
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue