diff --git a/src/Lean/Elab/Syntax.lean b/src/Lean/Elab/Syntax.lean index 70ddf2bc33..db8632b8e1 100644 --- a/src/Lean/Elab/Syntax.lean +++ b/src/Lean/Elab/Syntax.lean @@ -458,7 +458,7 @@ def expandNotationItemIntoPattern (stx : Syntax) : CommandElabM Syntax := /-- Try to derive a `SimpleDelab` from a notation. The notation must be of the form `notation ... => c var_1 ... var_n` where `c` is a declaration in the current scope and the `var_i` are a permutation of the LHS vars. -/ -def mkSimpleDelab (vars : Array Syntax) (pat qrhs : Syntax) : OptionT CommandElabM Syntax := do +def mkSimpleDelab (attrKind : Syntax) (vars : Array Syntax) (pat qrhs : Syntax) : OptionT CommandElabM Syntax := do match qrhs with | `($c:ident $args*) => let [(c, [])] ← resolveGlobalName c.getId | failure @@ -466,16 +466,16 @@ def mkSimpleDelab (vars : Array Syntax) (pat qrhs : Syntax) : OptionT CommandEla guard <| args.allDiff -- replace head constant with fresh (unused) antiquotation so we're not dependent on the exact pretty printing of the head let qrhs ← `($(mkAntiquotNode (mkIdent "c")) $args*) - `(@[appUnexpander $(mkIdent c):ident] def unexpand : Lean.PrettyPrinter.Unexpander := fun + `(@[$attrKind:attrKind appUnexpander $(mkIdent c):ident] def unexpand : Lean.PrettyPrinter.Unexpander := fun | `($qrhs) => `($pat) | _ => throw ()) | `($c:ident) => let [(c, [])] ← resolveGlobalName c.getId | failure - `(@[appUnexpander $(mkIdent c):ident] def unexpand : Lean.PrettyPrinter.Unexpander := fun _ => `($pat)) + `(@[$attrKind:attrKind appUnexpander $(mkIdent c):ident] def unexpand : Lean.PrettyPrinter.Unexpander := fun _ => `($pat)) | _ => failure private def expandNotationAux (ref : Syntax) - (currNamespace : Name) (attrKind : AttributeKind) (prec? : Option Syntax) (name? : Option Syntax) (prio? : Option Syntax) (items : Array Syntax) (rhs : Syntax) : CommandElabM Syntax := do + (currNamespace : Name) (attrKind : Syntax) (prec? : Option Syntax) (name? : Option Syntax) (prio? : Option Syntax) (items : Array Syntax) (rhs : Syntax) : CommandElabM Syntax := do let prio ← liftMacroM <| evalOptPrio prio? -- build parser let syntaxParts ← items.mapM expandNotationItemIntoSyntaxItem @@ -493,22 +493,19 @@ private def expandNotationAux (ref : Syntax) So, we must include current namespace when we create a pattern for the following `macro_rules` commands. -/ let fullName := currNamespace ++ name let pat := Syntax.node fullName patArgs - let stxDecl ← match attrKind with - | AttributeKind.global => `(syntax $[: $prec?]? (name := $(mkIdent name)) (priority := $(quote prio):numLit) $[$syntaxParts]* : $cat) - | AttributeKind.scoped => `(scoped syntax $[: $prec? ]? (name := $(mkIdent name)) (priority := $(quote prio):numLit) $[$syntaxParts]* : $cat) - | AttributeKind.local => `(local syntax $[: $prec? ]? (name := $(mkIdent name)) (priority := $(quote prio):numLit) $[$syntaxParts]* : $cat) + let stxDecl ← `($attrKind:attrKind syntax $[: $prec?]? (name := $(mkIdent name)) (priority := $(quote prio):numLit) $[$syntaxParts]* : $cat) let macroDecl ← `(macro_rules | `($pat) => `($qrhs)) - match (← mkSimpleDelab vars pat qrhs |>.run) with + match (← mkSimpleDelab attrKind vars pat qrhs |>.run) with | some delabDecl => mkNullNode #[stxDecl, macroDecl, delabDecl] | none => mkNullNode #[stxDecl, macroDecl] @[builtinCommandElab «notation»] def expandNotation : CommandElab := adaptExpander fun stx => do - let attrKind ← toAttributeKind stx[0] - let stx := stx.setArg 0 mkAttrKindGlobal + -- trigger scoped checks early and only once + let _ ← toAttributeKind stx[0] let currNamespace ← getCurrNamespace match stx with - | `(notation $[: $prec? ]? $[(name := $name?)]? $[(priority := $prio?)]? $items* => $rhs) => + | `($attrKind:attrKind notation $[: $prec? ]? $[(name := $name?)]? $[(priority := $prio?)]? $items* => $rhs) => expandNotationAux stx currNamespace attrKind prec? name? prio? items rhs | _ => throwUnsupportedSyntax diff --git a/tests/lean/localNotationPP.lean b/tests/lean/localNotationPP.lean new file mode 100644 index 0000000000..1e0b3fd805 --- /dev/null +++ b/tests/lean/localNotationPP.lean @@ -0,0 +1,5 @@ +axiom n : Type → Type +section +local notation "ℕ" x => n x +end +#check n Nat -- should *not* be `ℕ Nat : Type` diff --git a/tests/lean/localNotationPP.lean.expected.out b/tests/lean/localNotationPP.lean.expected.out new file mode 100644 index 0000000000..7061084127 --- /dev/null +++ b/tests/lean/localNotationPP.lean.expected.out @@ -0,0 +1 @@ +n Nat : Type