diff --git a/src/Lean/Elab/StructInst.lean b/src/Lean/Elab/StructInst.lean index 62267966c3..28546caf98 100644 --- a/src/Lean/Elab/StructInst.lean +++ b/src/Lean/Elab/StructInst.lean @@ -5,6 +5,7 @@ Authors: Leonardo de Moura -/ import Lean.Util.FindExpr import Lean.Parser.Term +import Lean.Meta.Structure import Lean.Elab.App import Lean.Elab.Binders @@ -81,10 +82,15 @@ where let r ← go sources (sourcesNew.push sourceNew) `(let src := $source; $r) +structure ExplicitSourceInfo where + stx : Syntax + structName : Name + deriving Inhabited + inductive Source where | none -- structure instance source has not been provieded | implicit (stx : Syntax) -- `..` - | explicit (sources : Array Syntax) -- `s₁ ... sₙ with` + | explicit (sources : Array ExplicitSourceInfo) -- `s₁ ... sₙ with` deriving Inhabited def Source.isNone : Source → Bool @@ -101,7 +107,7 @@ def setStructSourceSyntax (structStx : Syntax) : Source → Syntax | Source.none => (structStx.setArg 1 mkNullNode).setArg 3 mkNullNode | Source.implicit stx => (structStx.setArg 1 mkNullNode).setArg 3 stx | Source.explicit sources => - let stx := mkSourcesWithSyntax sources + let stx := mkSourcesWithSyntax (sources.map (·.stx)) (structStx.setArg 1 stx).setArg 3 mkNullNode private def getStructSource (structStx : Syntax) : TermElabM Source := @@ -113,7 +119,13 @@ private def getStructSource (structStx : Syntax) : TermElabM Source := else if explicitSource.isNone then return Source.implicit implicitSource else if implicitSource[0].isNone then - return Source.explicit explicitSource[0].getSepArgs + let sources ← explicitSource[0].getSepArgs.mapM fun stx => do + let some src ← isLocalIdent? stx | unreachable! + let srcType ← whnf (← inferType src) + tryPostponeIfMVar srcType + let structName ← getStructureName srcType + return { stx, structName } + return Source.explicit sources else throwError "invalid structure instance `with` and `..` cannot be used together" @@ -158,17 +170,16 @@ private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do | none => pure none | some s => if s[0][0].getKind == ``Lean.Parser.Term.structInstArrayRef then pure s? else pure none -private def elabModifyOp (stx modifyOp : Syntax) (sources : Array Syntax) (expectedType? : Option Expr) : TermElabM Expr := do +private def elabModifyOp (stx modifyOp : Syntax) (sources : Array ExplicitSourceInfo) (expectedType? : Option Expr) : TermElabM Expr := do if sources.size > 1 then throwError "invalid \{...} notation, multiple sources and array update is not supported." let cont (val : Syntax) : TermElabM Expr := do let lval := modifyOp[0][0] let idx := lval[1] - let self := sources[0] + let self := sources[0].stx let stxNew ← `($(self).modifyOp (idx := $idx) (fun s => $val)) trace[Elab.struct.modifyOp] "{stx}\n===>\n{stxNew}" withMacroExpansion stx stxNew <| elabTerm stxNew expectedType? - trace[Elab.struct.modifyOp] "{modifyOp}\nSource: {sources}" let rest := modifyOp[0][1] if rest.isNone then cont modifyOp[2] @@ -186,23 +197,16 @@ private def elabModifyOp (stx modifyOp : Syntax) (sources : Array Syntax) (expec cont val /-- - Get structure name and type. + Get structure name. This method triest to postpone execution if the expected type is not available. If the expected type is available and it is a structure, then we use it. Otherwise, we use the type of the first source. -/ -private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceView : Source) : TermElabM (Name × Expr) := do +private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceView : Source) : TermElabM Name := do tryPostponeIfNoneOrMVar expectedType? - let useSource : Unit → TermElabM (Name × Expr) := fun _ => + let useSource : Unit → TermElabM Name := fun _ => match sourceView, expectedType? with - | Source.explicit sources, _ => do - let some src ← isLocalIdent? sources[0] | unreachable! - let srcType ← inferType src - let srcType ← whnf srcType - tryPostponeIfMVar srcType - match srcType.getAppFn with - | Expr.const constName _ _ => return (constName, srcType) - | _ => throwUnexpectedExpectedType srcType "source" + | Source.explicit sources, _ => return sources[0].structName | _, some expectedType => throwUnexpectedExpectedType expectedType | _, none => throwUnknownExpectedType match expectedType? with @@ -210,7 +214,10 @@ private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceVi | some expectedType => let expectedType ← whnf expectedType match expectedType.getAppFn with - | Expr.const constName _ _ => return (constName, expectedType) + | Expr.const constName _ _ => + unless isStructure (← getEnv) constName do + throwError "invalid \{...} notation, structure type expected{indentExpr expectedType}" + return constName | _ => useSource () where throwUnknownExpectedType := @@ -287,9 +294,9 @@ partial def formatStruct : Struct → Format | ⟨_, structName, fields, source⟩ => let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", " match source with - | Source.none => "{" ++ fieldsFmt ++ "}" - | Source.implicit _ => "{" ++ fieldsFmt ++ " .. }" - | Source.explicit stx => "{" ++ format stx ++ " with " ++ fieldsFmt ++ "}" + | Source.none => "{" ++ fieldsFmt ++ "}" + | Source.implicit _ => "{" ++ fieldsFmt ++ " .. }" + | Source.explicit sources => "{" ++ format (sources.map (·.stx)) ++ " with " ++ fieldsFmt ++ "}" instance : ToFormat Struct := ⟨formatStruct⟩ instance : ToString Struct := ⟨toString ∘ format⟩ @@ -453,7 +460,9 @@ private def mkProjStx (s : Syntax) (fieldName : Name) : Syntax := private def mkSubstructSource (structName : Name) (fieldName : Name) (src : Source) : TermElabM Source := match src with | Source.explicit sources => do - let sources := sources.map fun sources => mkProjStx sources fieldName + -- Remark: we are not updating the source `structName` here. It is fine for now since the + -- updated value will only be used after we delete this code. + let sources := sources.map fun source => { source with stx := mkProjStx source.stx fieldName } return Source.explicit sources | s => return s @@ -520,7 +529,7 @@ mutual | Source.explicit sources => /- TODO: find the first source that field `fieldName`, and add a path to it here. -/ -- stx is of the form `optional (try (sepBy1 termParser ", " >> "with"))` - let src := sources[0] -- TODO -- add support for multiple sources + let src := sources[0].stx -- TODO -- add support for multiple sources let val := mkProjStx src fieldName addField (FieldVal.term val) return s.setFields fields.reverse @@ -838,9 +847,7 @@ def propagate (struct : Struct) : TermElabM Unit := end DefaultFields private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (source : Source) : TermElabM Expr := do - let (structName, structType) ← getStructName stx expectedType? source - unless isStructure (← getEnv) structName do - throwError "invalid \{...} notation, structure type expected{indentExpr structType}" + let structName ← getStructName stx expectedType? source let struct ← liftMacroM <| mkStructView stx structName source let struct ← expandStruct struct trace[Elab.struct] "{struct}"