From c3715bb5a00a20759103ead1d5d40b01867c9c31 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 5 Feb 2020 13:07:14 -0800 Subject: [PATCH] feat: add `expandParentFields` --- src/Init/Lean/Elab/StructInst.lean | 53 ++++++++++++++++++++++++++++++ src/Init/Lean/Parser/Term.lean | 2 +- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/src/Init/Lean/Elab/StructInst.lean b/src/Init/Lean/Elab/StructInst.lean index 24b424cc56..4ccfc40959 100644 --- a/src/Init/Lean/Elab/StructInst.lean +++ b/src/Init/Lean/Elab/StructInst.lean @@ -154,11 +154,64 @@ else do | Expr.const constName _ _ => pure constName | _ => useSource () +/- Convert a path such as `[N.C.toB, N.B.toA]` into `#[ "." toB, "." toA]` -/ +private def mkParentFieldNameFromPath (ref : Syntax) (path : List Name) : TermElabM (Array Syntax) := +path.toArray.mapM $ fun toFunName => + match toFunName with + | Name.str _ s _ => pure $ mkNullNode #[mkAtomFrom ref ".", mkIdentFrom ref (mkNameSimple s)] + | _ => throwError ref "invalid field name to parent structure" + +/- For example, consider the following structures: + ``` + structure A := (x : Nat) + structure B extends A := (y : Nat) + structure C extends B := (z : Bool) + ``` + This method expands parent structure fields using the path to the parent structure. + For example, + ``` + { C . x := 0, y := 0, z := true } + ``` + is expanded into + ``` + { C . toB.toA.x := 0, toB.y := 0, z := true } + ``` -/ +private def expandParentFields (stx : Syntax) (structName : Name) : TermElabM Syntax := do +env ← getEnv; +let args := (stx.getArg 2).getArgs; +args ← args.mapM $ fun arg => + if arg.getKind == `Lean.Parser.Term.structInstField then + /- arg is of the form + def structInstField := parser! structInstLVal >> " := " >> termParser + def structInstLVal := (ident <|> structInstArrayRef) >> many (("." >> (ident <|> numLit)) <|> structInstArrayRef) -/ + let field := arg.getArg 0; + if field.isIdent then + let fieldName := field.getId; + match findField? env structName fieldName with + | none => throwError arg ("'" ++ fieldName ++ "' is not a field of structure '" ++ structName ++ "'") + | some baseStructName => + if baseStructName == structName then pure arg + else match getPathToBaseStructure? env baseStructName structName with + | some (Name.str _ firstField _ :: rest) => do + let newFieldStx := mkIdentFrom arg (mkNameSimple firstField); + let arg := arg.setArg 0 newFieldStx; + newFieldsStx ← mkParentFieldNameFromPath arg (rest ++ [fieldName]); + let newManyArgs := newFieldsStx ++ (arg.getArg 1).getArgs; + let arg := arg.setArg 1 (mkNullNode newManyArgs); + pure arg + | _ => throwError arg ("failed to access field '" ++ fieldName ++ "' in parent structure") + else + pure arg + else + pure arg; +pure $ stx.setArg 2 (mkNullNode args) + private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (sourceView : SourceView) : TermElabM Expr := do structName ← getStructName stx expectedType? sourceView; env ← getEnv; unless (isStructureLike env structName) $ throwError stx ("invalid {...} notation, '" ++ structName ++ "' is not a structure"); +stx ← expandParentFields stx structName; throwError stx ("WIP " ++ toString structName ++ toString stx) @[builtinTermElab structInst] def elabStructInst : TermElab := diff --git a/src/Init/Lean/Parser/Term.lean b/src/Init/Lean/Parser/Term.lean index 957e44d8c3..36f255580a 100644 --- a/src/Init/Lean/Parser/Term.lean +++ b/src/Init/Lean/Parser/Term.lean @@ -61,7 +61,7 @@ def haveAssign := parser! " := " >> termParser @[builtinTermParser] def «show» := parser! symbol "show " leadPrec >> termParser >> fromTerm @[builtinTermParser] def «fun» := parser! unicodeSymbol "λ" "fun" leadPrec >> many1 (termParser appPrec) >> darrow >> termParser def structInstArrayRef := parser! "[" >> termParser >>"]" -def structInstLVal := (ident <|> structInstArrayRef) >> many (("." >> (ident <|> numLit)) <|> structInstArrayRef) +def structInstLVal := (ident <|> structInstArrayRef) >> many (group ("." >> (ident <|> numLit)) <|> structInstArrayRef) def structInstField := parser! structInstLVal >> " := " >> termParser def structInstSource := parser! ".." >> optional termParser @[builtinTermParser] def structInst := parser! symbol "{" appPrec >> optional (try (ident >> " . ")) >> sepBy (structInstField <|> structInstSource) ", " true >> "}"