diff --git a/src/Init/Lean/Elab/StructInst.lean b/src/Init/Lean/Elab/StructInst.lean index 0bb50105c5..a9dd741bff 100644 --- a/src/Init/Lean/Elab/StructInst.lean +++ b/src/Init/Lean/Elab/StructInst.lean @@ -14,63 +14,29 @@ namespace Elab namespace Term namespace StructInst -/- parser! symbol "{" appPrec >> optional (try (ident >> " . ")) >> sepBy (structInstField <|> structInstSource) ", " true >> "}" -/ - -namespace ExpandNonAtomicExplicitSource - -structure State := -(found : Bool := false) -(source? : Option Syntax := none) - -/- Auxiliary function for `expandNonAtomicExplicitSource` -/ -def main (stx : Syntax) : StateT State TermElabM (Option Syntax) := do -let args := (stx.getArg 2).getArgs; -args ← args.mapM $ fun arg => - if arg.getKind == `Lean.Parser.Term.structInstSource then do - -- parser! ".." >> optional termParser - s ← get; - if s.found then - liftM $ throwError arg "source has already been specified" - else - let optSource := arg.getArg 1; - if optSource.isNone then do - modify $ fun s => { s with found := true }; - pure arg - else do - let source := optSource.getArg 0; - fvar? ← liftM $ isLocalTermId? source; - match fvar? with - | some _ => do - -- it is already a local variable - modify $ fun s => { s with found := true }; - pure arg - | none => do - src ← `(src); - modify $ fun s => { s with found := true, source? := source }; - let optSource := optSource.setArg 0 src; - let arg := arg.setArg 1 optSource; - pure arg - else - pure arg; -s ← get; -match s.source? with -| none => pure none -| some source => do - let newStx := stx.setArg 2 (mkNullNode args); - `(let src := $source; $newStx) - -end ExpandNonAtomicExplicitSource +/- parser! symbol "{" appPrec >> optional (try (termParser >> "with")) >> sepBy structInstField ", " true >> optional ".." >> optional (" : " >> termParser) >> "}" -/ /- -If `stx` is of the form `{ ... .. s }` and `s` is not a local variable, expand into `let src := s; { ... .. src}`. +If `stx` is of the form `{ s with ... }` and `s` is not a local variable, expand into `let src := s; { src with ... }`. -/ private def expandNonAtomicExplicitSource (stx : Syntax) : TermElabM (Option Syntax) := -withFreshMacroScope $ (ExpandNonAtomicExplicitSource.main stx).run' {} +withFreshMacroScope $ + let sourceOpt := stx.getArg 1; + if sourceOpt.isNone then pure none else do + let source := sourceOpt.getArg 0; + fvar? ← isLocalTermId? source; + match fvar? with + | some _ => pure none + | none => do + src ← `(src); + let sourceOpt := sourceOpt.setArg 0 src; + let stxNew := stx.setArg 1 sourceOpt; + `(let src := $source; $stxNew) inductive Source | none -- structure instance source has not been provieded | implicit (stx : Syntax) -- `..` -| explicit (stx : Syntax) (src : Expr) -- `.. term` +| explicit (stx : Syntax) (src : Expr) -- `src with` instance Source.inhabited : Inhabited Source := ⟨Source.none⟩ @@ -78,35 +44,25 @@ def Source.isNone : Source → Bool | Source.none => true | _ => false -instance Source.hasFormat : HasFormat Source := -⟨fun src => match src with - | Source.none => "" - | Source.implicit _ => " .." - | Source.explicit _ src => " .. " ++ format src⟩ - -def Source.addSyntax : Source → Array Syntax → Array Syntax -| Source.none, acc => acc -| Source.implicit stx, acc => acc.push stx -| Source.explicit stx _, acc => acc.push stx +def setStructSourceSyntax (structStx : Syntax) : Source → Syntax +| Source.none => (structStx.setArg 1 mkNullNode).setArg 3 mkNullNode +| Source.implicit stx => (structStx.setArg 1 mkNullNode).setArg 3 stx +| Source.explicit stx _ => (structStx.setArg 1 stx).setArg 3 mkNullNode private def getStructSource (stx : Syntax) : TermElabM Source := -let args := (stx.getArg 2).getArgs; -args.foldSepByM - (fun arg r => - if arg.getKind == `Lean.Parser.Term.structInstSource then - -- parser! ".." >> optional termParser - if !r.isNone then throwError arg "source has already been specified" - else - let optTerm := arg.getArg 1; - if optTerm.isNone then pure $ Source.implicit arg - else do - fvar? ← isLocalTermId? (optTerm.getArg 0); - match fvar? with - | none => unreachable! -- expandNonAtomicExplicitSource must have been used when we get here - | some fvar => pure $ Source.explicit arg fvar - else - pure r) - Source.none +let explicitSource := stx.getArg 1; +let implicitSource := stx.getArg 3; +if explicitSource.isNone && implicitSource.isNone then + pure Source.none +else if explicitSource.isNone then + pure $ Source.implicit implicitSource +else if implicitSource.isNone then do + fvar? ← isLocalTermId? (explicitSource.getArg 0); + match fvar? with + | none => unreachable! -- expandNonAtomicExplicitSource must have been used when we get here + | some src => pure $ Source.explicit explicitSource src +else + throwError stx "invalid structure instance `with` and `..` cannot be used together" /- We say a `{ ... }` notation is a `modifyOp` if it contains only one @@ -117,34 +73,29 @@ private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do let args := (stx.getArg 2).getArgs; s? ← args.foldSepByM (fun arg s? => - let k := arg.getKind; - if k == `Lean.Parser.Term.structInstSource then pure s? - else if k == `Lean.Parser.Term.structInstField then - /- Remark: the syntax for `structInstField` is - ``` - def structInstLVal := (ident <|> numLit <|> structInstArrayRef) >> many (group ("." >> (ident <|> numLit)) <|> structInstArrayRef) - def structInstField := parser! structInstLVal >> " := " >> termParser - ``` -/ - let lval := arg.getArg 0; - let k := lval.getKind; - if k == `Lean.Parser.Term.structInstArrayRef then - match s? with - | none => pure (some arg) - | some s => - if s.getKind == `Lean.Parser.Term.structInstArrayRef then - throwError arg "invalid {...} notation, at most one `[..]` at a given level" - else - throwError arg "invalid {...} notation, can't mix field and `[..]` at a given level" - else - match s? with - | none => pure (some arg) - | some s => - if s.getKind == `Lean.Parser.Term.structInstArrayRef then - throwError arg "invalid {...} notation, can't mix field and `[..]` at a given level" - else - pure s? + /- Remark: the syntax for `structInstField` is + ``` + def structInstLVal := (ident <|> numLit <|> structInstArrayRef) >> many (group ("." >> (ident <|> numLit)) <|> structInstArrayRef) + def structInstField := parser! structInstLVal >> " := " >> termParser + ``` -/ + let lval := arg.getArg 0; + let k := lval.getKind; + if k == `Lean.Parser.Term.structInstArrayRef then + match s? with + | none => pure (some arg) + | some s => + if s.getKind == `Lean.Parser.Term.structInstArrayRef then + throwError arg "invalid {...} notation, at most one `[..]` at a given level" + else + throwError arg "invalid {...} notation, can't mix field and `[..]` at a given level" else - throwError arg "unexpected {...} notation") + match s? with + | none => pure (some arg) + | some s => + if s.getKind == `Lean.Parser.Term.structInstArrayRef then + throwError arg "invalid {...} notation, can't mix field and `[..]` at a given level" + else + pure s?) none; match s? with | none => pure none @@ -174,38 +125,27 @@ else do let val := val.setArg 2 $ mkNullNode #[valField, mkAtomFrom stx ", ", valSource]; continue val -/- Get structure name and elaborate explicit source (if avialable) -/ -private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceView : Source) : TermElabM Name := -let arg := stx.getArg 1; -if !arg.isNone then do - r : List (Name × List String) ← resolveGlobalName (arg.getIdAt 0); - env ← getEnv; - let r := r.filter $ fun p => p.2.isEmpty && isStructureLike env p.1; - let candidates := r.map $ fun p => p.1; - match candidates with - | [c] => pure c - | [] => throwError arg "invalid {...} notation, structure expected" - | _ => throwError arg ("invalid {...} notation, ambiguous " ++ toString candidates) -else do - let ref := stx; - tryPostponeIfNoneOrMVar expectedType?; - let useSource : Unit → TermElabM Name := fun _ => - match sourceView with - | Source.explicit _ src => do - srcType ← inferType stx src; - srcType ← whnf stx srcType; - tryPostponeIfMVar srcType; - match srcType.getAppFn with - | Expr.const constName _ _ => pure constName - | _ => throwError stx ("invalid {...} notation, source type is not of the form (C ...)" ++ indentExpr srcType) - | _ => throwError ref ("invalid {...} notation, expected type is not of the form (C ...)" ++ indentExpr expectedType?.get!); - match expectedType? with - | none => useSource () - | some expectedType => do - expectedType ← whnf ref expectedType; - match expectedType.getAppFn with +/- Get structure name and elaborate explicit source (if available) -/ +private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceView : Source) : TermElabM Name := do +let ref := stx; +tryPostponeIfNoneOrMVar expectedType?; +let useSource : Unit → TermElabM Name := fun _ => + match sourceView with + | Source.explicit _ src => do + srcType ← inferType stx src; + srcType ← whnf stx srcType; + tryPostponeIfMVar srcType; + match srcType.getAppFn with | Expr.const constName _ _ => pure constName - | _ => useSource () + | _ => throwError stx ("invalid {...} notation, source type is not of the form (C ...)" ++ indentExpr srcType) + | _ => throwError ref ("invalid {...} notation, expected type is not of the form (C ...)" ++ indentExpr expectedType?.get!); +match expectedType? with +| none => useSource () +| some expectedType => do + expectedType ← whnf ref expectedType; + match expectedType.getAppFn with + | Expr.const constName _ _ => pure constName + | _ => useSource () inductive FieldLHS | fieldName (ref : Syntax) (name : Name) @@ -261,7 +201,11 @@ Format.joinSep field.lhs " . " ++ " := " ++ partial def formatStruct : Struct → Format | ⟨_, structName, fields, source⟩ => - "{" ++ fmt structName ++ " . " ++ Format.joinSep (fields.map (formatField formatStruct)) ", " ++ fmt source ++ "}" + let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", "; + match source with + | Source.none => "{" ++ fieldsFmt ++ "}" + | Source.implicit _ => "{" ++ fieldsFmt ++ " .. }" + | Source.explicit _ src => "{" ++ format src ++ " with " ++ fieldsFmt ++ "}" instance Struct.hasFormat : HasFormat Struct := ⟨formatStruct⟩ instance Struct.hasToString : HasToString Struct := ⟨toString ∘ format⟩ @@ -358,11 +302,11 @@ s.modifyFieldsM $ fun fields => do This method expands parent structure fields using the path to the parent structure. For example, ``` - { C . x := 0, y := 0, z := true } + { x := 0, y := 0, z := true : C } ``` is expanded into ``` - { C . toB.toA.x := 0, toB.y := 0, z := true } + { toB.toA.x := 0, toB.y := 0, z := true : C } ``` -/ private def expandParentFields (s : Struct) : TermElabM Struct := do env ← getEnv; @@ -438,10 +382,10 @@ s.modifyFieldsM $ fun fields => do | none => do -- It is not a substructure field. Thus, we wrap fields using `Syntax`, and use `elabTerm` to process them. let valStx := s.ref; -- construct substructure syntax using s.ref as template - let valStx := valStx.setArg 1 mkNullNode; -- erase optional struct name + let valStx := valStx.setArg 4 mkNullNode; -- erase optional expected type let args := substructFields.toArray.map $ Field.toSyntax; - let args := substructSource.addSyntax args; let valStx := valStx.setArg 2 (mkSepStx args (mkAtomFrom s.ref ",")); + let valStx := setStructSourceSyntax valStx substructSource; pure { field with lhs := [field.lhs.head!], val := FieldVal.term valStx } def findField? (fields : Fields) (fieldName : Name) : Option (Field Struct) := @@ -473,8 +417,8 @@ fields ← fieldNames.foldlM | Source.none => addField FieldVal.default | Source.implicit _ => addField (FieldVal.term (mkHole s.ref)) | Source.explicit stx _ => - -- stx is of the form `".." >> optional termParser` - let src := (stx.getArg 1).getArg 0; + -- stx is of the form `optional (try (termParser >> "with"))` + let src := (stx.getArg 0).getArg 0; let val := mkProjStx src fieldName; addField (FieldVal.term val)) []; @@ -820,6 +764,7 @@ match mkStructView stx structName source with @[builtinTermElab structInst] def elabStructInst : TermElab := fun stx expectedType? => do + -- TODO: expand expected type syntax at structurInst stxNew? ← expandNonAtomicExplicitSource stx; match stxNew? with | some stxNew => withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?