feat: let dot notation see through CoeFun instances (#5692)
Projects like mathlib like to define projection functions with extra structure, for example one could imagine defining `Multiset.card : Multiset α →+ Nat`, which bundles the fact that `Multiset.card (m1 + m2) = Multiset.card m1 + Multiset.card m2` for all `m1 m2 : Multiset α`. A problem though is that so far this has prevented dot notation from working: you can't write `(m1 + m2).card = m1.card + m2.card`. With this PR, now you can. The way it works is that "LValue resolution" will apply CoeFun instances when trying to resolve which argument should receive the object of dot notation. A contrived-yet-representative example: ```lean structure Equiv (α β : Sort _) where toFun : α → β invFun : β → α infixl:25 " ≃ " => Equiv instance: CoeFun (α ≃ β) fun _ => α → β where coe := Equiv.toFun structure Foo where n : Nat def Foo.n' : Foo ≃ Nat := ⟨Foo.n, Foo.mk⟩ variable (f : Foo) #check f.n' -- Foo.n'.toFun f : Nat ``` Design note 1: While LValue resolution attempts to make use of named arguments when positional arguments cannot be used, when we apply CoeFun instances we disallow making use of named arguments. The rationale is that argument names for CoeFun instances tend to be random, which could lead dot notation randomly succeeding or failing. It is better to be uniform, and so it uniformly fails in this case. Design note 2: There is a limitation in that this will *not* make use of the values of any of the provided arguments when synthesizing the CoeFun instances (see the tests for an example), since argument elaboration takes place after LValue resolution. However, we make sure that synthesis will fail rather than choose the wrong CoeFun instance. Performance note: Such instances will be synthesized twice, once during LValue resolution, and again when applying arguments. This also adds in a small optimization to the parameter list computation in LValue resolution so that it lazily reduces when a relevant parameter hasn't been found yet, rather than using `forallTelescopeReducing`. It also switches to using `forallMetaTelescope` to make sure the CoeFun synthesis will fail if multiple instances could apply. Getting this to pretty print will be deferred to future work. Closes #1910
This commit is contained in:
parent
36c2511b27
commit
a026bc7edb
2 changed files with 163 additions and 37 deletions
|
|
@ -1118,9 +1118,17 @@ where
|
|||
|
||||
/-- Auxiliary inductive datatype that represents the resolution of an `LVal`. -/
|
||||
inductive LValResolution where
|
||||
/-- When applied to `f`, effectively expands to `BaseStruct.fieldName (self := Struct.toBase f)`.
|
||||
This is a special named argument where it suppresses any explicit arguments depending on it so that type parameters don't need to be supplied. -/
|
||||
| projFn (baseStructName : Name) (structName : Name) (fieldName : Name)
|
||||
/-- Similar to `projFn`, but for extracting field indexed by `idx`. Works for structure-like inductive types in general. -/
|
||||
| projIdx (structName : Name) (idx : Nat)
|
||||
/-- When applied to `f`, effectively expands to `constName ... (Struct.toBase f)`, with the argument placed in the correct
|
||||
positional argument if possible, or otherwise as a named argument. The `Struct.toBase` is not present if `baseStructName == structName`,
|
||||
in which case these do not need to be structures. Supports generalized field notation. -/
|
||||
| const (baseStructName : Name) (structName : Name) (constName : Name)
|
||||
/-- Like `const`, but with `fvar` instead of `constName`.
|
||||
The `fullName` is the name of the recursive function, and `baseName` is the base name of the type to search for in the parameter list. -/
|
||||
| localRec (baseName : Name) (fullName : Name) (fvar : Expr)
|
||||
|
||||
private def throwLValError (e : Expr) (eType : Expr) (msg : MessageData) : TermElabM α :=
|
||||
|
|
@ -1290,45 +1298,70 @@ private def typeMatchesBaseName (type : Expr) (baseName : Name) : MetaM Bool :=
|
|||
else
|
||||
return (← whnfR type).isAppOf baseName
|
||||
|
||||
/-- Auxiliary method for field notation. It tries to add `e` as a new argument to `args` or `namedArgs`.
|
||||
This method first finds the parameter with a type of the form `(baseName ...)`.
|
||||
When the parameter is found, if it an explicit one and `args` is big enough, we add `e` to `args`.
|
||||
Otherwise, if there isn't another parameter with the same name, we add `e` to `namedArgs`.
|
||||
/--
|
||||
Auxiliary method for field notation. Tries to add `e` as a new argument to `args` or `namedArgs`.
|
||||
This method first finds the parameter with a type of the form `(baseName ...)`.
|
||||
When the parameter is found, if it an explicit one and `args` is big enough, we add `e` to `args`.
|
||||
Otherwise, if there isn't another parameter with the same name, we add `e` to `namedArgs`.
|
||||
|
||||
Remark: `fullName` is the name of the resolved "field" access function. It is used for reporting errors -/
|
||||
private def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (fType : Expr)
|
||||
: TermElabM (Array Arg × Array NamedArg) :=
|
||||
forallTelescopeReducing fType fun xs _ => do
|
||||
let mut argIdx := 0 -- position of the next explicit argument
|
||||
let mut remainingNamedArgs := namedArgs
|
||||
for h : i in [:xs.size] do
|
||||
let x := xs[i]
|
||||
let xDecl ← x.fvarId!.getDecl
|
||||
/- If there is named argument with name `xDecl.userName`, then we skip it. -/
|
||||
match remainingNamedArgs.findIdx? (fun namedArg => namedArg.name == xDecl.userName) with
|
||||
| some idx =>
|
||||
Remark: `fullName` is the name of the resolved "field" access function. It is used for reporting errors
|
||||
-/
|
||||
private partial def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (f : Expr) :
|
||||
MetaM (Array Arg × Array NamedArg) := do
|
||||
withoutModifyingState <| go f (← inferType f) 0 namedArgs (namedArgs.map (·.name)) true
|
||||
where
|
||||
/--
|
||||
* `argIdx` is the position into `args` for the next place an explicit argument can be inserted.
|
||||
* `remainingNamedArgs` keeps track of named arguments that haven't been visited yet,
|
||||
for handling the case where multiple parameters have the same name.
|
||||
* `unusableNamedArgs` keeps track of names that can't be used as named arguments. This is initialized with user-provided named arguments.
|
||||
* `allowNamed` is whether or not to allow using named arguments.
|
||||
Disabled after using `CoeFun` since those parameter names unlikely to be meaningful,
|
||||
and otherwise whether dot notation works or not could feel random.
|
||||
-/
|
||||
go (f fType : Expr) (argIdx : Nat) (remainingNamedArgs : Array NamedArg) (unusableNamedArgs : Array Name) (allowNamed : Bool) := withIncRecDepth do
|
||||
/- Use metavariables (rather than `forallTelescope`) to prevent `coerceToFunction?` from succeeding when multiple instances could apply -/
|
||||
let (xs, bInfos, fType') ← forallMetaTelescope fType
|
||||
let mut argIdx := argIdx
|
||||
let mut remainingNamedArgs := remainingNamedArgs
|
||||
let mut unusableNamedArgs := unusableNamedArgs
|
||||
for x in xs, bInfo in bInfos do
|
||||
let xDecl ← x.mvarId!.getDecl
|
||||
if let some idx := remainingNamedArgs.findIdx? (·.name == xDecl.userName) then
|
||||
/- If there is named argument with name `xDecl.userName`, then it is accounted for and we can't make use of it. -/
|
||||
remainingNamedArgs := remainingNamedArgs.eraseIdx idx
|
||||
| none =>
|
||||
let type := xDecl.type
|
||||
if (← typeMatchesBaseName type baseName) then
|
||||
else
|
||||
if (← typeMatchesBaseName xDecl.type baseName) then
|
||||
/- We found a type of the form (baseName ...).
|
||||
First, we check if the current argument is an explicit one,
|
||||
and the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/
|
||||
if argIdx ≤ args.size && xDecl.binderInfo.isExplicit then
|
||||
/- We insert `e` as an explicit argument -/
|
||||
and if the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/
|
||||
if argIdx ≤ args.size && bInfo.isExplicit then
|
||||
/- We can insert `e` as an explicit argument -/
|
||||
return (args.insertAt! argIdx (Arg.expr e), namedArgs)
|
||||
/- If we can't add `e` to `args`, we try to add it using a named argument, but this is only possible
|
||||
if there isn't an argument with the same name occurring before it. -/
|
||||
for j in [:i] do
|
||||
let prev := xs[j]!
|
||||
let prevDecl ← prev.fvarId!.getDecl
|
||||
if prevDecl.userName == xDecl.userName then
|
||||
throwError "invalid field notation, function '{fullName}' has argument with the expected type{indentExpr type}\nbut it cannot be used"
|
||||
return (args, namedArgs.push { name := xDecl.userName, val := Arg.expr e })
|
||||
if xDecl.binderInfo.isExplicit then
|
||||
-- advance explicit argument position
|
||||
else
|
||||
/- If we can't add `e` to `args`, we try to add it using a named argument, but this is only possible
|
||||
if there isn't an argument with the same name occurring before it. -/
|
||||
if !allowNamed || unusableNamedArgs.contains xDecl.userName then
|
||||
throwError "\
|
||||
invalid field notation, function '{fullName}' has argument with the expected type\
|
||||
{indentExpr xDecl.type}\n\
|
||||
but it cannot be used"
|
||||
else
|
||||
return (args, namedArgs.push { name := xDecl.userName, val := Arg.expr e })
|
||||
/- Advance `argIdx` and update seen named arguments. -/
|
||||
if bInfo.isExplicit then
|
||||
argIdx := argIdx + 1
|
||||
throwError "invalid field notation, function '{fullName}' does not have argument with type ({baseName} ...) that can be used, it must be explicit or implicit with a unique name"
|
||||
unusableNamedArgs := unusableNamedArgs.push xDecl.userName
|
||||
/- If named arguments aren't allowed, then it must still be possible to pass the value as an explicit argument.
|
||||
Otherwise, we can abort now. -/
|
||||
if allowNamed || argIdx ≤ args.size then
|
||||
if let fType'@(.forallE ..) ← whnf fType' then
|
||||
return ← go (mkAppN f xs) fType' argIdx remainingNamedArgs unusableNamedArgs allowNamed
|
||||
if let some f' ← coerceToFunction? (mkAppN f xs) then
|
||||
return ← go f' (← inferType f') argIdx remainingNamedArgs unusableNamedArgs false
|
||||
throwError "\
|
||||
invalid field notation, function '{fullName}' does not have argument with type ({baseName} ...) that can be used, \
|
||||
it must be explicit or implicit with a unique name"
|
||||
|
||||
/-- Adds the `TermInfo` for the field of a projection. See `Lean.Parser.Term.identProjKind`. -/
|
||||
private def addProjTermInfo
|
||||
|
|
@ -1375,8 +1408,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
|
|||
let projFn ← mkConst constName
|
||||
let projFn ← addProjTermInfo lval.getRef projFn
|
||||
if lvals.isEmpty then
|
||||
let projFnType ← inferType projFn
|
||||
let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFnType
|
||||
let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFn
|
||||
elabAppArgs projFn namedArgs args expectedType? explicit ellipsis
|
||||
else
|
||||
let f ← elabAppArgs projFn #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false)
|
||||
|
|
@ -1384,8 +1416,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
|
|||
| LValResolution.localRec baseName fullName fvar =>
|
||||
let fvar ← addProjTermInfo lval.getRef fvar
|
||||
if lvals.isEmpty then
|
||||
let fvarType ← inferType fvar
|
||||
let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvarType
|
||||
let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvar
|
||||
elabAppArgs fvar namedArgs args expectedType? explicit ellipsis
|
||||
else
|
||||
let f ← elabAppArgs fvar #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false)
|
||||
|
|
|
|||
95
tests/lean/run/1910.lean
Normal file
95
tests/lean/run/1910.lean
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
/-!
|
||||
# Dot notation and CoeFun
|
||||
|
||||
https://github.com/leanprover/lean4/issues/1910
|
||||
-/
|
||||
|
||||
set_option pp.mvars false
|
||||
|
||||
/-!
|
||||
Test that dot notation resolution can see through CoeFun instances.
|
||||
-/
|
||||
|
||||
structure Equiv (α β : Sort _) where
|
||||
toFun : α → β
|
||||
invFun : β → α
|
||||
|
||||
infixl:25 " ≃ " => Equiv
|
||||
|
||||
instance : CoeFun (α ≃ β) fun _ => α → β where
|
||||
coe := Equiv.toFun
|
||||
|
||||
structure Foo where
|
||||
n : Nat
|
||||
|
||||
def Foo.n' : Foo ≃ Nat := ⟨Foo.n, Foo.mk⟩
|
||||
|
||||
variable (f : Foo)
|
||||
/-- info: Foo.n'.toFun f : Nat -/
|
||||
#guard_msgs in #check f.n'
|
||||
|
||||
example (f : Foo) : f.n' = f.n := rfl
|
||||
|
||||
/-!
|
||||
Fail dot notation if it requires using a named argument from the CoeFun instance.
|
||||
-/
|
||||
structure F where
|
||||
f : Bool → Nat → Nat
|
||||
|
||||
instance : CoeFun F (fun _ => (x : Bool) → (y : Nat) → Nat) where
|
||||
coe x := fun (a : Bool) (b : Nat) => x.f a b
|
||||
|
||||
-- Recall CoeFun oddity: it uses the unfolded *value* to figure out parameter names.
|
||||
-- That's why this is `a` and `b` rather than `x` and `y`.
|
||||
/-- info: fun x => (fun a b => x.f a b) true 2 : F → Nat -/
|
||||
#guard_msgs in #check fun (x : F) => x (a := true) (b := 2)
|
||||
|
||||
def Nat.foo : F := { f := fun _ b => b }
|
||||
|
||||
-- Ok:
|
||||
/-- info: fun n x => (fun a b => Nat.foo.f a b) x n : Nat → Bool → Nat -/
|
||||
#guard_msgs in #check fun (n : Nat) => (Nat.foo · n)
|
||||
|
||||
-- Intentionally fails:
|
||||
/--
|
||||
error: invalid field notation, function 'Nat.foo' has argument with the expected type
|
||||
Nat
|
||||
but it cannot be used
|
||||
---
|
||||
info: fun n => sorryAx (?_ n) true : (n : Nat) → ?_ n
|
||||
-/
|
||||
#guard_msgs in #check fun (n : Nat) => n.foo
|
||||
|
||||
/-!
|
||||
Make sure that dot notation does not use the wrong CoeFun instance.
|
||||
The following instances rely on the second one having higher priority,
|
||||
so we need to fail completely when the instances would depend on argument values.
|
||||
-/
|
||||
|
||||
structure Bar (b : Bool) where
|
||||
|
||||
instance : CoeFun (Bar b) (fun _ => Bar b → Bool) where
|
||||
coe := fun _ _ => b
|
||||
|
||||
instance : CoeFun (Bar true) (fun _ => (b : Bool) → Bar b) where
|
||||
coe := fun _ _ => {}
|
||||
|
||||
def Bar.bar : Bar true := {}
|
||||
|
||||
/-- info: fun f => (fun x => false) f : Bar false → Bool -/
|
||||
#guard_msgs in #check fun (f : Bar false) => Bar.bar false f
|
||||
/--
|
||||
error: invalid field notation, function 'Bar.bar' does not have argument with type (Bar ...) that can be used, it must be explicit or implicit with a unique name
|
||||
---
|
||||
info: fun f => sorryAx (?_ f) true : (f : Bar false) → ?_ f
|
||||
-/
|
||||
#guard_msgs in #check fun (f : Bar false) => f.bar false
|
||||
|
||||
/-- info: fun f => (fun x => false) f : Bar false → Bool -/
|
||||
#guard_msgs in #check fun (f : Bar false) => Bar.bar true false f
|
||||
/--
|
||||
error: invalid field notation, function 'Bar.bar' does not have argument with type (Bar ...) that can be used, it must be explicit or implicit with a unique name
|
||||
---
|
||||
info: fun f => sorryAx (?_ f) true : (f : Bar false) → ?_ f
|
||||
-/
|
||||
#guard_msgs in #check fun (f : Bar false) => f.bar true false
|
||||
Loading…
Add table
Reference in a new issue