diff --git a/src/Lean/Elab/Notation.lean b/src/Lean/Elab/Notation.lean index 5260ff91c2..2a0402c9e9 100644 --- a/src/Lean/Elab/Notation.lean +++ b/src/Lean/Elab/Notation.lean @@ -5,6 +5,7 @@ Authors: Leonardo de Moura -/ import Lean.Elab.Syntax import Lean.Elab.AuxDef +import Lean.Elab.BuiltinNotation namespace Lean.Elab.Command open Lean.Syntax @@ -43,15 +44,55 @@ def expandNotationItemIntoPattern (stx : Syntax) : MacroM Syntax := else Macro.throwUnsupported +def removeParenthesesAux (parens body : Syntax) : Syntax := + match parens.getHeadInfo, body.getHeadInfo, body.getTailInfo, parens.getTailInfo with + | .original lead _ _ _, .original _ pos trail pos', + .original endLead endPos _ endPos', .original _ _ endTrail _ => + body.setHeadInfo (.original lead pos trail pos') |>.setTailInfo (.original endLead endPos endTrail endPos') + | _, _, _, _ => body + +partial def removeParentheses (stx : Syntax) : MacroM Syntax := do + match stx with + | `(($e)) => pure $ removeParenthesesAux stx (←removeParentheses $ (←Term.expandCDot? e).getD e) + | _ => + match stx with + | .node info kind args => pure $ .node info kind (←args.mapM removeParentheses) + | _ => pure stx + +partial def hasDuplicateAntiquot (stxs : Array Syntax) : Bool := Id.run do + let mut seen := NameSet.empty + for stx in stxs do + for node in Syntax.topDown stx true do + if node.isAntiquot then + let ident := node.getAntiquotTerm.getId + if seen.contains ident then + return true + else + seen := seen.insert ident + pure false + /-- Try to derive a `SimpleDelab` from a notation. - The notation must be of the form `notation ... => c var_1 ... var_n` - where `c` is a declaration in the current scope and the `var_i` are a permutation of the LHS vars. -/ + The notation must be of the form `notation ... => c body` + where `c` is a declaration in the current scope and `body` any syntax + that contains each variable from the LHS at most once. -/ def mkSimpleDelab (attrKind : Syntax) (pat qrhs : Syntax) : OptionT MacroM Syntax := do match qrhs with | `($c:ident $args*) => let [(c, [])] ← Macro.resolveGlobalName c.getId | failure - guard <| args.all (Syntax.isIdent ∘ getAntiquotTerm) - guard <| args.allDiff + /- + Try to remove all non semantic parenthesis. Since the parenthesizer + runs after appUnexpanders we should not match on parenthesis that the user + syntax inserted here for example the right hand side of: + notation "{" x "|" p "}" => setOf (fun x => p) + Should be matched as: setOf fun x => p + -/ + let args ← args.mapM (liftM ∘ removeParentheses) + /- + The user could mention the same antiquotation from the lhs multiple + times on the rhs, this heuristic does not support this. + -/ + let dup := hasDuplicateAntiquot args + guard !dup -- replace head constant with (unused) antiquotation so we're not dependent on the exact pretty printing of the head -- The reference is attached to the syntactic representation of the called function itself, not the entire function application `(@[$attrKind:attrKind appUnexpander $(mkIdent c):ident] @@ -106,5 +147,4 @@ private def expandNotationAux (ref : Syntax) expandNotationAux stx (← Macro.getCurrNamespace) attrKind prec? name? prio? items rhs | _ => Macro.throwUnsupported - end Lean.Elab.Command diff --git a/tests/lean/255.lean.expected.out b/tests/lean/255.lean.expected.out index 907805e394..b163a7dcb8 100644 --- a/tests/lean/255.lean.expected.out +++ b/tests/lean/255.lean.expected.out @@ -1,4 +1,4 @@ -id x : α +A : α id x✝ : α 255.lean:16:7-16:8: error: unknown constant 'x✝' id (sorryAx ?m true) : ?m diff --git a/tests/lean/notationDelab.lean b/tests/lean/notationDelab.lean new file mode 100644 index 0000000000..3476d0d155 --- /dev/null +++ b/tests/lean/notationDelab.lean @@ -0,0 +1,24 @@ +notation "unitTest " x => Prod.mk x () + +#check unitTest 42 + +notation "parenthesisTest " x => Nat.sub (x) +#check parenthesisTest 12 + +def Set (α : Type u) := α → Prop +def setOf {α : Type} (p : α → Prop) : Set α := p +notation "{ " x " | " p " }" => setOf (fun x => p) + +#check { (x : Nat) | x ≤ 1 } + +notation "cdotTest " "(" x ", " y ")" => Prod.map (· + 1) (1 + ·) (x, y) + +#check cdotTest (13, 12) + +notation "tupleFunctionTest " "(" x ", " y ")"=> Prod.map (Nat.add 1) (Nat.add 2) (x, y) + +#check tupleFunctionTest (15, 12) + +notation "doubleRhsTest " x => Prod.mk x x + +#check doubleRhsTest 12 diff --git a/tests/lean/notationDelab.lean.expected.out b/tests/lean/notationDelab.lean.expected.out new file mode 100644 index 0000000000..6de7670b0b --- /dev/null +++ b/tests/lean/notationDelab.lean.expected.out @@ -0,0 +1,6 @@ +unitTest 42 : Nat × Unit +parenthesisTest 12 : Nat → Nat +{ x | x ≤ 1 } : Set Nat +cdotTest (13, 12) : Nat × Nat +tupleFunctionTest (15, 12) : Nat × Nat +(12, 12) : Nat × Nat