From ddf93d2f8a44ac2efcb73ce16e653886ec5de32d Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 11 Mar 2022 15:39:41 -0800 Subject: [PATCH] feat: support for arrow types in the dot notation cc @gebner --- RELEASES.md | 3 +++ src/Lean/Elab/App.lean | 36 ++++++++++++++++++++++-------------- tests/lean/run/arrowDot.lean | 5 +++++ 3 files changed, 30 insertions(+), 14 deletions(-) create mode 100644 tests/lean/run/arrowDot.lean diff --git a/RELEASES.md b/RELEASES.md index bd7f9d49ae..4edd762fec 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,6 +1,9 @@ v4.0.0-m4 (WIP) --------- +* Extend dot-notation `x.field` for arrow types. If type of `x` is an arrow, we look up for `Function.field`. +For example, given `f : Nat → Nat` and `g : Nat → Nat`, `f.comp g` is now notation for `Function.comp f g`. + * [Add code folding support to the language server](https://github.com/leanprover/lean4/pull/1014). * Support notation `let := | ` in `do` blocks. diff --git a/src/Lean/Elab/App.lean b/src/Lean/Elab/App.lean index d23df4f4a6..517c2e988b 100644 --- a/src/Lean/Elab/App.lean +++ b/src/Lean/Elab/App.lean @@ -538,7 +538,17 @@ private partial def findMethod? (env : Environment) (structName fieldName : Name else none +private def throwInvalidFieldNotation (e eType : Expr) : TermElabM α := + throwLValError e eType "invalid field notation, type is not of the form (C ...) where C is a constant" + private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM LValResolution := do + if eType.isForall then + match lval with + | LVal.fieldName _ fieldName _ _ => + let fullName := `Function ++ fieldName + if (← getEnv).contains fullName then + return LValResolution.const `Function `Function fullName + | _ => pure () match eType.getAppFn.constName?, lval with | some structName, LVal.fieldIdx _ idx => if idx == 0 then @@ -595,11 +605,9 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L if e.isConst then throwUnknownConstant (e.constName! ++ suffix) else - throwLValError e eType "invalid field notation, type is not of the form (C ...) where C is a constant" - | _, LVal.getOp _ idx => - throwLValError e eType "invalid [..] notation, type is not of the form (C ...) where C is a constant" - | _, _ => - throwLValError e eType "invalid field notation, type is not of the form (C ...) where C is a constant" + throwInvalidFieldNotation e eType + | _, LVal.getOp _ idx => throwInvalidFieldNotation e eType + | _, _ => throwInvalidFieldNotation e eType /- whnfCore + implicit consumption. Example: given `e` with `eType := {α : Type} → (fun β => List β) α `, it produces `(e ?m, List ?m)` where `?m` is fresh metavariable. -/ @@ -653,6 +661,14 @@ private partial def mkBaseProjections (baseStructName : Name) (structName : Name e ← elabAppArgs projFn #[{ name := `self, val := Arg.expr e }] (args := #[]) (expectedType? := none) (explicit := false) (ellipsis := false) return e +private def typeMatchesBaseName (type : Expr) (baseName : Name) : MetaM Bool := do + if baseName == `Function then + return (← whnfR type).isForall + else if type.consumeMData.isAppOf baseName then + return true + else + return (← whnfR type).isAppOf baseName + /- Auxiliary method for field notation. It tries to add `e` as a new argument to `args` or `namedArgs`. This method first finds the parameter with a type of the form `(baseName ...)`. When the parameter is found, if it an explicit one and `args` is big enough, we add `e` to `args`. @@ -672,16 +688,8 @@ private def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Ar | some idx => remainingNamedArgs := remainingNamedArgs.eraseIdx idx | none => - let mut foundIt := false let type := xDecl.type - if type.consumeMData.isAppOf baseName then - foundIt := true - if !foundIt then - /- Normalize type and try again -/ - let type ← withReducible $ whnf type - if type.consumeMData.isAppOf baseName then - foundIt := true - if foundIt then + if (← typeMatchesBaseName type baseName) then /- We found a type of the form (baseName ...). First, we check if the current argument is an explicit one, and the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/ diff --git a/tests/lean/run/arrowDot.lean b/tests/lean/run/arrowDot.lean new file mode 100644 index 0000000000..327ef9689c --- /dev/null +++ b/tests/lean/run/arrowDot.lean @@ -0,0 +1,5 @@ +def test (f : Nat → Nat) (g : Nat → Nat) := + f.comp g $ 10 + +example : test (·+1) (·*2) = 21 := + rfl