diff --git a/src/Init/Lean/Elab/Syntax.lean b/src/Init/Lean/Elab/Syntax.lean index 1efea57623..45f5c1abaf 100644 --- a/src/Init/Lean/Elab/Syntax.lean +++ b/src/Init/Lean/Elab/Syntax.lean @@ -180,26 +180,54 @@ fun stx => do trace `Elab stx $ fun _ => d; withMacroExpansion stx d $ elabCommand d -def getMacroRulesAltKind (alt : Syntax) : CommandElabM SyntaxNodeKind := +def elabMacroRulesAux (k : SyntaxNodeKind) (alts : Array Syntax) : CommandElabM Syntax := do +alts ← alts.mapSepElemsM $ fun alt => do { + let lhs := alt.getArg 0; + let pat := lhs.getArg 0; + match_syntax pat with + | `(`($quot)) => + let k' := quot.getKind; + if k' == k then + pure alt + else if k' == choiceKind then do + match quot.getArgs.find? $ fun quotAlt => if quotAlt.getKind == k then some quotAlt else none with + | none => throwError alt ("invalid macro_rules alternative, expected syntax node kind '" ++ k ++ "'") + | some quot => do + pat ← `(`($quot)); + let lhs := lhs.setArg 0 pat; + pure $ alt.setArg 0 lhs + else + throwError alt ("invalid macro_rules alternative, unexpected syntax node kind '" ++ k' ++ "'") + | stx => throwUnsupportedSyntax +}; +`(@[macro $(Lean.mkIdent k)] def myMacro : Macro := fun stx => match_syntax stx with $alts:matchAlt* | _ => throw Lean.Macro.Exception.unsupportedSyntax) + +def inferMacroRulesAltKind (alt : Syntax) : CommandElabM SyntaxNodeKind := match_syntax (alt.getArg 0).getArg 0 with | `(`($quot)) => pure quot.getKind | stx => throwUnsupportedSyntax -def elabMacroRulesAux (alts : Array Syntax) : CommandElabM Syntax := do -k ← getMacroRulesAltKind (alts.get! 0); -altsK ← alts.filterSepElemsM (fun alt => do k' ← getMacroRulesAltKind alt; pure $ k == k'); -altsNotK ← alts.filterSepElemsM (fun alt => do k' ← getMacroRulesAltKind alt; pure $ k != k'); -altsKDef ← `(@[macro $(Lean.mkIdent k)] def myMacro : Macro := fun stx => match_syntax stx with $altsK:matchAlt* | _ => throw Lean.Macro.Exception.unsupportedSyntax); -if altsNotK.isEmpty then - pure altsKDef -else - `($altsKDef:command macro_rules $altsNotK:matchAlt*) +def elabNoKindMacroRulesAux (alts : Array Syntax) : CommandElabM Syntax := do +k ← inferMacroRulesAltKind (alts.get! 0); +if k == choiceKind then + throwError (alts.get! 0) + "invalid macro_rules alternative, multiple interpretations for first element (solution: specify node kind using `macro_rules [] ...`)" +else do + altsK ← alts.filterSepElemsM (fun alt => do k' ← inferMacroRulesAltKind alt; pure $ k == k'); + altsNotK ← alts.filterSepElemsM (fun alt => do k' ← inferMacroRulesAltKind alt; pure $ k != k'); + defCmd ← elabMacroRulesAux k altsK; + if altsNotK.isEmpty then + pure defCmd + else + `($defCmd:command macro_rules $altsNotK:matchAlt*) @[builtinCommandElab «macro_rules»] def elabMacroRules : CommandElab := adaptExpander $ fun stx => match_syntax stx with -| `(macro_rules $alts:matchAlt*) => elabMacroRulesAux alts -| `(macro_rules | $alts:matchAlt*) => elabMacroRulesAux alts -| _ => throwUnsupportedSyntax +| `(macro_rules $alts:matchAlt*) => elabNoKindMacroRulesAux alts +| `(macro_rules | $alts:matchAlt*) => elabNoKindMacroRulesAux alts +| `(macro_rules [$kind] $alts:matchAlt*) => elabMacroRulesAux kind.getId alts +| `(macro_rules [$kind] | $alts:matchAlt*) => elabMacroRulesAux kind.getId alts +| _ => throwUnsupportedSyntax /- We just ignore Lean3 notation declaration commands. -/ @[builtinCommandElab «mixfix»] def elabMixfix : CommandElab := fun _ => pure () diff --git a/src/Init/LeanInit.lean b/src/Init/LeanInit.lean index e3f0ce566c..0498c22e75 100644 --- a/src/Init/LeanInit.lean +++ b/src/Init/LeanInit.lean @@ -782,4 +782,22 @@ filterSepElemsMAux a p 0 #[] def filterSepElems (a : Array Syntax) (p : Syntax → Bool) : Array Syntax := Id.run $ a.filterSepElemsM p +private partial def mapSepElemsMAux {m : Type → Type} [Monad m] (a : Array Syntax) (f : Syntax → m Syntax) : Nat → Array Syntax → m (Array Syntax) +| i, acc => + if h : i < a.size then do + let stx := a.get ⟨i, h⟩; + if i % 2 == 0 then do + stx ← f stx; + mapSepElemsMAux (i+1) (acc.push stx) + else + mapSepElemsMAux (i+1) (acc.push stx) + else + pure acc + +def mapSepElemsM {m : Type → Type} [Monad m] (a : Array Syntax) (f : Syntax → m Syntax) : m (Array Syntax) := +mapSepElemsMAux a f 0 #[] + +def mapSepElems (a : Array Syntax) (f : Syntax → Syntax) : Array Syntax := +Id.run $ a.mapSepElemsM f + end Array