From adc215dab98daf5a683587e7750bb1332f762f88 Mon Sep 17 00:00:00 2001 From: Mario Carneiro Date: Sun, 11 Sep 2022 23:40:46 -0400 Subject: [PATCH] feat: support `{s with ..}` --- src/Lean/Elab/StructInst.lean | 102 ++++++++++++++------------------ tests/lean/run/structInst3.lean | 2 + 2 files changed, 47 insertions(+), 57 deletions(-) diff --git a/src/Lean/Elab/StructInst.lean b/src/Lean/Elab/StructInst.lean index 7b4ddb6195..3e93a9582d 100644 --- a/src/Lean/Elab/StructInst.lean +++ b/src/Lean/Elab/StructInst.lean @@ -86,15 +86,14 @@ structure ExplicitSourceInfo where structName : Name deriving Inhabited -inductive Source where - | none -- structure instance source has not been provieded - | implicit (stx : Syntax) -- `..` - | explicit (sources : Array ExplicitSourceInfo) -- `s₁ ... sₙ with` +structure Source where + explicit : Array ExplicitSourceInfo -- `s₁ ... sₙ with` + implicit : Option Syntax -- `..` deriving Inhabited def Source.isNone : Source → Bool - | .none => true - | _ => false + | { explicit := #[], implicit := none } => true + | _ => false /-- `optional (atomic (sepBy1 termParser ", " >> " with ")` -/ private def mkSourcesWithSyntax (sources : Array Syntax) : Syntax := @@ -106,21 +105,18 @@ private def getStructSource (structStx : Syntax) : TermElabM Source := withRef structStx do let explicitSource := structStx[1] let implicitSource := structStx[3] - if explicitSource.isNone && implicitSource[0].isNone then - return .none - else if explicitSource.isNone then - return .implicit implicitSource - else if implicitSource[0].isNone then - let sources ← explicitSource[0].getSepArgs.mapM fun stx => do + let explicit ← if explicitSource.isNone then + pure #[] + else + explicitSource[0].getSepArgs.mapM fun stx => do let some src ← isLocalIdent? stx | unreachable! addTermInfo' stx src let srcType ← whnf (← inferType src) tryPostponeIfMVar srcType let structName ← getStructureName srcType return { stx, structName } - return .explicit sources - else - throwError "invalid structure instance `with` and `..` cannot be used together" + let implicit := if implicitSource[0].isNone then none else implicitSource + return { explicit, implicit } /-- We say a `{ ... }` notation is a `modifyOp` if it contains only one @@ -197,13 +193,15 @@ private def elabModifyOp (stx modifyOp : Syntax) (sources : Array ExplicitSource private def getStructName (expectedType? : Option Expr) (sourceView : Source) : TermElabM Name := do tryPostponeIfNoneOrMVar expectedType? let useSource : Unit → TermElabM Name := fun _ => do - match sourceView, expectedType? with - | .explicit sources, _ => - if sources.size > 1 then - throwErrorAt sources[1]!.stx "invalid \{...} notation, expected type is not known, using the type of the first source, extra sources are not needed" - return sources[0]!.structName - | _, some expectedType => throwUnexpectedExpectedType expectedType - | _, none => throwUnknownExpectedType + match sourceView.explicit.size with + | 0 => + match expectedType? with + | some expectedType => throwUnexpectedExpectedType expectedType + | none => throwUnknownExpectedType + | 1 => return sourceView.explicit[0]!.structName + | _ => + throwErrorAt sourceView.explicit[1]!.stx + "invalid \{...} notation, expected type is not known, using the type of the first source, extra sources are not needed" match expectedType? with | none => useSource () | some expectedType => @@ -292,10 +290,11 @@ def formatField (formatStruct : Struct → Format) (field : Field Struct) : Form partial def formatStruct : Struct → Format | ⟨_, _, _, fields, source⟩ => let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", " - match source with - | .none => "{" ++ fieldsFmt ++ "}" - | .implicit _ => "{" ++ fieldsFmt ++ " .. }" - | .explicit sources => "{" ++ format (sources.map (·.stx)) ++ " with " ++ fieldsFmt ++ "}" + let implicitFmt := if source.implicit.isSome then " .. " else "" + if source.explicit.isEmpty then + "{" ++ fieldsFmt ++ implicitFmt ++ "}" + else + "{" ++ format (source.explicit.map (·.stx)) ++ " with " ++ fieldsFmt ++ implicitFmt ++ "}" instance : ToFormat Struct := ⟨formatStruct⟩ instance : ToString Struct := ⟨toString ∘ format⟩ @@ -487,15 +486,10 @@ mutual pure { field with lhs := [field.lhs.head!], val := FieldVal.nested substruct } | none => let updateSource (structStx : Syntax) : TermElabM Syntax := do - match s.source with - | .none => return (structStx.setArg 1 mkNullNode).setArg 3 mkNullNode - | .implicit stx => return (structStx.setArg 1 mkNullNode).setArg 3 stx - | .explicit sources => - let sourcesNew ← sources.filterMapM fun source => mkProjStx? source.stx source.structName fieldName - if sourcesNew.isEmpty then - return (structStx.setArg 1 mkNullNode).setArg 3 mkNullNode - else - return (structStx.setArg 1 (mkSourcesWithSyntax sourcesNew)).setArg 3 mkNullNode + let sourcesNew ← s.source.explicit.filterMapM fun source => mkProjStx? source.stx source.structName fieldName + let explicitSourceStx := if sourcesNew.isEmpty then mkNullNode else mkSourcesWithSyntax sourcesNew + let implicitSourceStx := s.source.implicit.getD mkNullNode + return (structStx.setArg 1 explicitSourceStx).setArg 3 implicitSourceStx let valStx := s.ref -- construct substructure syntax using s.ref as template let valStx := valStx.setArg 4 mkNullNode -- erase optional expected type let args := substructFields.toArray.map (·.toSyntax) @@ -516,28 +510,20 @@ mutual return { ref, lhs := [FieldLHS.fieldName ref fieldName], val := val } :: fields match Lean.isSubobjectField? env s.structName fieldName with | some substructName => - let addSubstruct : TermElabM Fields := do + -- If one of the sources has the subobject field, use it + if let some val ← s.source.explicit.findSomeM? fun source => mkProjStx? source.stx source.structName fieldName then + addField (FieldVal.term val) + else let substruct := Struct.mk ref substructName #[] [] s.source let substruct ← expandStruct substruct addField (FieldVal.nested substruct) - match s.source with - | .none => addSubstruct - | .implicit _ => addSubstruct - | .explicit sources => - -- If one of the sources has the subobject field, use it - if let some val ← sources.findSomeM? fun source => mkProjStx? source.stx source.structName fieldName then - addField (FieldVal.term val) - else - addSubstruct | none => - match s.source with - | .none => addField FieldVal.default - | .implicit _ => addField (FieldVal.term (mkHole ref)) - | .explicit sources => - if let some val ← sources.findSomeM? fun source => mkProjStx? source.stx source.structName fieldName then - addField (FieldVal.term val) - else - addField FieldVal.default + if let some val ← s.source.explicit.findSomeM? fun source => mkProjStx? source.stx source.structName fieldName then + addField (FieldVal.term val) + else if s.source.implicit.isSome then + addField (FieldVal.term (mkHole ref)) + else + addField FieldVal.default return s.setFields fields.reverse private partial def expandStruct (s : Struct) : TermElabM Struct := do @@ -922,10 +908,12 @@ private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (sour | some stxNew => withMacroExpansion stx stxNew <| elabTerm stxNew expectedType? | none => let sourceView ← getStructSource stx - match (← isModifyOp? stx), sourceView with - | some modifyOp, .explicit sources => elabModifyOp stx modifyOp sources expectedType? - | some _, _ => throwError "invalid \{...} notation, explicit source is required when using '[] := '" - | _, _ => elabStructInstAux stx expectedType? sourceView + if let some modifyOp ← isModifyOp? stx then + if sourceView.explicit.isEmpty then + throwError "invalid \{...} notation, explicit source is required when using '[] := '" + elabModifyOp stx modifyOp sourceView.explicit expectedType? + else + elabStructInstAux stx expectedType? sourceView builtin_initialize registerTraceClass `Elab.struct diff --git a/tests/lean/run/structInst3.lean b/tests/lean/run/structInst3.lean index b80d12edc2..136a250408 100644 --- a/tests/lean/run/structInst3.lean +++ b/tests/lean/run/structInst3.lean @@ -37,3 +37,5 @@ def c2 : C (Nat × Nat) := { z := (1, 1) } #check { c2 with x.fst := 2 } #check { c2 with x.1 := 3 } + +#check show C _ from { c2.toB with .. }