feat: allow multiple sources in the structure instance parser

This commit also fixes some macros, and make sure the elaborator still
works, but it does not support multiple sources yet.
This commit is contained in:
Leonardo de Moura 2021-08-11 13:07:56 -07:00
parent efb3f528a6
commit 09c2b668e6
2 changed files with 47 additions and 33 deletions

View file

@ -37,28 +37,39 @@ If `stx` is of the form `{ s with ... }` and `s` is not a local variable, expand
Note that this one is not a `Macro` because we need to access the local context.
-/
private def expandNonAtomicExplicitSource (stx : Syntax) : TermElabM (Option Syntax) :=
withFreshMacroScope do
let sourceOpt := stx[1]
if sourceOpt.isNone then
pure none
else
let source := sourceOpt[0]
match (← isLocalIdent? source) with
| some _ => pure none
| none =>
if source.isMissing then
throwAbortTerm
else
let src ← `(src)
let sourceOpt := sourceOpt.setArg 0 src
let stxNew := stx.setArg 1 sourceOpt
`(let src := $source; $stxNew)
private def expandNonAtomicExplicitSources (stx : Syntax) : TermElabM (Option Syntax) := do
let sourcesOpt := stx[1]
if sourcesOpt.isNone then
pure none
else
let sources := sourcesOpt[0]
if sources.isMissing then
throwAbortTerm
let sources := sources.getSepArgs
if (← sources.allM fun source => return (← isLocalIdent? source).isSome) then
return none
if sources.any (·.isMissing) then
throwAbortTerm
go sources.toList #[]
where
go (sources : List Syntax) (sourcesNew : Array Syntax) : TermElabM Syntax := do
match sources with
| [] =>
let sources := Syntax.mkSep sourcesNew (mkAtomFrom stx ", ")
return stx.setArg 1 (stx[1].setArg 0 sources)
| source :: sources =>
if (← isLocalIdent? source).isSome then
go sources (sourcesNew.push source)
else
withFreshMacroScope do
let sourceNew ← `(src)
let r ← go sources (sourcesNew.push sourceNew)
`(let src := $source; $r)
inductive Source where
| none -- structure instance source has not been provieded
| implicit (stx : Syntax) -- `..`
| explicit (stx : Syntax) (src : Expr) -- `src with`
| explicit (stx : Syntax) (srcs : Array Expr) -- `src with`
deriving Inhabited
def Source.isNone : Source → Bool
@ -79,10 +90,9 @@ private def getStructSource (stx : Syntax) : TermElabM Source :=
else if explicitSource.isNone then
return Source.implicit implicitSource
else if implicitSource[0].isNone then
let fvar? ← isLocalIdent? explicitSource[0]
match fvar? with
| none => unreachable! -- expandNonAtomicExplicitSource must have been used when we get here
| some src => return Source.explicit explicitSource src
let srcs ← explicitSource[0].getSepArgs.mapM fun src => do
if let some fvar ← isLocalIdent? src then fvar else unreachable!
return Source.explicit explicitSource srcs
else
throwError "invalid structure instance `with` and `..` cannot be used together"
@ -128,10 +138,12 @@ private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do
| some s => if s[0][0].getKind == ``Lean.Parser.Term.structInstArrayRef then pure s? else pure none
private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
if source[0].getSepArgs.size > 1 then
throwError "invalid \{...} notation, multiple sources and array update is not supported."
let cont (val : Syntax) : TermElabM Expr := do
let lval := modifyOp[0][0]
let idx := lval[1]
let self := source[0]
let self := source[0][0]
let stxNew ← `($(self).modifyOp (idx := $idx) (fun s => $val))
trace[Elab.struct.modifyOp] "{stx}\n===>\n{stxNew}"
withMacroExpansion stx stxNew <| elabTerm stxNew expectedType?
@ -146,7 +158,7 @@ private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option
let restArgs := rest.getArgs
let valRest := mkNullNode restArgs[1:restArgs.size]
let valField := modifyOp.setArg 0 <| Syntax.node ``Parser.Term.structInstLVal #[valFirst, valRest]
let valSource := source.modifyArg 0 fun _ => s
let valSource := source.modifyArg 0 fun sep => sep.modifyArg 0 fun _ => s
let val := stx.setArg 1 valSource
let val := val.setArg 2 <| mkNullNode #[mkNullNode #[valField, mkNullNode]]
trace[Elab.struct.modifyOp] "{stx}\nval: {val}"
@ -157,8 +169,8 @@ private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceVi
tryPostponeIfNoneOrMVar expectedType?
let useSource : Unit → TermElabM (Name × Expr) := fun _ =>
match sourceView, expectedType? with
| Source.explicit _ src, _ => do
let srcType ← inferType src
| Source.explicit _ srcs, _ => do
let srcType ← inferType srcs[0]
let srcType ← whnf srcType
tryPostponeIfMVar srcType
match srcType.getAppFn with
@ -416,10 +428,11 @@ private def mkProjStx (s : Syntax) (fieldName : Name) : Syntax :=
private def mkSubstructSource (structName : Name) (fieldNames : Array Name) (fieldName : Name) (src : Source) : TermElabM Source :=
match src with
| Source.explicit stx src => do
| Source.explicit stx srcs => do
-- TODO: handle multiple sources
let idx ← getFieldIdx structName fieldNames fieldName
let stx := stx.modifyArg 0 fun stx => mkProjStx stx fieldName
return Source.explicit stx (mkProj structName idx src)
let stx := stx.modifyArg 0 fun stx => stx.modifyArg 0 fun stx => mkProjStx stx fieldName
return Source.explicit stx #[mkProj structName idx srcs[0]]
| s => return s
private def groupFields (expandStruct : Struct → TermElabM Struct) (s : Struct) : TermElabM Struct := do
@ -477,8 +490,8 @@ private def addMissingFields (expandStruct : Struct → TermElabM Struct) (s : S
| Source.none => addField FieldVal.default
| Source.implicit _ => addField (FieldVal.term (mkHole s.ref))
| Source.explicit stx _ =>
-- stx is of the form `optional (try (termParser >> "with"))`
let src := stx[0]
-- stx is of the form `optional (try (sepBy1 termParser ", " >> "with"))`
let src := stx[0][0] -- TODO -- add support for multiple sources
let val := mkProjStx src fieldName
addField (FieldVal.term val)
return s.setFields fields.reverse
@ -801,11 +814,12 @@ private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (sour
let struct ← expandStruct struct
trace[Elab.struct] "{struct}"
let (r, struct) ← elabStruct struct expectedType?
trace[Elab.struct] "before propagate {r}"
DefaultFields.propagate struct
return r
@[builtinTermElab structInst] def elabStructInst : TermElab := fun stx expectedType? => do
match (← expandNonAtomicExplicitSource stx) with
match (← expandNonAtomicExplicitSources stx) with
| some stxNew => withMacroExpansion stx stxNew <| elabTerm stxNew expectedType?
| none =>
let sourceView ← getStructSource stx

View file

@ -79,7 +79,7 @@ def structInstLVal := leading_parser (ident <|> fieldIdx <|> structInstArrayRe
def structInstField := ppGroup $ leading_parser structInstLVal >> " := " >> termParser
def structInstFieldAbbrev := leading_parser atomic (ident >> notFollowedBy ("." <|> ":=" <|> symbol "[") "invalid field abbreviation") -- `x` is an abbreviation for `x := x`
def optEllipsis := leading_parser optional ".."
@[builtinTermParser] def structInst := leading_parser "{" >> ppHardSpace >> optional (atomic (termParser >> " with "))
@[builtinTermParser] def structInst := leading_parser "{" >> ppHardSpace >> optional (atomic (sepBy1 termParser ", " >> " with "))
>> manyIndent (group ((structInstFieldAbbrev <|> structInstField) >> optional ", "))
>> optEllipsis
>> optional (" : " >> termParser) >> " }"