From 4cd7e359dfbfd56cca7f4a70d502ed159a61aa3e Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 3 Aug 2021 19:40:44 -0700 Subject: [PATCH] feat: elaborate strict implicit binders --- src/Lean/Elab/App.lean | 51 ++++++++++++------- src/Lean/Elab/Binders.lean | 26 ++++++---- .../PrettyPrinter/Delaborator/Builtins.lean | 17 ++++--- tests/lean/strictImplicit.lean | 7 +++ tests/lean/strictImplicit.lean.expected.out | 4 ++ 5 files changed, 69 insertions(+), 36 deletions(-) create mode 100644 tests/lean/strictImplicit.lean create mode 100644 tests/lean/strictImplicit.lean.expected.out diff --git a/src/Lean/Elab/App.lean b/src/Lean/Elab/App.lean index 920273258d..1ba08afccf 100644 --- a/src/Lean/Elab/App.lean +++ b/src/Lean/Elab/App.lean @@ -388,6 +388,11 @@ private def processExplictArg (k : M Expr) : M Expr := do else finalize +/- Return true if there are regular or named arguments to be processed. -/ +private def hasArgsToProcess : M Bool := do + let s ← get + return !s.args.isEmpty || !s.namedArgs.isEmpty + /- Process a `fType` of the form `{x : A} → B x`. This method assume `fType` is a function type -/ @@ -397,6 +402,17 @@ private def processImplicitArg (k : M Expr) : M Expr := do else addImplicitArg k +/- + Process a `fType` of the form `{{x : A}} → B x`. + This method assume `fType` is a function type -/ +private def processStrictImplicitArg (k : M Expr) : M Expr := do + if (← get).explicit then + processExplictArg k + else if (← hasArgsToProcess) then + addImplicitArg k + else + finalize + /- Return true if the next argument at `args` is of the form `_` -/ private def isNextArgHole : M Bool := do match (← get).args with @@ -423,11 +439,6 @@ private def processInstImplicitArg (k : M Expr) : M Expr := do addNewArg arg k -/- Return true if there are regular or named arguments to be processed. -/ -private def hasArgsToProcess : M Bool := do - let s ← get - pure $ !s.args.isEmpty || !s.namedArgs.isEmpty - /- Elaborate function application arguments. -/ partial def main : M Expr := do let s ← get @@ -444,9 +455,10 @@ partial def main : M Expr := do main | none => match binfo with - | BinderInfo.implicit => processImplicitArg main - | BinderInfo.instImplicit => processInstImplicitArg main - | _ => processExplictArg main + | BinderInfo.implicit => processImplicitArg main + | BinderInfo.instImplicit => processInstImplicitArg main + | BinderInfo.strictImplicit => processStrictImplicitArg main + | _ => processExplictArg main else if (← hasArgsToProcess) then synthesizePendingAndNormalizeFunType main @@ -572,25 +584,25 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L /- whnfCore + implicit consumption. Example: given `e` with `eType := {α : Type} → (fun β => List β) α `, it produces `(e ?m, List ?m)` where `?m` is fresh metavariable. -/ -private partial def consumeImplicits (stx : Syntax) (e eType : Expr) : TermElabM (Expr × Expr) := do +private partial def consumeImplicits (stx : Syntax) (e eType : Expr) (hasArgs : Bool) : TermElabM (Expr × Expr) := do let eType ← whnfCore eType match eType with | Expr.forallE n d b c => - if c.binderInfo.isImplicit then + if c.binderInfo.isImplicit || (hasArgs && c.binderInfo.isStrictImplicit) then let mvar ← mkFreshExprMVar d registerMVarErrorHoleInfo mvar.mvarId! stx - consumeImplicits stx (mkApp e mvar) (b.instantiate1 mvar) + consumeImplicits stx (mkApp e mvar) (b.instantiate1 mvar) hasArgs else if c.binderInfo.isInstImplicit then let mvar ← mkInstMVar d - consumeImplicits stx (mkApp e mvar) (b.instantiate1 mvar) + consumeImplicits stx (mkApp e mvar) (b.instantiate1 mvar) hasArgs else match d.getOptParamDefault? with - | some defVal => consumeImplicits stx (mkApp e defVal) (b.instantiate1 defVal) + | some defVal => consumeImplicits stx (mkApp e defVal) (b.instantiate1 defVal) hasArgs -- TODO: we do not handle autoParams here. | _ => pure (e, eType) | _ => pure (e, eType) -private partial def resolveLValLoop (lval : LVal) (e eType : Expr) (previousExceptions : Array Exception) : TermElabM (Expr × LValResolution) := do - let (e, eType) ← consumeImplicits lval.getRef e eType +private partial def resolveLValLoop (lval : LVal) (e eType : Expr) (previousExceptions : Array Exception) (hasArgs : Bool) : TermElabM (Expr × LValResolution) := do + let (e, eType) ← consumeImplicits lval.getRef e eType hasArgs tryPostponeIfMVar eType try let lvalRes ← resolveLValAux e eType lval @@ -599,15 +611,15 @@ private partial def resolveLValLoop (lval : LVal) (e eType : Expr) (previousExce | ex@(Exception.error _ _) => let eType? ← unfoldDefinition? eType match eType? with - | some eType => resolveLValLoop lval e eType (previousExceptions.push ex) + | some eType => resolveLValLoop lval e eType (previousExceptions.push ex) hasArgs | none => previousExceptions.forM fun ex => logException ex throw ex | ex@(Exception.internal _ _) => throw ex -private def resolveLVal (e : Expr) (lval : LVal) : TermElabM (Expr × LValResolution) := do +private def resolveLVal (e : Expr) (lval : LVal) (hasArgs : Bool) : TermElabM (Expr × LValResolution) := do let eType ← inferType e - resolveLValLoop lval e eType #[] + resolveLValLoop lval e eType #[] hasArgs private partial def mkBaseProjections (baseStructName : Name) (structName : Name) (e : Expr) : TermElabM Expr := do let env ← getEnv @@ -675,7 +687,8 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp | f, lval::lvals => do if let LVal.fieldName (ref := fieldStx) (targetStx := targetStx) .. := lval then addDotCompletionInfo targetStx f expectedType? fieldStx - let (f, lvalRes) ← resolveLVal f lval + let hasArgs := !namedArgs.isEmpty || !args.isEmpty + let (f, lvalRes) ← resolveLVal f lval hasArgs match lvalRes with | LValResolution.projIdx structName idx => let f := mkProj structName idx f diff --git a/src/Lean/Elab/Binders.lean b/src/Lean/Elab/Binders.lean index e7fe909904..55b8bcedd2 100644 --- a/src/Lean/Elab/Binders.lean +++ b/src/Lean/Elab/Binders.lean @@ -104,24 +104,29 @@ private def getBinderIds (ids : Syntax) : TermElabM (Array Syntax) := private def matchBinder (stx : Syntax) : TermElabM (Array BinderView) := do let k := stx.getKind - if k == `Lean.Parser.Term.simpleBinder then + if k == ``Lean.Parser.Term.simpleBinder then -- binderIdent+ >> optType let ids ← getBinderIds stx[0] let type := expandOptType (mkNullNode ids) stx[1] ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := type, bi := BinderInfo.default } - else if k == `Lean.Parser.Term.explicitBinder then + else if k == ``Lean.Parser.Term.explicitBinder then -- `(` binderIdent+ binderType (binderDefault <|> binderTactic)? `)` let ids ← getBinderIds stx[1] let type := expandBinderType (mkNullNode ids) stx[2] let optModifier := stx[3] let type ← expandBinderModifier type optModifier ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := type, bi := BinderInfo.default } - else if k == `Lean.Parser.Term.implicitBinder then + else if k == ``Lean.Parser.Term.implicitBinder then -- `{` binderIdent+ binderType `}` let ids ← getBinderIds stx[1] let type := expandBinderType (mkNullNode ids) stx[2] ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := type, bi := BinderInfo.implicit } - else if k == `Lean.Parser.Term.instBinder then + else if k == ``Lean.Parser.Term.strictImplicitBinder then + -- `⦃` binderIdent+ binderType `⦄` + let ids ← getBinderIds stx[1] + let type := expandBinderType (mkNullNode ids) stx[2] + ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := type, bi := BinderInfo.strictImplicit } + else if k == ``Lean.Parser.Term.instBinder then -- `[` optIdent type `]` let id ← expandOptIdent stx[1] let type := stx[2] @@ -256,15 +261,16 @@ partial def expandFunBinders (binders : Array Syntax) (body : Syntax) : MacroM ( let newBody ← `(match $major:ident with | $pattern => $newBody) pure (binders, newBody, true) match binder with - | Syntax.node `Lean.Parser.Term.implicitBinder _ => loop body (i+1) (newBinders.push binder) - | Syntax.node `Lean.Parser.Term.instBinder _ => loop body (i+1) (newBinders.push binder) - | Syntax.node `Lean.Parser.Term.explicitBinder _ => loop body (i+1) (newBinders.push binder) - | Syntax.node `Lean.Parser.Term.simpleBinder _ => loop body (i+1) (newBinders.push binder) - | Syntax.node `Lean.Parser.Term.hole _ => + | Syntax.node ``Lean.Parser.Term.implicitBinder _ => loop body (i+1) (newBinders.push binder) + | Syntax.node ``Lean.Parser.Term.strictImplicitBinder _ => loop body (i+1) (newBinders.push binder) + | Syntax.node ``Lean.Parser.Term.instBinder _ => loop body (i+1) (newBinders.push binder) + | Syntax.node ``Lean.Parser.Term.explicitBinder _ => loop body (i+1) (newBinders.push binder) + | Syntax.node ``Lean.Parser.Term.simpleBinder _ => loop body (i+1) (newBinders.push binder) + | Syntax.node ``Lean.Parser.Term.hole _ => let ident ← mkFreshIdent binder let type := binder loop body (i+1) (newBinders.push <| mkExplicitBinder ident type) - | Syntax.node `Lean.Parser.Term.paren args => + | Syntax.node ``Lean.Parser.Term.paren args => -- `(` (termParser >> parenSpecial)? `)` -- parenSpecial := (tupleTail <|> typeAscription)? let binderBody := binder[1] diff --git a/src/Lean/PrettyPrinter/Delaborator/Builtins.lean b/src/Lean/PrettyPrinter/Delaborator/Builtins.lean index 6caaa51e03..0bf5cfe377 100644 --- a/src/Lean/PrettyPrinter/Delaborator/Builtins.lean +++ b/src/Lean/PrettyPrinter/Delaborator/Builtins.lean @@ -506,10 +506,12 @@ def delabLam : Delab := else pure $ curNames.get! 0; `(funBinder| ($stxCurNames : $stxT)) - | BinderInfo.default, false => pure curNames.back -- here `curNames.size == 1` - | BinderInfo.implicit, true => `(funBinder| {$curNames* : $stxT}) - | BinderInfo.implicit, false => `(funBinder| {$curNames*}) - | BinderInfo.instImplicit, _ => + | BinderInfo.default, false => pure curNames.back -- here `curNames.size == 1` + | BinderInfo.implicit, true => `(funBinder| {$curNames* : $stxT}) + | BinderInfo.implicit, false => `(funBinder| {$curNames*}) + | BinderInfo.strictImplicit, true => `(funBinder| ⦃$curNames* : $stxT⦄) + | BinderInfo.strictImplicit, false => `(funBinder| ⦃$curNames*⦄) + | BinderInfo.instImplicit, _ => if usedDownstream then `(funBinder| [$curNames.back : $stxT]) -- here `curNames.size == 1` else `(funBinder| [$stxT]) | _ , _ => unreachable!; @@ -524,10 +526,11 @@ def delabForall : Delab := let prop ← try isProp e catch _ => false let stxT ← withBindingDomain delab let group ← match e.binderInfo with - | BinderInfo.implicit => `(bracketedBinderF|{$curNames* : $stxT}) + | BinderInfo.implicit => `(bracketedBinderF|{$curNames* : $stxT}) + | BinderInfo.strictImplicit => `(bracketedBinderF|⦃$curNames* : $stxT⦄) -- here `curNames.size == 1` - | BinderInfo.instImplicit => `(bracketedBinderF|[$curNames.back : $stxT]) - | _ => + | BinderInfo.instImplicit => `(bracketedBinderF|[$curNames.back : $stxT]) + | _ => -- heuristic: use non-dependent arrows only if possible for whole group to avoid -- noisy mix like `(α : Type) → Type → (γ : Type) → ...`. let dependent := curNames.any $ fun n => hasIdent n.getId stxBody diff --git a/tests/lean/strictImplicit.lean b/tests/lean/strictImplicit.lean new file mode 100644 index 0000000000..1cc86a6855 --- /dev/null +++ b/tests/lean/strictImplicit.lean @@ -0,0 +1,7 @@ +def g {α : Type} (a : α) := a +def f {{α : Type}} (a : α) := a + +#check g +#check f +#check g 1 +#check f 1 diff --git a/tests/lean/strictImplicit.lean.expected.out b/tests/lean/strictImplicit.lean.expected.out new file mode 100644 index 0000000000..a8f0383ac9 --- /dev/null +++ b/tests/lean/strictImplicit.lean.expected.out @@ -0,0 +1,4 @@ +g : ?m → ?m +f : ⦃α : Type⦄ → α → α +g 1 : Nat +f 1 : Nat