From 606aeddf067ca0b4bb110ee157ec4a6672d93373 Mon Sep 17 00:00:00 2001 From: Kyle Miller Date: Mon, 25 Nov 2024 10:38:17 -0800 Subject: [PATCH] feat: make dot notation be affected by `export`/`open` (#6189) This PR changes how generalized field notation ("dot notation") resolves the function. The new resolution rule is that if `x : S`, then `x.f` resolves the name `S.f` relative to the root namespace (hence it now affected by `export` and `open`). Breaking change: aliases now resolve differently. Before, if `x : S`, and if `S.f` is an alias for `S'.f`, then `x.f` would use `S'.f` and look for an argument of type `S'`. Now, it looks for an argument of type `S`, which is more generally useful behavior. Code making use of the old behavior should consider defining `S` or `S'` in terms of the other, since dot notation can unfold definitions during resolution. This also fixes a bug in explicit-mode generalized field notation (`@x.f`) where `x` could be passed as the wrong argument. This was not a bug for explicit-mode structure projections. Closes #3031. Addresses the `Function` namespace issue in #1629. --- src/Lean/Elab/App.lean | 101 ++++++++++++++--------------------- tests/lean/run/3031.lean | 109 ++++++++++++++++++++++++++++++++++++++ tests/lean/run/DVec.lean | 13 ++++- tests/lean/run/alias.lean | 21 +++++++- 4 files changed, 179 insertions(+), 65 deletions(-) create mode 100644 tests/lean/run/3031.lean diff --git a/src/Lean/Elab/App.lean b/src/Lean/Elab/App.lean index 4ba28f696e..12ebca9e7b 100644 --- a/src/Lean/Elab/App.lean +++ b/src/Lean/Elab/App.lean @@ -1150,48 +1150,33 @@ private def throwLValError (e : Expr) (eType : Expr) (msg : MessageData) : TermE throwError "{msg}{indentExpr e}\nhas type{indentExpr eType}" /-- -`findMethod? S fName` tries the following for each namespace `S'` in the resolution order for `S`: -- If `env` contains `S' ++ fName`, returns `(S', S' ++ fName)` -- Otherwise if `env` contains private name `prv` for `S' ++ fName`, returns `(S', prv)` +`findMethod? S fName` tries the for each namespace `S'` in the resolution order for `S` to resolve the name `S'.fname`. +If it resolves to `name`, returns `(S', name)`. -/ private partial def findMethod? (structName fieldName : Name) : MetaM (Option (Name × Name)) := do let env ← getEnv let find? structName' : MetaM (Option (Name × Name)) := do let fullName := structName' ++ fieldName - if env.contains fullName then - return some (structName', fullName) - let fullNamePrv := mkPrivateName env fullName - if env.contains fullNamePrv then - return some (structName', fullNamePrv) - return none + -- We do not want to make use of the current namespace for resolution. + let candidates := ResolveName.resolveGlobalName (← getEnv) Name.anonymous (← getOpenDecls) fullName + |>.filter (fun (_, fieldList) => fieldList.isEmpty) + |>.map Prod.fst + match candidates with + | [] => return none + | [fullName'] => return some (structName', fullName') + | _ => throwError "\ + invalid field notation '{fieldName}', the name '{fullName}' is ambiguous, possible interpretations: \ + {MessageData.joinSep (candidates.map (m!"'{.ofConstName ·}'")) ", "}" -- Optimization: the first element of the resolution order is `structName`, -- so we can skip computing the resolution order in the common case -- of the name resolving in the `structName` namespace. find? structName <||> do let resolutionOrder ← if isStructure env structName then getStructureResolutionOrder structName else pure #[structName] - for h : i in [1:resolutionOrder.size] do - if let some res ← find? resolutionOrder[i] then + for ns in resolutionOrder[1:resolutionOrder.size] do + if let some res ← find? ns then return res return none -/-- - Return `some (structName', fullName)` if `structName ++ fieldName` is an alias for `fullName`, and - `fullName` is of the form `structName' ++ fieldName`. - - TODO: if there is more than one applicable alias, it returns `none`. We should consider throwing an error or - warning. --/ -private def findMethodAlias? (env : Environment) (structName fieldName : Name) : Option (Name × Name) := - let fullName := structName ++ fieldName - -- We never skip `protected` aliases when resolving dot-notation. - let aliasesCandidates := getAliases env fullName (skipProtected := false) |>.filterMap fun alias => - match alias.eraseSuffix? fieldName with - | none => none - | some structName' => some (structName', alias) - match aliasesCandidates with - | [r] => some r - | _ => none - private def throwInvalidFieldNotation (e eType : Expr) : TermElabM α := throwLValError e eType "invalid field notation, type is not of the form (C ...) where C is a constant" @@ -1223,30 +1208,22 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L throwLValError e eType m!"invalid projection, structure has only {numFields} field(s)" | some structName, LVal.fieldName _ fieldName _ _ => let env ← getEnv - let searchEnv : Unit → TermElabM LValResolution := fun _ => do - if let some (baseStructName, fullName) ← findMethod? structName (.mkSimple fieldName) then - return LValResolution.const baseStructName structName fullName - else if let some (structName', fullName) := findMethodAlias? env structName (.mkSimple fieldName) then - return LValResolution.const structName' structName' fullName - else - throwLValError e eType - m!"invalid field '{fieldName}', the environment does not contain '{Name.mkStr structName fieldName}'" - -- search local context first, then environment - let searchCtx : Unit → TermElabM LValResolution := fun _ => do - let fullName := Name.mkStr structName fieldName - for localDecl in (← getLCtx) do - if localDecl.isAuxDecl then - if let some localDeclFullName := (← read).auxDeclToFullName.find? localDecl.fvarId then - if fullName == (privateToUserName? localDeclFullName).getD localDeclFullName then - /- LVal notation is being used to make a "local" recursive call. -/ - return LValResolution.localRec structName fullName localDecl.toExpr - searchEnv () if isStructure env structName then - match findField? env structName (Name.mkSimple fieldName) with - | some baseStructName => return LValResolution.projFn baseStructName structName (Name.mkSimple fieldName) - | none => searchCtx () - else - searchCtx () + if let some baseStructName := findField? env structName (Name.mkSimple fieldName) then + return LValResolution.projFn baseStructName structName (Name.mkSimple fieldName) + -- Search the local context first + let fullName := Name.mkStr structName fieldName + for localDecl in (← getLCtx) do + if localDecl.isAuxDecl then + if let some localDeclFullName := (← read).auxDeclToFullName.find? localDecl.fvarId then + if fullName == (privateToUserName? localDeclFullName).getD localDeclFullName then + /- LVal notation is being used to make a "local" recursive call. -/ + return LValResolution.localRec structName fullName localDecl.toExpr + -- Then search the environment + if let some (baseStructName, fullName) ← findMethod? structName (.mkSimple fieldName) then + return LValResolution.const baseStructName structName fullName + throwLValError e eType + m!"invalid field '{fieldName}', the environment does not contain '{Name.mkStr structName fieldName}'" | none, LVal.fieldName _ _ (some suffix) _ => if e.isConst then throwUnknownConstant (e.constName! ++ suffix) @@ -1326,7 +1303,7 @@ Otherwise, if there isn't another parameter with the same name, we add `e` to `n Remark: `fullName` is the name of the resolved "field" access function. It is used for reporting errors -/ -private partial def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (f : Expr) : +private partial def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (f : Expr) (explicit : Bool) : MetaM (Array Arg × Array NamedArg) := do withoutModifyingState <| go f (← inferType f) 0 namedArgs (namedArgs.map (·.name)) true where @@ -1351,11 +1328,11 @@ where /- If there is named argument with name `xDecl.userName`, then it is accounted for and we can't make use of it. -/ remainingNamedArgs := remainingNamedArgs.eraseIdx idx else - if (← typeMatchesBaseName xDecl.type baseName) then - /- We found a type of the form (baseName ...). - First, we check if the current argument is an explicit one, + if ← typeMatchesBaseName xDecl.type baseName then + /- We found a type of the form (baseName ...), or we found the first explicit argument in useFirstExplicit mode. + First, we check if the current argument is one that can be used positionally, and if the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/ - if h : argIdx ≤ args.size ∧ bInfo.isExplicit then + if h : argIdx ≤ args.size ∧ (explicit || bInfo.isExplicit) then /- We can insert `e` as an explicit argument -/ return (args.insertIdx argIdx (Arg.expr e), namedArgs) else @@ -1363,13 +1340,13 @@ where if there isn't an argument with the same name occurring before it. -/ if !allowNamed || unusableNamedArgs.contains xDecl.userName then throwError "\ - invalid field notation, function '{fullName}' has argument with the expected type\ + invalid field notation, function '{.ofConstName fullName}' has argument with the expected type\ {indentExpr xDecl.type}\n\ but it cannot be used" else return (args, namedArgs.push { name := xDecl.userName, val := Arg.expr e }) /- Advance `argIdx` and update seen named arguments. -/ - if bInfo.isExplicit then + if explicit || bInfo.isExplicit then argIdx := argIdx + 1 unusableNamedArgs := unusableNamedArgs.push xDecl.userName /- If named arguments aren't allowed, then it must still be possible to pass the value as an explicit argument. @@ -1380,7 +1357,7 @@ where if let some f' ← coerceToFunction? (mkAppN f xs) then return ← go f' (← inferType f') argIdx remainingNamedArgs unusableNamedArgs false throwError "\ - invalid field notation, function '{fullName}' does not have argument with type ({baseName} ...) that can be used, \ + invalid field notation, function '{.ofConstName fullName}' does not have argument with type ({.ofConstName baseName} ...) that can be used, \ it must be explicit or implicit with a unique name" /-- Adds the `TermInfo` for the field of a projection. See `Lean.Parser.Term.identProjKind`. -/ @@ -1426,7 +1403,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp let projFn ← mkConst constName let projFn ← addProjTermInfo lval.getRef projFn if lvals.isEmpty then - let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFn + let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFn explicit elabAppArgs projFn namedArgs args expectedType? explicit ellipsis else let f ← elabAppArgs projFn #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false) @@ -1434,7 +1411,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp | LValResolution.localRec baseName fullName fvar => let fvar ← addProjTermInfo lval.getRef fvar if lvals.isEmpty then - let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvar + let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvar explicit elabAppArgs fvar namedArgs args expectedType? explicit ellipsis else let f ← elabAppArgs fvar #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false) diff --git a/tests/lean/run/3031.lean b/tests/lean/run/3031.lean new file mode 100644 index 0000000000..d9956916e5 --- /dev/null +++ b/tests/lean/run/3031.lean @@ -0,0 +1,109 @@ +/-! +# Tests for generalized field notation through aliases and "top-level" dot notation +https://github.com/leanprover/lean4/issues/3031 +-/ + +/-! +Alias dot notation. There used to be a different kind of alias dot notation; +in the following example, it would have looked for an argument of type `Common.String`. +Now it looks for one of type `String`, allowing libraries to add "extension methods" from within their own namespaces. +-/ +def Common.String.a (s : String) : Nat := s.length + +export Common (String.a) + +/-- info: String.a "x" : Nat -/ +#guard_msgs in #check String.a "x" +/-- info: String.a "x" : Nat -/ +#guard_msgs in #check "x".a + +/-! +Declarations take precedence over aliases +-/ +def String.a (s : String) : Nat := s.length + 100 +/-- info: "x".a : Nat -/ +#guard_msgs in #check "x".a +/-- info: 100 -/ +#guard_msgs in #eval "".a + +/-! +Private declarations take precedence over aliases +-/ +private def String.b (s : String) : Nat := 0 +def Common.String.b (s : String) : Nat := 1 +export Common (String.b) +/-- info: 0 -/ +#guard_msgs in #eval "".b + +/-! +Multiple aliases is an error +-/ +def Common.String.c (s : String) : Nat := 0 +def Common'.String.c (s : String) : Nat := 0 +export Common (String.c) +export Common' (String.c) +/-- +error: invalid field notation 'c', the name 'String.c' is ambiguous, possible interpretations: 'Common'.String.c', 'Common.String.c' +-/ +#guard_msgs in #eval "".c + +/-! +Aliases work with inheritance +-/ +namespace Ex1 +structure A +structure B extends A +def Common.A.x (_ : A) : Nat := 0 +export Common (A.x) +/-- info: fun b => A.x b.toA : B → Nat -/ +#guard_msgs in #check fun (b : B) => b.x +end Ex1 + +/-! +`open` also works +-/ +def Common.String.parse (_ : String) : List Nat := [] + +namespace ExOpen1 +/-- +error: invalid field 'parse', the environment does not contain 'String.parse' + "" +has type + String +-/ +#guard_msgs in #check "".parse +section +open Common +/-- info: String.parse "" : List Nat -/ +#guard_msgs in #check "".parse +end +section +open Common (String.parse) +/-- info: String.parse "" : List Nat -/ +#guard_msgs in #check "".parse +end +end ExOpen1 + + +namespace Ex2 +class A (n : Nat) where + x : Nat + +/-! +Incidental fix: `@` for generalized field notation was failing if there were implicit arguments. +True projections were ok. +-/ +def A.x' {n : Nat} (a : A n) := a.x + +/-- info: fun a => a.x' : A 2 → Nat -/ +#guard_msgs in #check fun (a : A 2) => @a.x' +end Ex2 + +namespace Ex3 +variable (f : α → β) (g : β → γ) +/-! +Functions use the "top-level" dot notation rule: they use the first explicit argument, rather than the first function argument. +-/ +/-- info: g ∘ f : α → γ -/ +#guard_msgs in #check g.comp f +end Ex3 diff --git a/tests/lean/run/DVec.lean b/tests/lean/run/DVec.lean index e9f98f8b65..0006fe7aa6 100644 --- a/tests/lean/run/DVec.lean +++ b/tests/lean/run/DVec.lean @@ -38,6 +38,17 @@ example (v : Vec Nat 1) : Nat := #check @Vec.hd --- works +-- Does not work: Aliases find that `v` could be the `TypeVec` argument since `TypeVec` is an abbrev for `Vec`. +/-- +error: application type mismatch + @Vec.hd ?_ v +argument + v +has type + Vec Nat 1 : Type +but is expected to have type + TypeVec (?_ + 1) : Type (_ + 1) +-/ +#guard_msgs in set_option pp.mvars false in example (v : Vec Nat 1) : Nat := v.hd diff --git a/tests/lean/run/alias.lean b/tests/lean/run/alias.lean index 1174ec65aa..126c8ab1e4 100644 --- a/tests/lean/run/alias.lean +++ b/tests/lean/run/alias.lean @@ -3,14 +3,31 @@ def Set (α : Type) := α → Prop def Set.union (s₁ s₂ : Set α) : Set α := fun a => s₁ a ∨ s₂ a -def FinSet (n : Nat) := Fin n → Prop +def FinSet (n : Nat) := Set (Fin n) + +/-! +The type of `x` is unfolded to find `Set.union` +-/ +example (x y : FinSet 10) : FinSet 10 := + x.union y namespace FinSet - export Set (union) +export Set (union) end FinSet +/-! +Since the types are defeq, this alias works: +-/ example (x y : FinSet 10) : FinSet 10 := FinSet.union x y +/-! +However, this dot notation fails since there is no `FinSet` argument. +However, unfolding is the preferred method. +-/ +/-- +error: invalid field notation, function 'FinSet.union' does not have argument with type (FinSet ...) that can be used, it must be explicit or implicit with a unique name +-/ +#guard_msgs in example (x y : FinSet 10) : FinSet 10 := x.union y