fix: check arity in notation unexpander

Fixes #469
This commit is contained in:
Sebastian Ullrich 2021-07-22 16:52:43 +02:00
parent 98634b5554
commit dc3d94ff61
4 changed files with 29 additions and 17 deletions

View file

@ -52,13 +52,14 @@ def mkSimpleDelab (attrKind : Syntax) (vars : Array Syntax) (pat qrhs : Syntax)
guard <| args.all (Syntax.isIdent ∘ getAntiquotTerm)
guard <| args.allDiff
-- replace head constant with (unused) antiquotation so we're not dependent on the exact pretty printing of the head
let qrhs ← `($(mkAntiquotNode (← `(_))) $args*)
`(@[$attrKind:attrKind appUnexpander $(mkIdent c):ident] def unexpand : Lean.PrettyPrinter.Unexpander := fun
| `($qrhs) => `($pat)
| _ => throw ())
| `($$(_):ident $args*) => `($pat)
| _ => throw ())
| `($c:ident) =>
let [(c, [])] ← Macro.resolveGlobalName c.getId | failure
`(@[$attrKind:attrKind appUnexpander $(mkIdent c):ident] def unexpand : Lean.PrettyPrinter.Unexpander := fun _ => `($pat))
`(@[$attrKind:attrKind appUnexpander $(mkIdent c):ident] def unexpand : Lean.PrettyPrinter.Unexpander
| `($$(_):ident) => `($pat)
| _ => throw ())
| _ => failure
private def isLocalAttrKind (attrKind : Syntax) : Bool :=

4
tests/lean/469.lean Normal file
View file

@ -0,0 +1,4 @@
notation "(+)" => HAdd.hAdd
#check ((+) : Nat -> Nat -> Nat)
#check ((+) 2 : Nat -> Nat)
#check ((+) 2 3 : Nat)

View file

@ -0,0 +1,3 @@
(+) : Nat → Nat → Nat
HAdd.hAdd 2 : Nat → Nat
2 + 3 : Nat

View file

@ -27,20 +27,24 @@ fun (x : Lean.Syntax) =>
let discr : Lean.Syntax := x;
if Lean.Syntax.isOfKind discr `Lean.Parser.Term.app = true then
let discr_1 : Lean.Syntax := Lean.Syntax.getArg discr 0;
let discr_2 : Lean.Syntax := Lean.Syntax.getArg discr 1;
if Lean.Syntax.matchesNull discr_2 2 = true then
let discr : Lean.Syntax := Lean.Syntax.getArg discr_2 0;
let discr_3 : Lean.Syntax := Lean.Syntax.getArg discr_2 1;
let rhs : Lean.Syntax := discr_3;
let lhs : Lean.Syntax := discr;
do
let info ← Lean.MonadRef.mkInfoFromRefPos
Lean.getCurrMacroScope
Lean.getMainModule
pure (Lean.Syntax.node `«term_+++_» #[lhs, Lean.Syntax.atom info "+++", rhs])
else
cond (Lean.Syntax.isOfKind discr_1 `ident)
(let discr_2 : Lean.Syntax := Lean.Syntax.getArg discr 1;
if Lean.Syntax.matchesNull discr_2 2 = true then
let discr : Lean.Syntax := Lean.Syntax.getArg discr_2 0;
let discr_3 : Lean.Syntax := Lean.Syntax.getArg discr_2 1;
let rhs : Lean.Syntax := discr_3;
let lhs : Lean.Syntax := discr;
do
let info ← Lean.MonadRef.mkInfoFromRefPos
Lean.getCurrMacroScope
Lean.getMainModule
pure (Lean.Syntax.node `«term_+++_» #[lhs, Lean.Syntax.atom info "+++", rhs])
else
let discr : Lean.Syntax := Lean.Syntax.getArg discr 1;
throw Unit.unit)
(let discr_2 : Lean.Syntax := Lean.Syntax.getArg discr 0;
let discr : Lean.Syntax := Lean.Syntax.getArg discr 1;
throw Unit.unit
throw Unit.unit)
else
let discr : Lean.Syntax := x;
throw Unit.unit