feat: improve the heuristic for notation delab

Instead of the previous constraints on the right hand side that only
allowed a permutation of variables as parameters to a function the
new heuristic allows anything to the right of a function as long as
each variable only appears at most once.
This commit is contained in:
Henrik Böving 2022-06-18 19:48:44 +02:00 committed by Leonardo de Moura
parent 3a89723f8c
commit 0fde2db75e
4 changed files with 76 additions and 6 deletions

View file

@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
import Lean.Elab.Syntax
import Lean.Elab.AuxDef
import Lean.Elab.BuiltinNotation
namespace Lean.Elab.Command
open Lean.Syntax
@ -43,15 +44,55 @@ def expandNotationItemIntoPattern (stx : Syntax) : MacroM Syntax :=
else
Macro.throwUnsupported
def removeParenthesesAux (parens body : Syntax) : Syntax :=
match parens.getHeadInfo, body.getHeadInfo, body.getTailInfo, parens.getTailInfo with
| .original lead _ _ _, .original _ pos trail pos',
.original endLead endPos _ endPos', .original _ _ endTrail _ =>
body.setHeadInfo (.original lead pos trail pos') |>.setTailInfo (.original endLead endPos endTrail endPos')
| _, _, _, _ => body
partial def removeParentheses (stx : Syntax) : MacroM Syntax := do
match stx with
| `(($e)) => pure $ removeParenthesesAux stx (←removeParentheses $ (←Term.expandCDot? e).getD e)
| _ =>
match stx with
| .node info kind args => pure $ .node info kind (←args.mapM removeParentheses)
| _ => pure stx
partial def hasDuplicateAntiquot (stxs : Array Syntax) : Bool := Id.run do
let mut seen := NameSet.empty
for stx in stxs do
for node in Syntax.topDown stx true do
if node.isAntiquot then
let ident := node.getAntiquotTerm.getId
if seen.contains ident then
return true
else
seen := seen.insert ident
pure false
/-- Try to derive a `SimpleDelab` from a notation.
The notation must be of the form `notation ... => c var_1 ... var_n`
where `c` is a declaration in the current scope and the `var_i` are a permutation of the LHS vars. -/
The notation must be of the form `notation ... => c body`
where `c` is a declaration in the current scope and `body` any syntax
that contains each variable from the LHS at most once. -/
def mkSimpleDelab (attrKind : Syntax) (pat qrhs : Syntax) : OptionT MacroM Syntax := do
match qrhs with
| `($c:ident $args*) =>
let [(c, [])] ← Macro.resolveGlobalName c.getId | failure
guard <| args.all (Syntax.isIdent ∘ getAntiquotTerm)
guard <| args.allDiff
/-
Try to remove all non semantic parenthesis. Since the parenthesizer
runs after appUnexpanders we should not match on parenthesis that the user
syntax inserted here for example the right hand side of:
notation "{" x "|" p "}" => setOf (fun x => p)
Should be matched as: setOf fun x => p
-/
let args ← args.mapM (liftM ∘ removeParentheses)
/-
The user could mention the same antiquotation from the lhs multiple
times on the rhs, this heuristic does not support this.
-/
let dup := hasDuplicateAntiquot args
guard !dup
-- replace head constant with (unused) antiquotation so we're not dependent on the exact pretty printing of the head
-- The reference is attached to the syntactic representation of the called function itself, not the entire function application
`(@[$attrKind:attrKind appUnexpander $(mkIdent c):ident]
@ -106,5 +147,4 @@ private def expandNotationAux (ref : Syntax)
expandNotationAux stx (← Macro.getCurrNamespace) attrKind prec? name? prio? items rhs
| _ => Macro.throwUnsupported
end Lean.Elab.Command

View file

@ -1,4 +1,4 @@
id x : α
A : α
id x✝ : α
255.lean:16:7-16:8: error: unknown constant 'x✝'
id (sorryAx ?m true) : ?m

View file

@ -0,0 +1,24 @@
notation "unitTest " x => Prod.mk x ()
#check unitTest 42
notation "parenthesisTest " x => Nat.sub (x)
#check parenthesisTest 12
def Set (α : Type u) := α → Prop
def setOf {α : Type} (p : α → Prop) : Set α := p
notation "{ " x " | " p " }" => setOf (fun x => p)
#check { (x : Nat) | x ≤ 1 }
notation "cdotTest " "(" x ", " y ")" => Prod.map (· + 1) (1 + ·) (x, y)
#check cdotTest (13, 12)
notation "tupleFunctionTest " "(" x ", " y ")"=> Prod.map (Nat.add 1) (Nat.add 2) (x, y)
#check tupleFunctionTest (15, 12)
notation "doubleRhsTest " x => Prod.mk x x
#check doubleRhsTest 12

View file

@ -0,0 +1,6 @@
unitTest 42 : Nat × Unit
parenthesisTest 12 : Nat → Nat
{ x | x ≤ 1 } : Set Nat
cdotTest (13, 12) : Nat × Nat
tupleFunctionTest (15, 12) : Nat × Nat
(12, 12) : Nat × Nat