diff --git a/src/Init/Lean/Delaborator.lean b/src/Init/Lean/Delaborator.lean index 0f2a1df4a1..28ca37e83b 100644 --- a/src/Init/Lean/Delaborator.lean +++ b/src/Init/Lean/Delaborator.lean @@ -29,6 +29,7 @@ prelude import Init.Lean.KeyedDeclsAttribute import Init.Lean.ProjFns import Init.Lean.Syntax +import Init.Lean.Elab.Term namespace Lean @@ -348,29 +349,36 @@ private partial def delabBinders (delabGroup : Array Syntax → Syntax → Delab (withBindingBody n delab >>= delabGroup curNames) @[builtinDelab lam] -def delabExplicitLam : Delab := +def delabLam : Delab := delabBinders $ fun curNames stxBody => do e ← getExpr | unreachable!; stxT ← withBindingDomain delab; ppTypes ← getPPOption getPPBinderTypes; - group ← match e.binderInfo, ppTypes with - | BinderInfo.default, true => do - -- "default" binder group is the only one that expects binder names - -- as a term, i.e. a single `Term.id` or an application thereof - let curNames := curNames.map mkTermIdFromIdent; - stxCurNames ← if curNames.size > 1 then `($(curNames.get! 0) $(curNames.eraseIdx 0)*) - else pure $ curNames.get! 0; - `(funBinder| ($stxCurNames : $stxT)) - | BinderInfo.default, false => pure $ mkTermIdFromIdent curNames.back -- here `curNames.size == 1` - | BinderInfo.implicit, true => `(funBinder| {$curNames* : $stxT}) - | BinderInfo.implicit, false => `(funBinder| {$curNames*}) - | BinderInfo.instImplicit, _ => `(funBinder| [$curNames.back : $stxT]) -- here `curNames.size == 1` - | _ , _ => unreachable!; - match_syntax stxBody with - | `(@(fun $binderGroups* => $stxBody)) => `(@(fun $group $binderGroups* => $stxBody)) - | _ => `(@(fun $group => $stxBody)) - --- TODO: implicit lambdas + expl ← getPPOption getPPExplicit; + -- leave lambda implicit if possible + let blockImplicitLambda := expl || + e.binderInfo == BinderInfo.default || + Elab.Term.blockImplicitLambda stxBody || + curNames.any (fun n => hasIdent n.getId stxBody); + if !blockImplicitLambda then + pure stxBody + else do + group ← match e.binderInfo, ppTypes with + | BinderInfo.default, true => do + -- "default" binder group is the only one that expects binder names + -- as a term, i.e. a single `Term.id` or an application thereof + let curNames := curNames.map mkTermIdFromIdent; + stxCurNames ← if curNames.size > 1 then `($(curNames.get! 0) $(curNames.eraseIdx 0)*) + else pure $ curNames.get! 0; + `(funBinder| ($stxCurNames : $stxT)) + | BinderInfo.default, false => pure $ mkTermIdFromIdent curNames.back -- here `curNames.size == 1` + | BinderInfo.implicit, true => `(funBinder| {$curNames* : $stxT}) + | BinderInfo.implicit, false => `(funBinder| {$curNames*}) + | BinderInfo.instImplicit, _ => `(funBinder| [$curNames.back : $stxT]) -- here `curNames.size == 1` + | _ , _ => unreachable!; + match_syntax stxBody with + | `(fun $binderGroups* => $stxBody) => `(fun $group $binderGroups* => $stxBody) + | _ => `(fun $group => $stxBody) @[builtinDelab forallE] def delabForall : Delab := diff --git a/src/Init/Lean/Elab/Term.lean b/src/Init/Lean/Elab/Term.lean index 1e5e15521d..c9fde4400e 100644 --- a/src/Init/Lean/Elab/Term.lean +++ b/src/Init/Lean/Elab/Term.lean @@ -758,11 +758,15 @@ match_syntax stx with | `(fun $binders* => $body) => binders.any $ fun b => b.isOfKind `Lean.Parser.Term.implicitBinder || b.isOfKind `Lean.Parser.Term.instBinder | _ => false +/-- Block usage of implicit lambdas if `stx` is `@f` or `@f arg1 ...` or `fun` with an implicit binder annotation. -/ +def blockImplicitLambda (stx : Syntax) : Bool := +isExplicit stx || isExplicitApp stx || isLambdaWithImplicit stx + /-- - Return true with `expectedType` is of the form `{a : α} → β` or `[a : α] → β`, and - `stx` is not `@f` nor `@f arg1 ...` -/ -def useImplicitLambda? (stx : Syntax) (expectedType? : Option Expr) (implicitLambda : Bool) : TermElabM (Option Expr) := -if !implicitLambda || isExplicit stx || isExplicitApp stx || isLambdaWithImplicit stx then pure none + Return normalized expected type if it is of the form `{a : α} → β` or `[a : α] → β` and + `blockImplicitLambda stx` is not true, else return `none`. -/ +def useImplicitLambda? (stx : Syntax) (expectedType? : Option Expr) : TermElabM (Option Expr) := +if blockImplicitLambda stx then pure none else match expectedType? with | some expectedType => do expectedType ← whnfForall stx expectedType; @@ -803,7 +807,7 @@ partial def elabTermAux (expectedType? : Option Expr) (catchExPostpone : Bool) ( match stxNew? with | some stxNew => withMacroExpansion stx stxNew $ elabTermAux stxNew | _ => do - implicit? ← useImplicitLambda? stx expectedType? implicitLambda; + implicit? ← if implicitLambda then useImplicitLambda? stx expectedType? else pure none; match implicit? with | some expectedType => elabImplicitLambda stx catchExPostpone expectedType #[] | none => elabUsingElabFns stx expectedType? catchExPostpone diff --git a/tests/lean/Delaborator.lean b/tests/lean/Delaborator.lean index 61644e9323..9816a74a05 100644 --- a/tests/lean/Delaborator.lean +++ b/tests/lean/Delaborator.lean @@ -2,6 +2,7 @@ import Init.Lean open Lean open Lean.Elab open Lean.Elab.Term +open Lean.Format def check (stx : TermElabM Syntax) (optionsPerPos : OptionsPerPos := {}) : TermElabM Unit := do stx ← stx; @@ -11,7 +12,7 @@ stx' ← liftMetaM stx $ delab e opts optionsPerPos; dbgTrace $ toString stx'; e' ← elabTermAndSynthesize stx' none <* throwErrorIfErrors; unlessM (isDefEq stx e e') $ - throwError stx "failed to round-trip" + throwError stx (fmt "failed to round-trip" ++ line ++ fmt e ++ line ++ fmt e') -- #eval check `(?m) -- fails round-trip @@ -38,11 +39,18 @@ section end #eval check `(id (id Nat)) (RBMap.empty.insert 4 $ KVMap.empty.insert `pp.explicit true) +-- specify the expected type of `a` in a way that is not erased by the delaborator +def typeAs.{u} (α : Type u) (a : α) := () + #eval check `(fun (a : Nat) => a) #eval check `(fun (a b : Nat) => a) #eval check `(fun (a : Nat) (b : Bool) => a) -#eval check `(@(fun (a b : Nat) => a)) -#eval check `(@(fun α (s : HasToString α) => true)) +#eval check `(fun {a b : Nat} => a) +-- implicit lambdas work as long as the expected type is preserved +#eval check `(typeAs ({α : Type} → (a : α) → α) (fun a => a)) +section set_option pp.explicit true + #eval check `(fun {α : Type} [HasToString α] (a : α) => toString a) +end #eval check `((α : Type) → α) #eval check `((α β : Type) → α) -- group diff --git a/tests/lean/Delaborator.lean.expected.out b/tests/lean/Delaborator.lean.expected.out index 5064c0ee93..5fd1855633 100644 --- a/tests/lean/Delaborator.lean.expected.out +++ b/tests/lean/Delaborator.lean.expected.out @@ -18,86 +18,56 @@ (Term.app (Term.id `id (null)) (null (Term.app (Term.explicit "@" (Term.id `id (null))) (null (Term.type "Type") (Term.id `Nat (null)))))) -(Term.explicit - "@" - (Term.paren - "(" - (null - (Term.fun - "fun" - (null (Term.paren "(" (null (Term.id `a (null)) (null (Term.typeAscription ":" (Term.id `Nat (null))))) ")")) - "=>" - (Term.id `a (null))) - (null)) - ")")) -(Term.explicit - "@" - (Term.paren - "(" - (null - (Term.fun - "fun" - (null - (Term.paren - "(" - (null (Term.app (Term.id `a (null)) (null (Term.id `b (null)))) (null (Term.typeAscription ":" (Term.id `Nat (null))))) - ")")) - "=>" - (Term.id `a (null))) - (null)) - ")")) -(Term.explicit - "@" - (Term.paren - "(" - (null - (Term.fun - "fun" - (null - (Term.paren "(" (null (Term.id `a (null)) (null (Term.typeAscription ":" (Term.id `Nat (null))))) ")") - (Term.paren "(" (null (Term.id `b (null)) (null (Term.typeAscription ":" (Term.id `Bool (null))))) ")")) - "=>" - (Term.id `a (null))) - (null)) - ")")) -(Term.explicit - "@" - (Term.paren - "(" - (null - (Term.fun - "fun" - (null - (Term.paren - "(" - (null (Term.app (Term.id `a (null)) (null (Term.id `b (null)))) (null (Term.typeAscription ":" (Term.id `Nat (null))))) - ")")) - "=>" - (Term.id `a (null))) - (null)) - ")")) -(Term.explicit - "@" - (Term.paren - "(" - (null - (Term.fun - "fun" - (null - (Term.paren - "(" - (null (Term.id `α (null)) (null (Term.typeAscription ":" (Term.sortApp (Term.type "Type") (Level.hole "_"))))) - ")") - (Term.paren - "(" - (null - (Term.id `s (null)) - (null (Term.typeAscription ":" (Term.app (Term.id `HasToString (null)) (null (Term.id `α (null))))))) - ")")) - "=>" - (Term.id `Bool.true (null))) - (null)) - ")")) +(Term.fun + "fun" + (null (Term.paren "(" (null (Term.id `a (null)) (null (Term.typeAscription ":" (Term.id `Nat (null))))) ")")) + "=>" + (Term.id `a (null))) +(Term.fun + "fun" + (null + (Term.paren + "(" + (null (Term.app (Term.id `a (null)) (null (Term.id `b (null)))) (null (Term.typeAscription ":" (Term.id `Nat (null))))) + ")")) + "=>" + (Term.id `a (null))) +(Term.fun + "fun" + (null + (Term.paren "(" (null (Term.id `a (null)) (null (Term.typeAscription ":" (Term.id `Nat (null))))) ")") + (Term.paren "(" (null (Term.id `b (null)) (null (Term.typeAscription ":" (Term.id `Bool (null))))) ")")) + "=>" + (Term.id `a (null))) +(Term.fun + "fun" + (null (Term.implicitBinder "{" (null `a `b) (null ":" (Term.id `Nat (null))) "}")) + "=>" + (Term.id `a (null))) +(Term.app + (Term.id `typeAs (null)) + (null + (Term.depArrow + (Term.implicitBinder "{" (null `α) (null ":" (Term.type "Type")) "}") + "→" + (Term.arrow (Term.id `α (null)) "→" (Term.id `α (null)))) + (Term.fun + "fun" + (null + (Term.implicitBinder "{" (null `α) (null ":" (Term.type "Type")) "}") + (Term.paren "(" (null (Term.id `a (null)) (null (Term.typeAscription ":" (Term.id `α (null))))) ")")) + "=>" + (Term.id `a (null))))) +(Term.fun + "fun" + (null + (Term.implicitBinder "{" (null `α) (null ":" (Term.type "Type")) "}") + (Term.instBinder "[" (null `_inst_1 ":") (Term.app (Term.id `HasToString (null)) (null (Term.id `α (null)))) "]") + (Term.paren "(" (null (Term.id `a (null)) (null (Term.typeAscription ":" (Term.id `α (null))))) ")")) + "=>" + (Term.app + (Term.explicit "@" (Term.id `HasToString.toString (null))) + (null (Term.id `α (null)) (Term.id `_inst_1 (null)) (Term.id `a (null))))) (Term.depArrow (Term.explicitBinder "(" (null `α) (null ":" (Term.type "Type")) (null) ")") "→" (Term.id `α (null))) (Term.depArrow (Term.explicitBinder "(" (null `α `β) (null ":" (Term.type "Type")) (null) ")") "→" (Term.id `α (null))) (Term.arrow (Term.type "Type") "→" (Term.arrow (Term.type "Type") "→" (Term.type "Type")))