diff --git a/src/Lean/Elab/App.lean b/src/Lean/Elab/App.lean index b8a96c55fd..ce604ac0b0 100644 --- a/src/Lean/Elab/App.lean +++ b/src/Lean/Elab/App.lean @@ -1118,9 +1118,17 @@ where /-- Auxiliary inductive datatype that represents the resolution of an `LVal`. -/ inductive LValResolution where + /-- When applied to `f`, effectively expands to `BaseStruct.fieldName (self := Struct.toBase f)`. + This is a special named argument where it suppresses any explicit arguments depending on it so that type parameters don't need to be supplied. -/ | projFn (baseStructName : Name) (structName : Name) (fieldName : Name) + /-- Similar to `projFn`, but for extracting field indexed by `idx`. Works for structure-like inductive types in general. -/ | projIdx (structName : Name) (idx : Nat) + /-- When applied to `f`, effectively expands to `constName ... (Struct.toBase f)`, with the argument placed in the correct + positional argument if possible, or otherwise as a named argument. The `Struct.toBase` is not present if `baseStructName == structName`, + in which case these do not need to be structures. Supports generalized field notation. -/ | const (baseStructName : Name) (structName : Name) (constName : Name) + /-- Like `const`, but with `fvar` instead of `constName`. + The `fullName` is the name of the recursive function, and `baseName` is the base name of the type to search for in the parameter list. -/ | localRec (baseName : Name) (fullName : Name) (fvar : Expr) private def throwLValError (e : Expr) (eType : Expr) (msg : MessageData) : TermElabM α := @@ -1290,45 +1298,70 @@ private def typeMatchesBaseName (type : Expr) (baseName : Name) : MetaM Bool := else return (← whnfR type).isAppOf baseName -/-- Auxiliary method for field notation. It tries to add `e` as a new argument to `args` or `namedArgs`. - This method first finds the parameter with a type of the form `(baseName ...)`. - When the parameter is found, if it an explicit one and `args` is big enough, we add `e` to `args`. - Otherwise, if there isn't another parameter with the same name, we add `e` to `namedArgs`. +/-- +Auxiliary method for field notation. Tries to add `e` as a new argument to `args` or `namedArgs`. +This method first finds the parameter with a type of the form `(baseName ...)`. +When the parameter is found, if it an explicit one and `args` is big enough, we add `e` to `args`. +Otherwise, if there isn't another parameter with the same name, we add `e` to `namedArgs`. - Remark: `fullName` is the name of the resolved "field" access function. It is used for reporting errors -/ -private def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (fType : Expr) - : TermElabM (Array Arg × Array NamedArg) := - forallTelescopeReducing fType fun xs _ => do - let mut argIdx := 0 -- position of the next explicit argument - let mut remainingNamedArgs := namedArgs - for h : i in [:xs.size] do - let x := xs[i] - let xDecl ← x.fvarId!.getDecl - /- If there is named argument with name `xDecl.userName`, then we skip it. -/ - match remainingNamedArgs.findIdx? (fun namedArg => namedArg.name == xDecl.userName) with - | some idx => +Remark: `fullName` is the name of the resolved "field" access function. It is used for reporting errors +-/ +private partial def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (f : Expr) : + MetaM (Array Arg × Array NamedArg) := do + withoutModifyingState <| go f (← inferType f) 0 namedArgs (namedArgs.map (·.name)) true +where + /-- + * `argIdx` is the position into `args` for the next place an explicit argument can be inserted. + * `remainingNamedArgs` keeps track of named arguments that haven't been visited yet, + for handling the case where multiple parameters have the same name. + * `unusableNamedArgs` keeps track of names that can't be used as named arguments. This is initialized with user-provided named arguments. + * `allowNamed` is whether or not to allow using named arguments. + Disabled after using `CoeFun` since those parameter names unlikely to be meaningful, + and otherwise whether dot notation works or not could feel random. + -/ + go (f fType : Expr) (argIdx : Nat) (remainingNamedArgs : Array NamedArg) (unusableNamedArgs : Array Name) (allowNamed : Bool) := withIncRecDepth do + /- Use metavariables (rather than `forallTelescope`) to prevent `coerceToFunction?` from succeeding when multiple instances could apply -/ + let (xs, bInfos, fType') ← forallMetaTelescope fType + let mut argIdx := argIdx + let mut remainingNamedArgs := remainingNamedArgs + let mut unusableNamedArgs := unusableNamedArgs + for x in xs, bInfo in bInfos do + let xDecl ← x.mvarId!.getDecl + if let some idx := remainingNamedArgs.findIdx? (·.name == xDecl.userName) then + /- If there is named argument with name `xDecl.userName`, then it is accounted for and we can't make use of it. -/ remainingNamedArgs := remainingNamedArgs.eraseIdx idx - | none => - let type := xDecl.type - if (← typeMatchesBaseName type baseName) then + else + if (← typeMatchesBaseName xDecl.type baseName) then /- We found a type of the form (baseName ...). First, we check if the current argument is an explicit one, - and the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/ - if argIdx ≤ args.size && xDecl.binderInfo.isExplicit then - /- We insert `e` as an explicit argument -/ + and if the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/ + if argIdx ≤ args.size && bInfo.isExplicit then + /- We can insert `e` as an explicit argument -/ return (args.insertAt! argIdx (Arg.expr e), namedArgs) - /- If we can't add `e` to `args`, we try to add it using a named argument, but this is only possible - if there isn't an argument with the same name occurring before it. -/ - for j in [:i] do - let prev := xs[j]! - let prevDecl ← prev.fvarId!.getDecl - if prevDecl.userName == xDecl.userName then - throwError "invalid field notation, function '{fullName}' has argument with the expected type{indentExpr type}\nbut it cannot be used" - return (args, namedArgs.push { name := xDecl.userName, val := Arg.expr e }) - if xDecl.binderInfo.isExplicit then - -- advance explicit argument position + else + /- If we can't add `e` to `args`, we try to add it using a named argument, but this is only possible + if there isn't an argument with the same name occurring before it. -/ + if !allowNamed || unusableNamedArgs.contains xDecl.userName then + throwError "\ + invalid field notation, function '{fullName}' has argument with the expected type\ + {indentExpr xDecl.type}\n\ + but it cannot be used" + else + return (args, namedArgs.push { name := xDecl.userName, val := Arg.expr e }) + /- Advance `argIdx` and update seen named arguments. -/ + if bInfo.isExplicit then argIdx := argIdx + 1 - throwError "invalid field notation, function '{fullName}' does not have argument with type ({baseName} ...) that can be used, it must be explicit or implicit with a unique name" + unusableNamedArgs := unusableNamedArgs.push xDecl.userName + /- If named arguments aren't allowed, then it must still be possible to pass the value as an explicit argument. + Otherwise, we can abort now. -/ + if allowNamed || argIdx ≤ args.size then + if let fType'@(.forallE ..) ← whnf fType' then + return ← go (mkAppN f xs) fType' argIdx remainingNamedArgs unusableNamedArgs allowNamed + if let some f' ← coerceToFunction? (mkAppN f xs) then + return ← go f' (← inferType f') argIdx remainingNamedArgs unusableNamedArgs false + throwError "\ + invalid field notation, function '{fullName}' does not have argument with type ({baseName} ...) that can be used, \ + it must be explicit or implicit with a unique name" /-- Adds the `TermInfo` for the field of a projection. See `Lean.Parser.Term.identProjKind`. -/ private def addProjTermInfo @@ -1375,8 +1408,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp let projFn ← mkConst constName let projFn ← addProjTermInfo lval.getRef projFn if lvals.isEmpty then - let projFnType ← inferType projFn - let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFnType + let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFn elabAppArgs projFn namedArgs args expectedType? explicit ellipsis else let f ← elabAppArgs projFn #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false) @@ -1384,8 +1416,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp | LValResolution.localRec baseName fullName fvar => let fvar ← addProjTermInfo lval.getRef fvar if lvals.isEmpty then - let fvarType ← inferType fvar - let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvarType + let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvar elabAppArgs fvar namedArgs args expectedType? explicit ellipsis else let f ← elabAppArgs fvar #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false) diff --git a/tests/lean/run/1910.lean b/tests/lean/run/1910.lean new file mode 100644 index 0000000000..dfc13d6ec9 --- /dev/null +++ b/tests/lean/run/1910.lean @@ -0,0 +1,95 @@ +/-! +# Dot notation and CoeFun + +https://github.com/leanprover/lean4/issues/1910 +-/ + +set_option pp.mvars false + +/-! +Test that dot notation resolution can see through CoeFun instances. +-/ + +structure Equiv (α β : Sort _) where + toFun : α → β + invFun : β → α + +infixl:25 " ≃ " => Equiv + +instance : CoeFun (α ≃ β) fun _ => α → β where + coe := Equiv.toFun + +structure Foo where + n : Nat + +def Foo.n' : Foo ≃ Nat := ⟨Foo.n, Foo.mk⟩ + +variable (f : Foo) +/-- info: Foo.n'.toFun f : Nat -/ +#guard_msgs in #check f.n' + +example (f : Foo) : f.n' = f.n := rfl + +/-! +Fail dot notation if it requires using a named argument from the CoeFun instance. +-/ +structure F where + f : Bool → Nat → Nat + +instance : CoeFun F (fun _ => (x : Bool) → (y : Nat) → Nat) where + coe x := fun (a : Bool) (b : Nat) => x.f a b + +-- Recall CoeFun oddity: it uses the unfolded *value* to figure out parameter names. +-- That's why this is `a` and `b` rather than `x` and `y`. +/-- info: fun x => (fun a b => x.f a b) true 2 : F → Nat -/ +#guard_msgs in #check fun (x : F) => x (a := true) (b := 2) + +def Nat.foo : F := { f := fun _ b => b } + +-- Ok: +/-- info: fun n x => (fun a b => Nat.foo.f a b) x n : Nat → Bool → Nat -/ +#guard_msgs in #check fun (n : Nat) => (Nat.foo · n) + +-- Intentionally fails: +/-- +error: invalid field notation, function 'Nat.foo' has argument with the expected type + Nat +but it cannot be used +--- +info: fun n => sorryAx (?_ n) true : (n : Nat) → ?_ n +-/ +#guard_msgs in #check fun (n : Nat) => n.foo + +/-! +Make sure that dot notation does not use the wrong CoeFun instance. +The following instances rely on the second one having higher priority, +so we need to fail completely when the instances would depend on argument values. +-/ + +structure Bar (b : Bool) where + +instance : CoeFun (Bar b) (fun _ => Bar b → Bool) where + coe := fun _ _ => b + +instance : CoeFun (Bar true) (fun _ => (b : Bool) → Bar b) where + coe := fun _ _ => {} + +def Bar.bar : Bar true := {} + +/-- info: fun f => (fun x => false) f : Bar false → Bool -/ +#guard_msgs in #check fun (f : Bar false) => Bar.bar false f +/-- +error: invalid field notation, function 'Bar.bar' does not have argument with type (Bar ...) that can be used, it must be explicit or implicit with a unique name +--- +info: fun f => sorryAx (?_ f) true : (f : Bar false) → ?_ f +-/ +#guard_msgs in #check fun (f : Bar false) => f.bar false + +/-- info: fun f => (fun x => false) f : Bar false → Bool -/ +#guard_msgs in #check fun (f : Bar false) => Bar.bar true false f +/-- +error: invalid field notation, function 'Bar.bar' does not have argument with type (Bar ...) that can be used, it must be explicit or implicit with a unique name +--- +info: fun f => sorryAx (?_ f) true : (f : Bar false) → ?_ f +-/ +#guard_msgs in #check fun (f : Bar false) => f.bar true false