feat: FromToJson for recursive inductives
This commit is contained in:
parent
4073b20b7d
commit
43190e0e63
3 changed files with 56 additions and 27 deletions
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
|
|
@ -1,3 +1,4 @@
|
|||
{
|
||||
"files.insertFinalNewline": true
|
||||
"files.insertFinalNewline": true,
|
||||
"files.trimTrailingWhitespace": true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,20 +39,27 @@ def mkToJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
|
|||
let indVal ← getConstInfoInduct declNames[0]
|
||||
let cmds ← liftTermElabM none <| do
|
||||
let ctx ← mkContext "toJson" declNames[0]
|
||||
let toJsonFuncId := mkIdent ctx.auxFunNames[0]
|
||||
-- Return syntax to JSONify `id`, either via `ToJson` or recursively
|
||||
-- if `id`'s type is the type we're deriving for.
|
||||
let mkToJson (id : Syntax) (type : Expr) : TermElabM Syntax := do
|
||||
if type.isAppOf indVal.name then `($toJsonFuncId:ident $id:ident)
|
||||
else `(toJson $id:ident)
|
||||
let header ← mkHeader ctx ``ToJson 1 ctx.typeInfos[0]
|
||||
let discrs ← mkDiscrs header indVal
|
||||
let alts ← mkAlts indVal fun ctor args userNames =>
|
||||
let alts ← mkAlts indVal fun ctor args userNames => do
|
||||
match args, userNames with
|
||||
| #[], _ => `(toJson $(quote ctor.name.getString!))
|
||||
| #[x], none => `(mkObj [($(quote ctor.name.getString!), toJson $x)])
|
||||
| xs, none => do
|
||||
let xs ← xs.mapM fun x => `(toJson $x)
|
||||
`(mkObj [($(quote ctor.name.getString!), toJson #[$[$xs:term],*])])
|
||||
| xs, some userNames => do
|
||||
let xs ← xs.mapIdxM fun idx x => `(($(quote userNames[idx].getString!), toJson $x))
|
||||
| #[(x, t)], none => `(mkObj [($(quote ctor.name.getString!), $(← mkToJson x t))])
|
||||
| xs, none =>
|
||||
let xs ← xs.mapM fun (x, t) => mkToJson x t
|
||||
`(mkObj [($(quote ctor.name.getString!), Json.arr #[$[$xs:term],*])])
|
||||
| xs, some userNames =>
|
||||
let xs ← xs.mapIdxM fun idx (x, t) => do
|
||||
`(($(quote userNames[idx].getString!), $(← mkToJson x t)))
|
||||
`(mkObj [($(quote ctor.name.getString!), mkObj [$[$xs:term],*])])
|
||||
let auxCmd ← `(match $[$discrs],* with $alts:matchAlt*)
|
||||
let auxCmd ← `(private def $(mkIdent ctx.auxFunNames[0]):ident $header.binders:explicitBinder* := $auxCmd)
|
||||
let auxCmd ← `(private def $toJsonFuncId:ident $header.binders:explicitBinder* := $auxCmd)
|
||||
return #[auxCmd] ++ (← mkInstanceCmds ctx ``ToJson declNames)
|
||||
cmds.forM elabCommand
|
||||
return true
|
||||
|
|
@ -61,7 +68,7 @@ def mkToJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
|
|||
where
|
||||
mkAlts
|
||||
(indVal : InductiveVal)
|
||||
(rhs : ConstructorVal → Array Syntax → (Option $ Array Name) → TermElabM Syntax) : TermElabM (Array Syntax) := do
|
||||
(rhs : ConstructorVal → Array (Syntax × Expr) → (Option $ Array Name) → TermElabM Syntax) : TermElabM (Array Syntax) := do
|
||||
indVal.ctors.toArray.mapM fun ctor => do
|
||||
let ctorInfo ← getConstInfoCtor ctor
|
||||
forallTelescopeReducing ctorInfo.type fun xs type => do
|
||||
|
|
@ -73,7 +80,8 @@ where
|
|||
-- add `_` for inductive parameters, they are inaccessible
|
||||
for i in [:indVal.numParams] do
|
||||
ctorArgs := ctorArgs.push (← `(_))
|
||||
let mut identNames := #[]
|
||||
-- bound constructor arguments and their types
|
||||
let mut binders := #[]
|
||||
let mut userNames := #[]
|
||||
for i in [:ctorInfo.numFields] do
|
||||
let x := xs[indVal.numParams + i]
|
||||
|
|
@ -81,12 +89,13 @@ where
|
|||
if !localDecl.userName.hasMacroScopes then
|
||||
userNames := userNames.push localDecl.userName
|
||||
let a := mkIdent (← mkFreshUserName `a)
|
||||
identNames := identNames.push a
|
||||
binders := binders.push (a, localDecl.type)
|
||||
ctorArgs := ctorArgs.push a
|
||||
patterns := patterns.push (← `(@$(mkIdent ctorInfo.name):ident $ctorArgs:term*))
|
||||
let rhs ← rhs ctorInfo identNames (if userNames.size == identNames.size then some userNames else none)
|
||||
let rhs ← rhs ctorInfo binders (if userNames.size == binders.size then some userNames else none)
|
||||
`(matchAltExpr| | $[$patterns:term],* => $rhs:term)
|
||||
|
||||
def declModsF := Parser.Command.declModifiers false
|
||||
def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
|
||||
if declNames.size == 1 then
|
||||
if (← isStructure (← getEnv) declNames[0]) then
|
||||
|
|
@ -108,10 +117,12 @@ def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
|
|||
let cmds ← liftTermElabM none <| do
|
||||
let ctx ← mkContext "fromJson" declNames[0]
|
||||
let header ← mkHeader ctx ``FromJson 0 ctx.typeInfos[0]
|
||||
let fromJsonFuncId := mkIdent ctx.auxFunNames[0]
|
||||
let discrs ← mkDiscrs header indVal
|
||||
let alts ← mkAlts indVal
|
||||
let alts ← mkAlts indVal fromJsonFuncId
|
||||
let matchCmd ← alts.foldrM (fun xs x => `($xs <|> $x)) (←`(Except.error "no inductive constructor matched"))
|
||||
let cmd ← `(private def $(mkIdent ctx.auxFunNames[0]):ident $header.binders:explicitBinder* (json : Json)
|
||||
let declMods ← if indVal.isRec then `(declModsF| private partial) else `(declModsF| private)
|
||||
let cmd ← `($declMods:declModifiers def $fromJsonFuncId:ident $header.binders:explicitBinder* (json : Json)
|
||||
: Except String $(← mkInductiveApp ctx.typeInfos[0] header.argNames) :=
|
||||
$matchCmd )
|
||||
return #[cmd] ++ (← mkInstanceCmds ctx ``FromJson declNames)
|
||||
|
|
@ -120,12 +131,12 @@ def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
|
|||
else
|
||||
return false
|
||||
where
|
||||
mkAlts (indVal : InductiveVal) : TermElabM (Array Syntax) := do
|
||||
mkAlts (indVal : InductiveVal) (fromJsonFuncId : Syntax) : TermElabM (Array Syntax) := do
|
||||
let alts ←
|
||||
indVal.ctors.toArray.mapM fun ctor => do
|
||||
let ctorInfo ← getConstInfoCtor ctor
|
||||
forallTelescopeReducing ctorInfo.type fun xs type => do
|
||||
let mut identNames := #[]
|
||||
let mut binders := #[]
|
||||
let mut userNames := #[]
|
||||
for i in [:ctorInfo.numFields] do
|
||||
let x := xs[indVal.numParams + i]
|
||||
|
|
@ -133,17 +144,25 @@ where
|
|||
if !localDecl.userName.hasMacroScopes then
|
||||
userNames := userNames.push localDecl.userName
|
||||
let a := mkIdent (← mkFreshUserName `a)
|
||||
identNames := identNames.push a
|
||||
let jsonAccess ← identNames.mapIdxM (fun idx _ => `(jsons[$(quote idx.val)]))
|
||||
binders := binders.push (a, localDecl.type)
|
||||
|
||||
-- Return syntax to parse `id`, either via `FromJson` or recursively
|
||||
-- if `id`'s type is the type we're deriving for.
|
||||
let mkFromJson (idx : Nat) (type : Expr) : TermElabM Syntax :=
|
||||
if type.isAppOf indVal.name then `(Lean.Parser.Term.doExpr| $fromJsonFuncId:ident jsons[$(quote idx)])
|
||||
else `(Lean.Parser.Term.doExpr| fromJson? jsons[$(quote idx)])
|
||||
let identNames := binders.map Prod.fst
|
||||
let fromJsons ← binders.mapIdxM fun idx (_, type) => mkFromJson idx type
|
||||
|
||||
let userNamesOpt ←
|
||||
if identNames.size == userNames.size then
|
||||
if binders.size == userNames.size then
|
||||
`(some #[$[$(userNames.map quote):ident],*])
|
||||
else `(none)
|
||||
let stx ←
|
||||
`((Json.parseTagged json $(quote ctor.getString!) $(quote ctorInfo.numFields) $(quote userNamesOpt)).bind
|
||||
(fun jsons => do
|
||||
$[let $identNames:ident ← fromJson? $jsonAccess]*
|
||||
return $(mkIdent ctor):ident $identNames*))
|
||||
$[let $identNames:ident ← $fromJsons]*
|
||||
return $(mkIdent ctor):ident $identNames*))
|
||||
(stx, ctorInfo.numFields)
|
||||
-- the smaller cases, especially the ones without fields are likely faster
|
||||
let alts := alts.qsort (fun (_, x) (_, y) => x < y)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ open Json in macro_rules
|
|||
`(Json.arr #[$fields,*])
|
||||
|
||||
def checkToJson [ToJson α] (obj : α) (rhs : Json) : MetaM Unit :=
|
||||
let lhs := (obj |> toJson).pretty
|
||||
let lhs := (obj |> toJson).pretty
|
||||
if lhs == rhs.pretty then
|
||||
()
|
||||
else
|
||||
|
|
@ -50,16 +50,16 @@ structure Foo where
|
|||
deriving ToJson, FromJson, Repr, BEq
|
||||
|
||||
#eval checkToJson { x := 1, y := "bla" : Foo} (json { y : "bla", x : 1 })
|
||||
#eval checkRoundTrip { x := 1, y := "bla" : Foo }
|
||||
#eval checkRoundTrip { x := 1, y := "bla" : Foo }
|
||||
|
||||
-- set_option trace.Elab.command true
|
||||
-- set_option trace.Elab.command true
|
||||
structure WInfo where
|
||||
a : Nat
|
||||
b : Nat
|
||||
deriving ToJson, FromJson, Repr, BEq
|
||||
|
||||
-- set_option trace.Elab.command true
|
||||
inductive E
|
||||
-- set_option trace.Elab.command true
|
||||
inductive E
|
||||
| W : WInfo → E
|
||||
| WAlt (a b : Nat)
|
||||
| X : Nat → Nat → E
|
||||
|
|
@ -85,3 +85,12 @@ deriving ToJson, FromJson, Repr, BEq
|
|||
|
||||
#eval checkToJson E.Z (json "Z")
|
||||
#eval checkRoundTrip E.Z
|
||||
|
||||
inductive ERec
|
||||
| mk : Nat → ERec
|
||||
| W : ERec → ERec
|
||||
deriving ToJson, FromJson, Repr, BEq
|
||||
|
||||
#eval checkToJson (ERec.W (ERec.mk 6)) (json { W : { mk : 6 }})
|
||||
#eval checkRoundTrip (ERec.mk 7)
|
||||
#eval checkRoundTrip (ERec.W (ERec.mk 8))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue