feat: support {s with ..}

This commit is contained in:
Mario Carneiro 2022-09-11 23:40:46 -04:00 committed by Leonardo de Moura
parent d67546e388
commit adc215dab9
2 changed files with 47 additions and 57 deletions

View file

@ -86,15 +86,14 @@ structure ExplicitSourceInfo where
structName : Name
deriving Inhabited
inductive Source where
| none -- structure instance source has not been provieded
| implicit (stx : Syntax) -- `..`
| explicit (sources : Array ExplicitSourceInfo) -- `s₁ ... sₙ with`
structure Source where
explicit : Array ExplicitSourceInfo -- `s₁ ... sₙ with`
implicit : Option Syntax -- `..`
deriving Inhabited
def Source.isNone : Source → Bool
| .none => true
| _ => false
| { explicit := #[], implicit := none } => true
| _ => false
/-- `optional (atomic (sepBy1 termParser ", " >> " with ")` -/
private def mkSourcesWithSyntax (sources : Array Syntax) : Syntax :=
@ -106,21 +105,18 @@ private def getStructSource (structStx : Syntax) : TermElabM Source :=
withRef structStx do
let explicitSource := structStx[1]
let implicitSource := structStx[3]
if explicitSource.isNone && implicitSource[0].isNone then
return .none
else if explicitSource.isNone then
return .implicit implicitSource
else if implicitSource[0].isNone then
let sources ← explicitSource[0].getSepArgs.mapM fun stx => do
let explicit ← if explicitSource.isNone then
pure #[]
else
explicitSource[0].getSepArgs.mapM fun stx => do
let some src ← isLocalIdent? stx | unreachable!
addTermInfo' stx src
let srcType ← whnf (← inferType src)
tryPostponeIfMVar srcType
let structName ← getStructureName srcType
return { stx, structName }
return .explicit sources
else
throwError "invalid structure instance `with` and `..` cannot be used together"
let implicit := if implicitSource[0].isNone then none else implicitSource
return { explicit, implicit }
/--
We say a `{ ... }` notation is a `modifyOp` if it contains only one
@ -197,13 +193,15 @@ private def elabModifyOp (stx modifyOp : Syntax) (sources : Array ExplicitSource
private def getStructName (expectedType? : Option Expr) (sourceView : Source) : TermElabM Name := do
tryPostponeIfNoneOrMVar expectedType?
let useSource : Unit → TermElabM Name := fun _ => do
match sourceView, expectedType? with
| .explicit sources, _ =>
if sources.size > 1 then
throwErrorAt sources[1]!.stx "invalid \{...} notation, expected type is not known, using the type of the first source, extra sources are not needed"
return sources[0]!.structName
| _, some expectedType => throwUnexpectedExpectedType expectedType
| _, none => throwUnknownExpectedType
match sourceView.explicit.size with
| 0 =>
match expectedType? with
| some expectedType => throwUnexpectedExpectedType expectedType
| none => throwUnknownExpectedType
| 1 => return sourceView.explicit[0]!.structName
| _ =>
throwErrorAt sourceView.explicit[1]!.stx
"invalid \{...} notation, expected type is not known, using the type of the first source, extra sources are not needed"
match expectedType? with
| none => useSource ()
| some expectedType =>
@ -292,10 +290,11 @@ def formatField (formatStruct : Struct → Format) (field : Field Struct) : Form
partial def formatStruct : Struct → Format
| ⟨_, _, _, fields, source⟩ =>
let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", "
match source with
| .none => "{" ++ fieldsFmt ++ "}"
| .implicit _ => "{" ++ fieldsFmt ++ " .. }"
| .explicit sources => "{" ++ format (sources.map (·.stx)) ++ " with " ++ fieldsFmt ++ "}"
let implicitFmt := if source.implicit.isSome then " .. " else ""
if source.explicit.isEmpty then
"{" ++ fieldsFmt ++ implicitFmt ++ "}"
else
"{" ++ format (source.explicit.map (·.stx)) ++ " with " ++ fieldsFmt ++ implicitFmt ++ "}"
instance : ToFormat Struct := ⟨formatStruct⟩
instance : ToString Struct := ⟨toString ∘ format⟩
@ -487,15 +486,10 @@ mutual
pure { field with lhs := [field.lhs.head!], val := FieldVal.nested substruct }
| none =>
let updateSource (structStx : Syntax) : TermElabM Syntax := do
match s.source with
| .none => return (structStx.setArg 1 mkNullNode).setArg 3 mkNullNode
| .implicit stx => return (structStx.setArg 1 mkNullNode).setArg 3 stx
| .explicit sources =>
let sourcesNew ← sources.filterMapM fun source => mkProjStx? source.stx source.structName fieldName
if sourcesNew.isEmpty then
return (structStx.setArg 1 mkNullNode).setArg 3 mkNullNode
else
return (structStx.setArg 1 (mkSourcesWithSyntax sourcesNew)).setArg 3 mkNullNode
let sourcesNew ← s.source.explicit.filterMapM fun source => mkProjStx? source.stx source.structName fieldName
let explicitSourceStx := if sourcesNew.isEmpty then mkNullNode else mkSourcesWithSyntax sourcesNew
let implicitSourceStx := s.source.implicit.getD mkNullNode
return (structStx.setArg 1 explicitSourceStx).setArg 3 implicitSourceStx
let valStx := s.ref -- construct substructure syntax using s.ref as template
let valStx := valStx.setArg 4 mkNullNode -- erase optional expected type
let args := substructFields.toArray.map (·.toSyntax)
@ -516,28 +510,20 @@ mutual
return { ref, lhs := [FieldLHS.fieldName ref fieldName], val := val } :: fields
match Lean.isSubobjectField? env s.structName fieldName with
| some substructName =>
let addSubstruct : TermElabM Fields := do
-- If one of the sources has the subobject field, use it
if let some val ← s.source.explicit.findSomeM? fun source => mkProjStx? source.stx source.structName fieldName then
addField (FieldVal.term val)
else
let substruct := Struct.mk ref substructName #[] [] s.source
let substruct ← expandStruct substruct
addField (FieldVal.nested substruct)
match s.source with
| .none => addSubstruct
| .implicit _ => addSubstruct
| .explicit sources =>
-- If one of the sources has the subobject field, use it
if let some val ← sources.findSomeM? fun source => mkProjStx? source.stx source.structName fieldName then
addField (FieldVal.term val)
else
addSubstruct
| none =>
match s.source with
| .none => addField FieldVal.default
| .implicit _ => addField (FieldVal.term (mkHole ref))
| .explicit sources =>
if let some val ← sources.findSomeM? fun source => mkProjStx? source.stx source.structName fieldName then
addField (FieldVal.term val)
else
addField FieldVal.default
if let some val ← s.source.explicit.findSomeM? fun source => mkProjStx? source.stx source.structName fieldName then
addField (FieldVal.term val)
else if s.source.implicit.isSome then
addField (FieldVal.term (mkHole ref))
else
addField FieldVal.default
return s.setFields fields.reverse
private partial def expandStruct (s : Struct) : TermElabM Struct := do
@ -922,10 +908,12 @@ private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (sour
| some stxNew => withMacroExpansion stx stxNew <| elabTerm stxNew expectedType?
| none =>
let sourceView ← getStructSource stx
match (← isModifyOp? stx), sourceView with
| some modifyOp, .explicit sources => elabModifyOp stx modifyOp sources expectedType?
| some _, _ => throwError "invalid \{...} notation, explicit source is required when using '[<index>] := <value>'"
| _, _ => elabStructInstAux stx expectedType? sourceView
if let some modifyOp ← isModifyOp? stx then
if sourceView.explicit.isEmpty then
throwError "invalid \{...} notation, explicit source is required when using '[<index>] := <value>'"
elabModifyOp stx modifyOp sourceView.explicit expectedType?
else
elabStructInstAux stx expectedType? sourceView
builtin_initialize registerTraceClass `Elab.struct

View file

@ -37,3 +37,5 @@ def c2 : C (Nat × Nat) := { z := (1, 1) }
#check { c2 with x.fst := 2 }
#check { c2 with x.1 := 3 }
#check show C _ from { c2.toB with .. }