diff --git a/src/Lean/Elab/App.lean b/src/Lean/Elab/App.lean index 09a437349d..b914811980 100644 --- a/src/Lean/Elab/App.lean +++ b/src/Lean/Elab/App.lean @@ -578,38 +578,53 @@ private partial def mkBaseProjections (baseStructName : Name) (structName : Name e ← elabAppArgs projFn #[{ name := `self, val := Arg.expr e }] (args := #[]) (expectedType? := none) (explicit := false) (ellipsis := false) return e -/- Auxiliary method for field notation. It tries to add `e` to `args` as the first explicit parameter - which takes an element of type `(C ...)` where `C` is `baseName`. - `fullName` is the name of the resolved "field" access function. It is used for reporting errors -/ +/- Auxiliary method for field notation. It tries to add `e` as a new argument to `args` or `namedArgs`. + This method first finds the parameter with a type of the form `(baseName ...)`. + When the parameter is found, if it an explicit one and `args` is big enough, we add `e` to `args`. + Otherwise, if there isn't another parameter with the same name, we add `e` to `namedArgs`. + + Remark: `fullName` is the name of the resolved "field" access function. It is used for reporting errors -/ private def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (fType : Expr) - : TermElabM (Array Arg) := + : TermElabM (Array Arg × Array NamedArg) := forallTelescopeReducing fType fun xs _ => do - let mut i := 0 - let mut namedArgs := namedArgs - for x in xs do + let mut argIdx := 0 -- position of the next explicit argument + let mut remainingNamedArgs := namedArgs + for i in [:xs.size] do + let x := xs[i] let xDecl ← getLocalDecl x.fvarId! - if xDecl.binderInfo.isExplicit then - /- If there is named argument with name `xDecl.userName`, then we skip it. -/ - match namedArgs.findIdx? (fun namedArg => namedArg.name == xDecl.userName) with - | some idx => - namedArgs := namedArgs.eraseIdx idx - | none => - let type := xDecl.type - if type.consumeMData.isAppOf baseName then - -- found it - return args.insertAt i (Arg.expr e) - -- normalize type and try again + /- If there is named argument with name `xDecl.userName`, then we skip it. -/ + match remainingNamedArgs.findIdx? (fun namedArg => namedArg.name == xDecl.userName) with + | some idx => + remainingNamedArgs := remainingNamedArgs.eraseIdx idx + | none => + let mut foundIt := false + let type := xDecl.type + if type.consumeMData.isAppOf baseName then + foundIt := true + if !foundIt then + /- Normalize type and try again -/ let type ← withReducible $ whnf type if type.consumeMData.isAppOf baseName then - -- found it - return args.insertAt i (Arg.expr e) - if i < args.size then - i := i + 1 - else - for namedArg in namedArgs do - throwInvalidNamedArg namedArg fullName - throwError! "invalid field notation, function '{fullName}' does not have explicit argument with type ({baseName} ...)" - return args + foundIt := true + if foundIt then + /- We found a type of the form (baseName ...). + First, we check if the current argument is an explicit one, + and the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/ + if argIdx ≤ args.size && xDecl.binderInfo.isExplicit then + /- We insert `e` as an explicit argument -/ + return (args.insertAt argIdx (Arg.expr e), namedArgs) + /- If we can't add `e` to `args`, we try to add it using a named argument, but this is only possible + if there isn't an argument with the same name occurring before it. -/ + for j in [:i] do + let prev := xs[j] + let prevDecl ← getLocalDecl prev.fvarId! + if prevDecl.userName == xDecl.userName then + throwError! "invalid field notation, function '{fullName}' has argument with the expected type{indentExpr type}\nbut it cannot be used" + return (args, namedArgs.push { name := xDecl.userName, val := Arg.expr e }) + if xDecl.binderInfo.isExplicit then + -- advance explicit argument position + argIdx := argIdx + 1 + throwError! "invalid field notation, function '{fullName}' does not have argument with type ({baseName} ...) that can be used, it must be explicit or implicit with an unique name" private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (expectedType? : Option Expr) (explicit ellipsis : Bool) (f : Expr) (lvals : List LVal) : TermElabM Expr := @@ -635,7 +650,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp let projFn ← mkConst constName if lvals.isEmpty then let projFnType ← inferType projFn - let args ← addLValArg baseStructName constName f args namedArgs projFnType + let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFnType elabAppArgs projFn namedArgs args expectedType? explicit ellipsis else let f ← elabAppArgs projFn #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false) @@ -643,7 +658,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp | LValResolution.localRec baseName fullName fvar => if lvals.isEmpty then let fvarType ← inferType fvar - let args ← addLValArg baseName fullName f args namedArgs fvarType + let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvarType elabAppArgs fvar namedArgs args expectedType? explicit ellipsis else let f ← elabAppArgs fvar #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false) diff --git a/tests/lean/run/modAsClasses.lean b/tests/lean/run/modAsClasses.lean new file mode 100644 index 0000000000..564577f51e --- /dev/null +++ b/tests/lean/run/modAsClasses.lean @@ -0,0 +1,12 @@ +class MyMod := +(a : Nat) + +namespace MyMod +variable [MyMod] +def b := a + 1 +end MyMod + +def myMod1 : MyMod := ⟨0⟩ + +#eval myMod1.a +#eval myMod1.b