feat: support for arrow types in the dot notation

cc @gebner
This commit is contained in:
Leonardo de Moura 2022-03-11 15:39:41 -08:00
parent 8d42978e63
commit ddf93d2f8a
3 changed files with 30 additions and 14 deletions

View file

@ -1,6 +1,9 @@
v4.0.0-m4 (WIP)
---------
* Extend dot-notation `x.field` for arrow types. If type of `x` is an arrow, we look up for `Function.field`.
For example, given `f : Nat → Nat` and `g : Nat → Nat`, `f.comp g` is now notation for `Function.comp f g`.
* [Add code folding support to the language server](https://github.com/leanprover/lean4/pull/1014).
* Support notation `let <pattern> := <expr> | <else-case>` in `do` blocks.

View file

@ -538,7 +538,17 @@ private partial def findMethod? (env : Environment) (structName fieldName : Name
else
none
private def throwInvalidFieldNotation (e eType : Expr) : TermElabM α :=
throwLValError e eType "invalid field notation, type is not of the form (C ...) where C is a constant"
private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM LValResolution := do
if eType.isForall then
match lval with
| LVal.fieldName _ fieldName _ _ =>
let fullName := `Function ++ fieldName
if (← getEnv).contains fullName then
return LValResolution.const `Function `Function fullName
| _ => pure ()
match eType.getAppFn.constName?, lval with
| some structName, LVal.fieldIdx _ idx =>
if idx == 0 then
@ -595,11 +605,9 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L
if e.isConst then
throwUnknownConstant (e.constName! ++ suffix)
else
throwLValError e eType "invalid field notation, type is not of the form (C ...) where C is a constant"
| _, LVal.getOp _ idx =>
throwLValError e eType "invalid [..] notation, type is not of the form (C ...) where C is a constant"
| _, _ =>
throwLValError e eType "invalid field notation, type is not of the form (C ...) where C is a constant"
throwInvalidFieldNotation e eType
| _, LVal.getOp _ idx => throwInvalidFieldNotation e eType
| _, _ => throwInvalidFieldNotation e eType
/- whnfCore + implicit consumption.
Example: given `e` with `eType := {α : Type} → (fun β => List β) α `, it produces `(e ?m, List ?m)` where `?m` is fresh metavariable. -/
@ -653,6 +661,14 @@ private partial def mkBaseProjections (baseStructName : Name) (structName : Name
e ← elabAppArgs projFn #[{ name := `self, val := Arg.expr e }] (args := #[]) (expectedType? := none) (explicit := false) (ellipsis := false)
return e
private def typeMatchesBaseName (type : Expr) (baseName : Name) : MetaM Bool := do
if baseName == `Function then
return (← whnfR type).isForall
else if type.consumeMData.isAppOf baseName then
return true
else
return (← whnfR type).isAppOf baseName
/- Auxiliary method for field notation. It tries to add `e` as a new argument to `args` or `namedArgs`.
This method first finds the parameter with a type of the form `(baseName ...)`.
When the parameter is found, if it an explicit one and `args` is big enough, we add `e` to `args`.
@ -672,16 +688,8 @@ private def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Ar
| some idx =>
remainingNamedArgs := remainingNamedArgs.eraseIdx idx
| none =>
let mut foundIt := false
let type := xDecl.type
if type.consumeMData.isAppOf baseName then
foundIt := true
if !foundIt then
/- Normalize type and try again -/
let type ← withReducible $ whnf type
if type.consumeMData.isAppOf baseName then
foundIt := true
if foundIt then
if (← typeMatchesBaseName type baseName) then
/- We found a type of the form (baseName ...).
First, we check if the current argument is an explicit one,
and the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/

View file

@ -0,0 +1,5 @@
def test (f : Nat → Nat) (g : Nat → Nat) :=
f.comp g $ 10
example : test (·+1) (·*2) = 21 :=
rfl