diff --git a/src/Lean/Elab/StructInst.lean b/src/Lean/Elab/StructInst.lean index bfa99d8ef1..336cf1c698 100644 --- a/src/Lean/Elab/StructInst.lean +++ b/src/Lean/Elab/StructInst.lean @@ -185,6 +185,13 @@ instance Struct.inhabited : Inhabited Struct := ⟨⟨arbitrary _, arbitrary _, abbrev Fields := List (Field Struct) +/- true if all fields of the given structure are marked as `default` -/ +partial def Struct.allDefault : Struct → Bool +| ⟨_, _, fields, _⟩ => fields.all fun ⟨_, _, val, _⟩ => match val with + | FieldVal.term _ => false + | FieldVal.default => true + | FieldVal.nested s => Struct.allDefault s + def Struct.ref : Struct → Syntax | ⟨ref, _, _, _⟩ => ref @@ -493,6 +500,11 @@ annotation? `structInstDefault e def throwFailedToElabField {α} (fieldName : Name) (structName : Name) (msgData : MessageData) : TermElabM α := throwError ("failed to elaborate field '" ++ fieldName ++ "' of '" ++ structName ++ ", " ++ msgData) +def trySynthStructInstance? (s : Struct) (expectedType : Expr) : TermElabM (Option Expr) := +if !s.allDefault then pure none +else + catch (synthInstance? expectedType) (fun _ => pure none) + private partial def elabStruct : Struct → Option Expr → TermElabM (Expr × Struct) | s, expectedType? => withRef s.ref do env ← getEnv; @@ -514,7 +526,11 @@ private partial def elabStruct : Struct → Option Expr → TermElabM (Expr × S }; match field.val with | FieldVal.term stx => do val ← elabTermEnsuringType stx d; continue val field - | FieldVal.nested s => do (val, sNew) ← elabStruct s (some d); val ← ensureHasType d val; continue val { field with val := FieldVal.nested sNew } + | FieldVal.nested s => do + val? ← trySynthStructInstance? s d; -- if all fields of `s` are marked as `default`, then try to synthesize instance + match val? with + | some val => continue val { field with val := FieldVal.term (mkHole field.ref) } + | none => do(val, sNew) ← elabStruct s (some d); val ← ensureHasType d val; continue val { field with val := FieldVal.nested sNew } | FieldVal.default => do val ← withRef field.ref $ mkFreshExprMVar (some d); continue (markDefaultMissing val) field | _ => withRef field.ref $ throwFailedToElabField fieldName s.structName ("unexpected constructor type" ++ indentExpr type) | _ => throwErrorAt field.ref "unexpected unexpanded structure field") diff --git a/tests/lean/run/unexpected_result_with_bind.lean b/tests/lean/run/unexpected_result_with_bind.lean index 14fc46b635..4d454bc6f4 100644 --- a/tests/lean/run/unexpected_result_with_bind.lean +++ b/tests/lean/run/unexpected_result_with_bind.lean @@ -1,3 +1,4 @@ +new_frontend namespace Repro def FooM (α : Type) : Type := Unit → α @@ -6,7 +7,7 @@ def FooM.run {α : Type} (ψ : FooM α) (x : Unit) : α := ψ x def bind {α β : Type} : ∀ (ψ₁ : FooM α) (ψ₂ : α → FooM β), FooM β | ψ₁, ψ₂ => λ _ => ψ₂ (ψ₁.run ()) () -instance : HasPure FooM := ⟨λ _ x => λ _ => x⟩ +instance : HasPure FooM := ⟨λ x => λ _ => x⟩ instance : HasBind FooM := ⟨@bind⟩ instance : Monad FooM := {}