refactor: we can't elaborate substructure fields using elabTerm
Reason: derived structures may override/set the default value for substructure fields.
This commit is contained in:
parent
d9ca2751c2
commit
9986a653e2
2 changed files with 286 additions and 160 deletions
|
|
@ -5,12 +5,14 @@ Authors: Leonardo de Moura
|
|||
-/
|
||||
prelude
|
||||
import Init.Lean.Elab.Term
|
||||
import Init.Lean.Elab.TermApp
|
||||
import Init.Lean.Elab.TermBinders
|
||||
import Init.Lean.Elab.Quotation
|
||||
|
||||
namespace Lean
|
||||
namespace Elab
|
||||
namespace Term
|
||||
namespace StructInst
|
||||
|
||||
/- parser! symbol "{" appPrec >> optional (try (ident >> " . ")) >> sepBy (structInstField <|> structInstSource) ", " true >> "}" -/
|
||||
|
||||
|
|
@ -62,19 +64,30 @@ end ExpandNonAtomicExplicitSource
|
|||
/-
|
||||
If `stx` is of the form `{ ... .. s }` and `s` is not a local variable, expand into `let src := s; { ... .. src}`.
|
||||
-/
|
||||
def expandNonAtomicExplicitSource (stx : Syntax) : TermElabM (Option Syntax) :=
|
||||
private def expandNonAtomicExplicitSource (stx : Syntax) : TermElabM (Option Syntax) :=
|
||||
withFreshMacroScope $ (ExpandNonAtomicExplicitSource.main stx).run' {}
|
||||
|
||||
inductive SourceView
|
||||
inductive Source
|
||||
| none -- structure instance source has not been provieded
|
||||
| implicit -- `..`
|
||||
| explicit (stx : Syntax) -- `.. term`
|
||||
| implicit (stx : Syntax) -- `..`
|
||||
| explicit (stx : Syntax) (src : Expr) -- `.. term`
|
||||
|
||||
def SourceView.isNone : SourceView → Bool
|
||||
| SourceView.none => true
|
||||
| _ => false
|
||||
def Source.isNone : Source → Bool
|
||||
| Source.none => true
|
||||
| _ => false
|
||||
|
||||
private def getStructSource (stx : Syntax) : TermElabM SourceView :=
|
||||
instance Source.hasFormat : HasFormat Source :=
|
||||
⟨fun src => match src with
|
||||
| Source.none => ""
|
||||
| Source.implicit _ => " .."
|
||||
| Source.explicit _ src => " .. " ++ format src⟩
|
||||
|
||||
def Source.addSyntax : Source → Array Syntax → Array Syntax
|
||||
| Source.none, acc => acc
|
||||
| Source.implicit stx, acc => acc.push stx
|
||||
| Source.explicit stx _, acc => acc.push stx
|
||||
|
||||
private def getStructSource (stx : Syntax) : TermElabM Source :=
|
||||
let args := (stx.getArg 2).getArgs;
|
||||
args.foldSepByM
|
||||
(fun arg r =>
|
||||
|
|
@ -82,12 +95,16 @@ args.foldSepByM
|
|||
-- parser! ".." >> optional termParser
|
||||
if !r.isNone then throwError arg "source has already been specified"
|
||||
else
|
||||
let arg := arg.getArg 1;
|
||||
if arg.isNone then pure SourceView.implicit
|
||||
else pure $ SourceView.explicit (arg.getArg 0)
|
||||
let optTerm := arg.getArg 1;
|
||||
if optTerm.isNone then pure $ Source.implicit arg
|
||||
else do
|
||||
fvar? ← isLocalTermId? (optTerm.getArg 0);
|
||||
match fvar? with
|
||||
| none => unreachable! -- expandNonAtomicExplicitSource must have been used when we get here
|
||||
| some fvar => pure $ Source.explicit arg fvar
|
||||
else
|
||||
pure r)
|
||||
SourceView.none
|
||||
Source.none
|
||||
|
||||
/-
|
||||
We say a `{ ... }` notation is a `modifyOp` if it contains only one
|
||||
|
|
@ -121,11 +138,11 @@ match s? with
|
|||
| none => pure none
|
||||
| some s => if s.getKind == `Lean.Parser.Term.structInstArrayRef then pure s? else pure none
|
||||
|
||||
private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option Expr) : TermElabM Expr :=
|
||||
private def elabModifyOp (stx modifyOp : Syntax) (source : Expr) (expectedType? : Option Expr) : TermElabM Expr :=
|
||||
throwError stx ("WIP " ++ stx)
|
||||
|
||||
/- Get structure name and elaborate explicit source (if avialable) -/
|
||||
private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceView : SourceView) : TermElabM Name :=
|
||||
private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceView : Source) : TermElabM Name :=
|
||||
let arg := stx.getArg 1;
|
||||
if !arg.isNone then do
|
||||
pure $ arg.getIdAt 0
|
||||
|
|
@ -134,17 +151,13 @@ else do
|
|||
tryPostponeIfNoneOrMVar expectedType?;
|
||||
let useSource : Unit → TermElabM Name := fun _ =>
|
||||
match sourceView with
|
||||
| SourceView.explicit sourceStx => do
|
||||
fvar? ← isLocalTermId? sourceStx;
|
||||
match fvar? with
|
||||
| none => unreachable!
|
||||
| some fvar => do
|
||||
fvarType ← inferType stx fvar;
|
||||
fvarType ← whnf stx fvarType;
|
||||
tryPostponeIfMVar fvarType;
|
||||
match fvarType.getAppFn with
|
||||
| Expr.const constName _ _ => pure constName
|
||||
| _ => throwError stx ("invalid {...} notation, source type is not of the form (C ...)" ++ indentExpr fvarType)
|
||||
| Source.explicit _ src => do
|
||||
srcType ← inferType stx src;
|
||||
srcType ← whnf stx srcType;
|
||||
tryPostponeIfMVar srcType;
|
||||
match srcType.getAppFn with
|
||||
| Expr.const constName _ _ => pure constName
|
||||
| _ => throwError stx ("invalid {...} notation, source type is not of the form (C ...)" ++ indentExpr srcType)
|
||||
| _ => throwError ref ("invalid {...} notation, expected type is not of the form (C ...)" ++ indentExpr expectedType?.get!);
|
||||
match expectedType? with
|
||||
| none => useSource ()
|
||||
|
|
@ -154,74 +167,138 @@ else do
|
|||
| Expr.const constName _ _ => pure constName
|
||||
| _ => useSource ()
|
||||
|
||||
inductive FieldLHS
|
||||
| fieldName (ref : Syntax) (name : Name)
|
||||
| fieldIndex (ref : Syntax) (idx : Nat)
|
||||
| modifyOp (ref : Syntax) (index : Syntax)
|
||||
|
||||
instance FieldLHS.inhabited : Inhabited FieldLHS := ⟨FieldLHS.fieldName (arbitrary _) (arbitrary _)⟩
|
||||
instance FieldLHS.hasFormat : HasFormat FieldLHS :=
|
||||
⟨fun lhs => match lhs with
|
||||
| FieldLHS.fieldName _ n => fmt n
|
||||
| FieldLHS.fieldIndex _ i => fmt i
|
||||
| FieldLHS.modifyOp _ i => "[" ++ i.prettyPrint ++ "]"⟩
|
||||
|
||||
inductive FieldVal (σ : Type)
|
||||
| term {} (stx : Syntax) : FieldVal
|
||||
| nested (s : σ) : FieldVal
|
||||
|
||||
structure Field (σ : Type) :=
|
||||
mk {} :: (ref : Syntax) (lhs : List FieldLHS) (val : FieldVal σ)
|
||||
|
||||
instance Field.inhabited {σ} : Inhabited (Field σ) := ⟨⟨arbitrary _, [], FieldVal.term (arbitrary _)⟩⟩
|
||||
|
||||
def Field.isSimple {σ} : Field σ → Bool
|
||||
| { lhs := [_], .. } => true
|
||||
| _ => false
|
||||
|
||||
inductive Struct
|
||||
| mk (ref : Syntax) (structName : Name) (fields : List (Field Struct)) (source : Source)
|
||||
|
||||
abbrev Fields := List (Field Struct)
|
||||
|
||||
def Struct.ref : Struct → Syntax
|
||||
| ⟨ref, _, _, _⟩ => ref
|
||||
|
||||
def Struct.structName : Struct → Name
|
||||
| ⟨_, structName, _, _⟩ => structName
|
||||
|
||||
def Struct.fields : Struct → Fields
|
||||
| ⟨_, _, fields, _⟩ => fields
|
||||
|
||||
def Struct.source : Struct → Source
|
||||
| ⟨_, _, _, s⟩ => s
|
||||
|
||||
def formatField (formatStruct : Struct → Format) (field : Field Struct) : Format :=
|
||||
Format.joinSep field.lhs " . " ++ " := " ++
|
||||
match field.val with
|
||||
| FieldVal.term v => v.prettyPrint
|
||||
| FieldVal.nested s => formatStruct s
|
||||
|
||||
partial def formatStruct : Struct → Format
|
||||
| ⟨_, structName, fields, source⟩ =>
|
||||
"{" ++ fmt structName ++ " . " ++ Format.joinSep (fields.map (formatField formatStruct)) ", " ++ fmt source ++ "}"
|
||||
|
||||
instance Struct.hasFormat : HasFormat Struct := ⟨formatStruct⟩
|
||||
instance Struct.hasToString : HasToString Struct := ⟨toString ∘ format⟩
|
||||
|
||||
instance Field.hasFormat : HasFormat (Field Struct) := ⟨formatField formatStruct⟩
|
||||
instance Field.hasToString : HasToString (Field Struct) := ⟨toString ∘ format⟩
|
||||
|
||||
/-
|
||||
Recall that `structInstField` elements have the form
|
||||
```
|
||||
def structInstField := parser! structInstLVal >> " := " >> termParser
|
||||
def structInstLVal := (ident <|> numLit <|> structInstArrayRef) >> many (("." >> (ident <|> numLit)) <|> structInstArrayRef)
|
||||
def structInstArrayRef := parser! "[" >> termParser >>"]"
|
||||
-/
|
||||
|
||||
/- Given a structure instance element `structInstElem`, prepend the new fields. -/
|
||||
private def prependFields (structInstElem : Syntax) (newFields : List Name) : Syntax :=
|
||||
match newFields with
|
||||
| [] => structInstElem
|
||||
| first :: rest =>
|
||||
let currFirst := structInstElem.getArg 0;
|
||||
let currFirst := if currFirst.isIdent then mkNullNode #[mkAtomFrom currFirst ".", currFirst] else currFirst;
|
||||
let restStx := rest.toArray.map $ fun fieldName => mkNullNode #[mkAtomFrom structInstElem ".", mkIdentFrom structInstElem fieldName];
|
||||
let newManyArgs := restStx.push currFirst ++ (structInstElem.getArg 1).getArgs;
|
||||
let structInstElem := structInstElem.setArg 1 (mkNullNode newManyArgs);
|
||||
structInstElem.setArg 0 (mkIdentFrom structInstElem first)
|
||||
-- Remark: this code relies on the fact that `expandStruct` only transforms `fieldLHS.fieldName`
|
||||
def FieldLHS.toSyntax (first : Bool) : FieldLHS → Syntax
|
||||
| FieldLHS.modifyOp stx _ => stx
|
||||
| FieldLHS.fieldName stx name => if first then mkIdentFrom stx name else mkNullNode #[mkAtomFrom stx ".", mkIdentFrom stx name]
|
||||
| FieldLHS.fieldIndex stx _ => if first then stx else mkNullNode #[mkAtomFrom stx ".", stx]
|
||||
|
||||
@[inline] private def modifyStructInstFieldsM {m : Type → Type} [Monad m] (stx : Syntax) (f : Syntax → m Syntax) : m Syntax := do
|
||||
let args := (stx.getArg 2).getArgs;
|
||||
args ← args.mapM $ fun arg =>
|
||||
if arg.getKind == `Lean.Parser.Term.structInstField then
|
||||
f arg
|
||||
else
|
||||
pure arg;
|
||||
pure $ stx.setArg 2 (mkNullNode args)
|
||||
def FieldVal.toSyntax : FieldVal Struct → Syntax
|
||||
| FieldVal.term stx => stx
|
||||
| _ => unreachable!
|
||||
|
||||
@[inline] private def modifyStructInstFields (stx : Syntax) (f : Syntax → Syntax) : Syntax :=
|
||||
Id.run $ modifyStructInstFieldsM stx f
|
||||
def Field.toSyntax : Field Struct → Syntax
|
||||
| field =>
|
||||
let stx := field.ref;
|
||||
let stx := stx.setArg 3 field.val.toSyntax;
|
||||
match field.lhs with
|
||||
| first::rest =>
|
||||
let stx := stx.setArg 0 $ first.toSyntax true;
|
||||
let stx := stx.setArg 1 $ mkNullNode $ rest.toArray.map (FieldLHS.toSyntax false);
|
||||
stx
|
||||
| _ => unreachable!
|
||||
|
||||
/- Given a structure instance `stx`, expand the first field of each element if it is a composite name.
|
||||
Example:
|
||||
```
|
||||
(Term.structInstField `x.y (null) ":=" (Term.num (numLit "1")))
|
||||
```
|
||||
is expanded into
|
||||
```
|
||||
(Term.structInstField `x (null (null "." `y)) ":=" (Term.num (numLit "1")))
|
||||
``` -/
|
||||
private def expandCompositeFields (stx : Syntax) : Syntax :=
|
||||
modifyStructInstFields stx $ fun arg =>
|
||||
let field := arg.getArg 0;
|
||||
if field.isIdent then
|
||||
match field.getId with
|
||||
| Name.str Name.anonymous _ _ => arg -- atomic field
|
||||
| Name.str pre s _ =>
|
||||
-- update first with `s`
|
||||
let arg := arg.setArg 0 (mkIdentFrom field (mkNameSimple s));
|
||||
prependFields arg pre.components
|
||||
| _ => unreachable!
|
||||
else
|
||||
arg
|
||||
private def toFieldLHS (stx : Syntax) : FieldLHS :=
|
||||
if stx.getKind == `Lean.Parser.Term.structInstArrayRef then FieldLHS.modifyOp stx (stx.getArg 1)
|
||||
else
|
||||
-- Note that the representation of the first field is different.
|
||||
let stx := if stx.getKind == nullKind then stx.getArg 1 else stx;
|
||||
if stx.isIdent then FieldLHS.fieldName stx stx.getId
|
||||
else match stx.isNatLit? with
|
||||
| some idx => FieldLHS.fieldIndex stx idx
|
||||
| none => unreachable!
|
||||
|
||||
/- Example `{ Prod . 1 := 10, 2 := true }` => `{ Prod . fst := 10, snd := true }` -/
|
||||
private def expandNumLitFields (stx : Syntax) (structName : Name) : TermElabM Syntax := do
|
||||
env ← getEnv;
|
||||
let fieldNames := getStructureFields env structName;
|
||||
modifyStructInstFieldsM stx $ fun arg =>
|
||||
let field := arg.getArg 0;
|
||||
match field.isNatLit? with
|
||||
| none => pure arg
|
||||
| some idx =>
|
||||
if idx == 0 then throwError arg "invalid field index, index must be greater than 0"
|
||||
else if idx > fieldNames.size then throwError arg ("invalid field index, structure has only #" ++ toString fieldNames.size ++ " fields")
|
||||
else
|
||||
let newField := mkIdentFrom field (fieldNames.get! idx);
|
||||
pure $ arg.setArg 0 newField
|
||||
private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : Struct :=
|
||||
let args := (stx.getArg 2).getArgs;
|
||||
let fieldsStx := args.filter $ fun arg => arg.getKind == `Lean.Parser.Term.structInstField;
|
||||
let fields := fieldsStx.toList.map $ fun fieldStx =>
|
||||
let val := fieldStx.getArg 3;
|
||||
let first := toFieldLHS (fieldStx.getArg 0);
|
||||
let rest := (fieldStx.getArg 1).getArgs.toList.map $ toFieldLHS;
|
||||
({ref := fieldStx, lhs := first :: rest, val := FieldVal.term val } : Field Struct);
|
||||
⟨stx, structName, fields, source⟩
|
||||
|
||||
def Struct.modifyFieldsM {m : Type → Type} [Monad m] (s : Struct) (f : Fields → m Fields) : m Struct :=
|
||||
match s with
|
||||
| ⟨ref, structName, fields, source⟩ => do fields ← f fields; pure ⟨ref, structName, fields, source⟩
|
||||
|
||||
def Struct.modifyFields (s : Struct) (f : Fields → Fields) : Struct :=
|
||||
Id.run $ s.modifyFieldsM f
|
||||
|
||||
private def expandCompositeFields (s : Struct) : Struct :=
|
||||
s.modifyFields $ fun fields => fields.map $ fun field => match field with
|
||||
| { lhs := FieldLHS.fieldName ref (Name.str Name.anonymous _ _) :: rest, .. } => field
|
||||
| { lhs := FieldLHS.fieldName ref n@(Name.str _ _ _) :: rest, .. } =>
|
||||
let newEntries := n.components.map $ FieldLHS.fieldName ref;
|
||||
{ lhs := newEntries ++ rest, .. field }
|
||||
| _ => field
|
||||
|
||||
private def expandNumLitFields (s : Struct) : TermElabM Struct :=
|
||||
s.modifyFieldsM $ fun fields => do
|
||||
env ← getEnv;
|
||||
let fieldNames := getStructureFields env s.structName;
|
||||
fields.mapM $ fun field => match field with
|
||||
| { lhs := FieldLHS.fieldIndex ref idx :: rest, .. } =>
|
||||
if idx == 0 then throwError ref "invalid field index, index must be greater than 0"
|
||||
else if idx > fieldNames.size then throwError ref ("invalid field index, structure has only #" ++ toString fieldNames.size ++ " fields")
|
||||
else pure { lhs := FieldLHS.fieldName ref (fieldNames.get! $ idx - 1) :: rest, .. field }
|
||||
| _ => pure field
|
||||
|
||||
/- For example, consider the following structures:
|
||||
```
|
||||
|
|
@ -238,97 +315,118 @@ modifyStructInstFieldsM stx $ fun arg =>
|
|||
```
|
||||
{ C . toB.toA.x := 0, toB.y := 0, z := true }
|
||||
``` -/
|
||||
private def expandParentFields (stx : Syntax) (structName : Name) : TermElabM Syntax := do
|
||||
private def expandParentFields (s : Struct) : TermElabM Struct := do
|
||||
env ← getEnv;
|
||||
modifyStructInstFieldsM stx $ fun arg =>
|
||||
let field := arg.getArg 0;
|
||||
if field.isIdent then
|
||||
let fieldName := field.getId;
|
||||
match findField? env structName fieldName with
|
||||
| none => throwError arg ("'" ++ fieldName ++ "' is not a field of structure '" ++ structName ++ "'")
|
||||
s.modifyFieldsM $ fun fields => fields.mapM $ fun field => match field with
|
||||
| { lhs := FieldLHS.fieldName ref fieldName :: rest, .. } =>
|
||||
match findField? env s.structName fieldName with
|
||||
| none => throwError ref ("'" ++ fieldName ++ "' is not a field of structure '" ++ s.structName ++ "'")
|
||||
| some baseStructName =>
|
||||
if baseStructName == structName then pure arg
|
||||
else match getPathToBaseStructure? env baseStructName structName with
|
||||
if baseStructName == s.structName then pure field
|
||||
else match getPathToBaseStructure? env baseStructName s.structName with
|
||||
| some path => do
|
||||
let path := path.map $ fun funName => match funName with
|
||||
| Name.str _ s _ => mkNameSimple s
|
||||
| Name.str _ s _ => FieldLHS.fieldName ref (mkNameSimple s)
|
||||
| _ => unreachable!;
|
||||
pure $ prependFields arg path
|
||||
| _ => throwError arg ("failed to access field '" ++ fieldName ++ "' in parent structure")
|
||||
else
|
||||
pure arg
|
||||
pure { lhs := path ++ field.lhs, .. field }
|
||||
| _ => throwError ref ("failed to access field '" ++ fieldName ++ "' in parent structure")
|
||||
| _ => pure field
|
||||
|
||||
/- We say a `structInstField` is simple if the suffix is empty.
|
||||
That is, the `many` component `many (("." >> (ident <|> numLit)) <|> structInstArrayRef)` is empty. -/
|
||||
private def isSimpleStructInstField (stx : Syntax) : Bool :=
|
||||
(stx.getArg 1).getArgs.isEmpty
|
||||
private abbrev FieldMap := HashMap Name Fields
|
||||
|
||||
private def getStructInstFields (stx : Syntax) : Array Syntax :=
|
||||
(stx.getArg 2).getArgs.filter $ fun elem => elem.getKind == `Lean.Parser.Term.structInstField
|
||||
|
||||
private def getFieldName (structInstField : Syntax) : Name :=
|
||||
(structInstField.getArg 0).getId
|
||||
|
||||
private abbrev FieldMap := HashMap Name (List Syntax)
|
||||
|
||||
private def groupFields (instFields : Array Syntax) : TermElabM FieldMap :=
|
||||
instFields.foldlM
|
||||
(fun fieldMap instField =>
|
||||
let fieldName := getFieldName instField;
|
||||
match fieldMap.find? fieldName with
|
||||
| some (prevInstField::restInstFields) =>
|
||||
if isSimpleStructInstField prevInstField || isSimpleStructInstField instField then
|
||||
throwError instField ("field '" ++ fieldName ++ "' has already beed specified")
|
||||
else
|
||||
pure $ fieldMap.insert fieldName (instField::prevInstField::restInstFields)
|
||||
| _ => pure $ fieldMap.insert fieldName [instField])
|
||||
private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
|
||||
fields.foldlM
|
||||
(fun fieldMap field =>
|
||||
match field.lhs with
|
||||
| FieldLHS.fieldName _ fieldName :: rest =>
|
||||
match fieldMap.find? fieldName with
|
||||
| some (prevField::restFields) =>
|
||||
if field.isSimple || prevField.isSimple then
|
||||
throwError field.ref ("field '" ++ fieldName ++ "' has already beed specified")
|
||||
else
|
||||
pure $ fieldMap.insert fieldName (field::prevField::restFields)
|
||||
| _ => pure $ fieldMap.insert fieldName [field]
|
||||
| _ => unreachable!)
|
||||
{}
|
||||
|
||||
private def isSimpleStructInstFieldSingleton? : List Syntax → Option Syntax
|
||||
| [instField] => if isSimpleStructInstField instField then some instField else none
|
||||
| _ => none
|
||||
private def isSimpleField? : Fields → Option (Field Struct)
|
||||
| [field] => if field.isSimple then some field else none
|
||||
| _ => none
|
||||
|
||||
-- def structInstSource := parser! ".." >> optional termParser
|
||||
private def mkStructInstSource (ref : Syntax) (optTermParser : Syntax) : Syntax :=
|
||||
Syntax.node `Lean.Parser.Term.structInstSource #[mkAtomFrom ref "..", optTermParser]
|
||||
private def getFieldIdx (ref : Syntax) (structName : Name) (fieldNames : Array Name) (fieldName : Name) : TermElabM Nat := do
|
||||
match fieldNames.findIdx? $ fun n => n == fieldName with
|
||||
| some idx => pure idx
|
||||
| none => throwError ref ("field '" ++ fieldName ++ "' is not a valid field of '" ++ structName ++ "'")
|
||||
|
||||
private def mkProjStx (s : Syntax) (fieldName : Name) : Syntax :=
|
||||
Syntax.node `Lean.Parser.Term.proj #[s, mkAtomFrom s ".", mkIdentFrom s fieldName]
|
||||
|
||||
structure FieldView :=
|
||||
(ref : Syntax)
|
||||
(fieldName : Name)
|
||||
(val: Syntax)
|
||||
private def mkSubstructSource (ref : Syntax) (structName : Name) (fieldNames : Array Name) (fieldName : Name) (src : Source) : TermElabM Source :=
|
||||
match src with
|
||||
| Source.explicit stx src => do
|
||||
idx ← getFieldIdx ref structName fieldNames fieldName;
|
||||
let stx := stx.modifyArg 1 $ fun stx => stx.modifyArg 0 $ fun stx => mkProjStx stx fieldName;
|
||||
pure $ Source.explicit stx (mkProj structName idx src)
|
||||
| s => pure s
|
||||
|
||||
@[specialize] private def groupFields (expandStruct : Struct → TermElabM Struct) (s : Struct) : TermElabM Struct := do
|
||||
env ← getEnv;
|
||||
let fieldNames := getStructureFields env s.structName;
|
||||
s.modifyFieldsM $ fun fields => do
|
||||
fieldMap ← mkFieldMap fields;
|
||||
fieldMap.toList.mapM $ fun ⟨fieldName, fields⟩ =>
|
||||
match isSimpleField? fields with
|
||||
| some field => pure field
|
||||
| none => do
|
||||
let substructFields := fields.map $ fun field => { lhs := field.lhs.tail!, .. field };
|
||||
substructSource ← mkSubstructSource s.ref s.structName fieldNames fieldName s.source;
|
||||
let field := fields.head!;
|
||||
match Lean.isSubobjectField? env s.structName fieldName with
|
||||
| some substructName => do
|
||||
let substruct := Struct.mk s.ref substructName substructFields substructSource;
|
||||
substruct ← expandStruct substruct;
|
||||
pure { lhs := [field.lhs.head!], val := FieldVal.nested substruct, .. field }
|
||||
| none => do
|
||||
-- It is not a substructure field. Thus, we wrap fields using `Syntax`, and use `elabTerm` to process them.
|
||||
let valStx := s.ref; -- construct substructure syntax using s.ref as template
|
||||
let valStx := valStx.setArg 1 mkNullNode; -- erase optional struct name
|
||||
let args := substructFields.toArray.map $ Field.toSyntax;
|
||||
let args := substructSource.addSyntax args;
|
||||
let valStx := valStx.setArg 2 (mkSepStx args (mkAtomFrom s.ref ","));
|
||||
pure { lhs := [field.lhs.head!], val := FieldVal.term valStx, .. field }
|
||||
|
||||
private partial def expandStruct : Struct → TermElabM Struct
|
||||
| s => do
|
||||
let s := expandCompositeFields s;
|
||||
s ← expandNumLitFields s;
|
||||
s ← expandParentFields s;
|
||||
groupFields expandStruct s
|
||||
|
||||
/-
|
||||
namespace ElabFields
|
||||
|
||||
structure Context :=
|
||||
(structPath : List Name)
|
||||
|
||||
structure PendingField :=
|
||||
(ref : Syntax)
|
||||
(structPath : List Name)
|
||||
(fieldName : Name)
|
||||
(mvar : MVarId)
|
||||
|
||||
structure State :=
|
||||
(instMVars : Array MVarId)
|
||||
(pendingFields : List PendingField)
|
||||
|
||||
end ElabFields
|
||||
|
||||
private def getFieldViews (stx : Syntax) (sourceView : SourceView) : TermElabM (List FieldView) := do
|
||||
let instFields := getStructInstFields stx;
|
||||
fieldMap ← groupFields instFields;
|
||||
pure $ fieldMap.toList.map $ fun ⟨fieldName, instFields⟩ =>
|
||||
match isSimpleStructInstFieldSingleton? instFields with
|
||||
| some instField => { ref := instField, fieldName := fieldName, val := instField.getArg 3 }
|
||||
| none =>
|
||||
let newArgs := instFields.toArray.map $ fun instField =>
|
||||
let suffixElems := (instField.getArg 1).getArgs;
|
||||
let newField := suffixElems.get! 0;
|
||||
let newField := if newField.getKind == `Lean.Parser.Term.structInstArrayRef then newField else newField.getArg 1;
|
||||
let newSuffixElems := suffixElems.eraseIdx 0;
|
||||
let instField := instField.setArg 0 newField;
|
||||
let instField := instField.setArg 1 (mkNullNode newSuffixElems);
|
||||
instField;
|
||||
let newArgs := match sourceView with
|
||||
| SourceView.none => newArgs
|
||||
| SourceView.implicit => newArgs.push $ mkStructInstSource stx mkNullNode
|
||||
| SourceView.explicit src => newArgs.push $ mkStructInstSource stx (mkNullNode #[mkProjStx src fieldName]);
|
||||
let newStruct := stx.setArg 1 mkNullNode; -- erase explicit struct name
|
||||
let newStruct := stx.setArg 2 (mkSepStx newArgs (mkAtomFrom stx ","));
|
||||
{ ref := instFields.head!, fieldName := fieldName, val := newStruct }
|
||||
|
||||
structure CtorHeaderResult :=
|
||||
(ctorFn : Expr)
|
||||
(ctorFnType : Expr)
|
||||
(instMVars : Array Expr)
|
||||
(instMVars : Array MVarId)
|
||||
|
||||
private def mkCtorHeaderAux (ref : Syntax) : Nat → Expr → Expr → Array Expr → TermElabM CtorHeaderResult
|
||||
private def mkCtorHeaderAux (ref : Syntax) : Nat → Expr → Expr → Array MVarId → TermElabM CtorHeaderResult
|
||||
| 0, type, ctorFn, instMVars => pure { ctorFn := ctorFn, ctorFnType := type, instMVars := instMVars }
|
||||
| n+1, type, ctorFn, instMVars => do
|
||||
type ← whnfForall ref type;
|
||||
|
|
@ -337,7 +435,7 @@ private def mkCtorHeaderAux (ref : Syntax) : Nat → Expr → Expr → Array Exp
|
|||
match c.binderInfo with
|
||||
| BinderInfo.instImplicit => do
|
||||
a ← mkFreshExprMVar ref d MetavarKind.synthetic;
|
||||
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) (instMVars.push a)
|
||||
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) (instMVars.push a.mvarId!)
|
||||
| _ => do
|
||||
a ← mkFreshExprMVar ref d;
|
||||
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) instMVars
|
||||
|
|
@ -370,7 +468,24 @@ r ← mkCtorHeaderAux ref ctorVal.nparams type val #[];
|
|||
propagateExpectedType ref r.ctorFnType ctorVal.nfields expectedType?;
|
||||
pure r
|
||||
|
||||
private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (sourceView : SourceView) : TermElabM Expr := do
|
||||
private partial def elabFields (ref : Syntax) (structName : Name) (sourceView : Source) : List FieldView → Expr → Expr → TermElabM Expr
|
||||
| fieldViews, type, e => do
|
||||
type ← whnfForall ref type;
|
||||
match type with
|
||||
| Expr.forallE n d b c => do
|
||||
let fieldName := deinternalizeFieldName n;
|
||||
dbgTrace (">> field " ++ toString fieldName);
|
||||
arg ← mkFreshExprMVar ref d; -- TODO
|
||||
let fieldViews := fieldViews; -- TODO
|
||||
let b := b.instantiate1 arg;
|
||||
let e := mkApp e arg;
|
||||
elabFields fieldViews b e
|
||||
| _ =>
|
||||
match fieldViews with
|
||||
| fview :: _ => throwError fview.ref ("'" ++ fview.fieldName ++ "' is not a field of structure '" ++ structName ++ "'")
|
||||
| _ => pure e
|
||||
|
||||
private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (sourceView : Source) : TermElabM Expr := do
|
||||
structName ← getStructName stx expectedType? sourceView;
|
||||
env ← getEnv;
|
||||
unless (isStructureLike env structName) $
|
||||
|
|
@ -383,7 +498,17 @@ let ctorVal := getStructureCtor env structName;
|
|||
ctorHeader ← mkCtorHeader stx ctorVal expectedType?;
|
||||
-- fieldViews.forM $ fun v => dbgTrace (toString v.fieldName ++ " := " ++ toString v.val);
|
||||
-- dbgTrace (">> " ++ toString ctorHeader.ctorFn);
|
||||
throwError stx ("WIP")
|
||||
s ← elabFields stx structName sourceView fieldViews ctorHeader.ctorFnType ctorHeader.ctorFn;
|
||||
synthesizeAppInstMVars stx ctorHeader.instMVars;
|
||||
pure s
|
||||
-/
|
||||
private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (source : Source) : TermElabM Expr := do
|
||||
structName ← getStructName stx expectedType? source;
|
||||
env ← getEnv;
|
||||
unless (isStructureLike env structName) $
|
||||
throwError stx ("invalid {...} notation, '" ++ structName ++ "' is not a structure");
|
||||
struct ← expandStruct $ mkStructView stx structName source;
|
||||
throwError stx ("WIP" ++ Format.line ++ toString struct)
|
||||
|
||||
@[builtinTermElab structInst] def elabStructInst : TermElab :=
|
||||
fun stx expectedType? => do
|
||||
|
|
@ -394,10 +519,11 @@ fun stx expectedType? => do
|
|||
sourceView ← getStructSource stx;
|
||||
modifyOp? ← isModifyOp? stx;
|
||||
match modifyOp?, sourceView with
|
||||
| some modifyOp, SourceView.explicit source => elabModifyOp stx modifyOp source expectedType?
|
||||
| some _, _ => throwError stx ("invalid {...} notation, explicit source is required when using '[<index>] := <value>'")
|
||||
| _, _ => elabStructInstAux stx expectedType? sourceView
|
||||
| some modifyOp, Source.explicit _ source => elabModifyOp stx modifyOp source expectedType?
|
||||
| some _, _ => throwError stx ("invalid {...} notation, explicit source is required when using '[<index>] := <value>'")
|
||||
| _, _ => elabStructInstAux stx expectedType? sourceView
|
||||
|
||||
end StructInst
|
||||
end Term
|
||||
end Elab
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ when (namedArgs.any $ fun namedArg' => namedArg.name == namedArg'.name) $
|
|||
throwError ref ("argument '" ++ toString namedArg.name ++ "' was already set");
|
||||
pure $ namedArgs.push namedArg
|
||||
|
||||
private def synthesizeAppInstMVars (ref : Syntax) (instMVars : Array MVarId) : TermElabM Unit :=
|
||||
def synthesizeAppInstMVars (ref : Syntax) (instMVars : Array MVarId) : TermElabM Unit :=
|
||||
instMVars.forM $ fun mvarId =>
|
||||
unlessM (synthesizeInstMVarCore ref mvarId) $
|
||||
registerSyntheticMVar ref mvarId SyntheticMVarKind.typeClass
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue