diff --git a/src/Lean/Elab/StructInst.lean b/src/Lean/Elab/StructInst.lean index 0bba50274f..91fa6e9e06 100644 --- a/src/Lean/Elab/StructInst.lean +++ b/src/Lean/Elab/StructInst.lean @@ -37,28 +37,39 @@ If `stx` is of the form `{ s with ... }` and `s` is not a local variable, expand Note that this one is not a `Macro` because we need to access the local context. -/ -private def expandNonAtomicExplicitSource (stx : Syntax) : TermElabM (Option Syntax) := - withFreshMacroScope do - let sourceOpt := stx[1] - if sourceOpt.isNone then - pure none - else - let source := sourceOpt[0] - match (← isLocalIdent? source) with - | some _ => pure none - | none => - if source.isMissing then - throwAbortTerm - else - let src ← `(src) - let sourceOpt := sourceOpt.setArg 0 src - let stxNew := stx.setArg 1 sourceOpt - `(let src := $source; $stxNew) +private def expandNonAtomicExplicitSources (stx : Syntax) : TermElabM (Option Syntax) := do + let sourcesOpt := stx[1] + if sourcesOpt.isNone then + pure none + else + let sources := sourcesOpt[0] + if sources.isMissing then + throwAbortTerm + let sources := sources.getSepArgs + if (← sources.allM fun source => return (← isLocalIdent? source).isSome) then + return none + if sources.any (·.isMissing) then + throwAbortTerm + go sources.toList #[] +where + go (sources : List Syntax) (sourcesNew : Array Syntax) : TermElabM Syntax := do + match sources with + | [] => + let sources := Syntax.mkSep sourcesNew (mkAtomFrom stx ", ") + return stx.setArg 1 (stx[1].setArg 0 sources) + | source :: sources => + if (← isLocalIdent? source).isSome then + go sources (sourcesNew.push source) + else + withFreshMacroScope do + let sourceNew ← `(src) + let r ← go sources (sourcesNew.push sourceNew) + `(let src := $source; $r) inductive Source where | none -- structure instance source has not been provieded | implicit (stx : Syntax) -- `..` - | explicit (stx : Syntax) (src : Expr) -- `src with` + | explicit (stx : Syntax) (srcs : Array Expr) -- `src with` deriving Inhabited def Source.isNone : Source → Bool @@ -79,10 +90,9 @@ private def getStructSource (stx : Syntax) : TermElabM Source := else if explicitSource.isNone then return Source.implicit implicitSource else if implicitSource[0].isNone then - let fvar? ← isLocalIdent? explicitSource[0] - match fvar? with - | none => unreachable! -- expandNonAtomicExplicitSource must have been used when we get here - | some src => return Source.explicit explicitSource src + let srcs ← explicitSource[0].getSepArgs.mapM fun src => do + if let some fvar ← isLocalIdent? src then fvar else unreachable! + return Source.explicit explicitSource srcs else throwError "invalid structure instance `with` and `..` cannot be used together" @@ -128,10 +138,12 @@ private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do | some s => if s[0][0].getKind == ``Lean.Parser.Term.structInstArrayRef then pure s? else pure none private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do + if source[0].getSepArgs.size > 1 then + throwError "invalid \{...} notation, multiple sources and array update is not supported." let cont (val : Syntax) : TermElabM Expr := do let lval := modifyOp[0][0] let idx := lval[1] - let self := source[0] + let self := source[0][0] let stxNew ← `($(self).modifyOp (idx := $idx) (fun s => $val)) trace[Elab.struct.modifyOp] "{stx}\n===>\n{stxNew}" withMacroExpansion stx stxNew <| elabTerm stxNew expectedType? @@ -146,7 +158,7 @@ private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option let restArgs := rest.getArgs let valRest := mkNullNode restArgs[1:restArgs.size] let valField := modifyOp.setArg 0 <| Syntax.node ``Parser.Term.structInstLVal #[valFirst, valRest] - let valSource := source.modifyArg 0 fun _ => s + let valSource := source.modifyArg 0 fun sep => sep.modifyArg 0 fun _ => s let val := stx.setArg 1 valSource let val := val.setArg 2 <| mkNullNode #[mkNullNode #[valField, mkNullNode]] trace[Elab.struct.modifyOp] "{stx}\nval: {val}" @@ -157,8 +169,8 @@ private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceVi tryPostponeIfNoneOrMVar expectedType? let useSource : Unit → TermElabM (Name × Expr) := fun _ => match sourceView, expectedType? with - | Source.explicit _ src, _ => do - let srcType ← inferType src + | Source.explicit _ srcs, _ => do + let srcType ← inferType srcs[0] let srcType ← whnf srcType tryPostponeIfMVar srcType match srcType.getAppFn with @@ -416,10 +428,11 @@ private def mkProjStx (s : Syntax) (fieldName : Name) : Syntax := private def mkSubstructSource (structName : Name) (fieldNames : Array Name) (fieldName : Name) (src : Source) : TermElabM Source := match src with - | Source.explicit stx src => do + | Source.explicit stx srcs => do + -- TODO: handle multiple sources let idx ← getFieldIdx structName fieldNames fieldName - let stx := stx.modifyArg 0 fun stx => mkProjStx stx fieldName - return Source.explicit stx (mkProj structName idx src) + let stx := stx.modifyArg 0 fun stx => stx.modifyArg 0 fun stx => mkProjStx stx fieldName + return Source.explicit stx #[mkProj structName idx srcs[0]] | s => return s private def groupFields (expandStruct : Struct → TermElabM Struct) (s : Struct) : TermElabM Struct := do @@ -477,8 +490,8 @@ private def addMissingFields (expandStruct : Struct → TermElabM Struct) (s : S | Source.none => addField FieldVal.default | Source.implicit _ => addField (FieldVal.term (mkHole s.ref)) | Source.explicit stx _ => - -- stx is of the form `optional (try (termParser >> "with"))` - let src := stx[0] + -- stx is of the form `optional (try (sepBy1 termParser ", " >> "with"))` + let src := stx[0][0] -- TODO -- add support for multiple sources let val := mkProjStx src fieldName addField (FieldVal.term val) return s.setFields fields.reverse @@ -801,11 +814,12 @@ private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (sour let struct ← expandStruct struct trace[Elab.struct] "{struct}" let (r, struct) ← elabStruct struct expectedType? + trace[Elab.struct] "before propagate {r}" DefaultFields.propagate struct return r @[builtinTermElab structInst] def elabStructInst : TermElab := fun stx expectedType? => do - match (← expandNonAtomicExplicitSource stx) with + match (← expandNonAtomicExplicitSources stx) with | some stxNew => withMacroExpansion stx stxNew <| elabTerm stxNew expectedType? | none => let sourceView ← getStructSource stx diff --git a/src/Lean/Parser/Term.lean b/src/Lean/Parser/Term.lean index 515e2b65bf..4b7db28d8c 100644 --- a/src/Lean/Parser/Term.lean +++ b/src/Lean/Parser/Term.lean @@ -79,7 +79,7 @@ def structInstLVal := leading_parser (ident <|> fieldIdx <|> structInstArrayRe def structInstField := ppGroup $ leading_parser structInstLVal >> " := " >> termParser def structInstFieldAbbrev := leading_parser atomic (ident >> notFollowedBy ("." <|> ":=" <|> symbol "[") "invalid field abbreviation") -- `x` is an abbreviation for `x := x` def optEllipsis := leading_parser optional ".." -@[builtinTermParser] def structInst := leading_parser "{" >> ppHardSpace >> optional (atomic (termParser >> " with ")) +@[builtinTermParser] def structInst := leading_parser "{" >> ppHardSpace >> optional (atomic (sepBy1 termParser ", " >> " with ")) >> manyIndent (group ((structInstFieldAbbrev <|> structInstField) >> optional ", ")) >> optEllipsis >> optional (" : " >> termParser) >> " }"