feat: FromToJson for recursive inductives

This commit is contained in:
Wojciech Nawrocki 2021-07-23 21:50:50 -07:00 committed by Sebastian Ullrich
parent 4073b20b7d
commit 43190e0e63
3 changed files with 56 additions and 27 deletions

View file

@ -1,3 +1,4 @@
{
"files.insertFinalNewline": true
"files.insertFinalNewline": true,
"files.trimTrailingWhitespace": true
}

View file

@ -39,20 +39,27 @@ def mkToJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
let indVal ← getConstInfoInduct declNames[0]
let cmds ← liftTermElabM none <| do
let ctx ← mkContext "toJson" declNames[0]
let toJsonFuncId := mkIdent ctx.auxFunNames[0]
-- Return syntax to JSONify `id`, either via `ToJson` or recursively
-- if `id`'s type is the type we're deriving for.
let mkToJson (id : Syntax) (type : Expr) : TermElabM Syntax := do
if type.isAppOf indVal.name then `($toJsonFuncId:ident $id:ident)
else `(toJson $id:ident)
let header ← mkHeader ctx ``ToJson 1 ctx.typeInfos[0]
let discrs ← mkDiscrs header indVal
let alts ← mkAlts indVal fun ctor args userNames =>
let alts ← mkAlts indVal fun ctor args userNames => do
match args, userNames with
| #[], _ => `(toJson $(quote ctor.name.getString!))
| #[x], none => `(mkObj [($(quote ctor.name.getString!), toJson $x)])
| xs, none => do
let xs ← xs.mapM fun x => `(toJson $x)
`(mkObj [($(quote ctor.name.getString!), toJson #[$[$xs:term],*])])
| xs, some userNames => do
let xs ← xs.mapIdxM fun idx x => `(($(quote userNames[idx].getString!), toJson $x))
| #[(x, t)], none => `(mkObj [($(quote ctor.name.getString!), $(← mkToJson x t))])
| xs, none =>
let xs ← xs.mapM fun (x, t) => mkToJson x t
`(mkObj [($(quote ctor.name.getString!), Json.arr #[$[$xs:term],*])])
| xs, some userNames =>
let xs ← xs.mapIdxM fun idx (x, t) => do
`(($(quote userNames[idx].getString!), $(← mkToJson x t)))
`(mkObj [($(quote ctor.name.getString!), mkObj [$[$xs:term],*])])
let auxCmd ← `(match $[$discrs],* with $alts:matchAlt*)
let auxCmd ← `(private def $(mkIdent ctx.auxFunNames[0]):ident $header.binders:explicitBinder* := $auxCmd)
let auxCmd ← `(private def $toJsonFuncId:ident $header.binders:explicitBinder* := $auxCmd)
return #[auxCmd] ++ (← mkInstanceCmds ctx ``ToJson declNames)
cmds.forM elabCommand
return true
@ -61,7 +68,7 @@ def mkToJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
where
mkAlts
(indVal : InductiveVal)
(rhs : ConstructorVal → Array Syntax → (Option $ Array Name) → TermElabM Syntax) : TermElabM (Array Syntax) := do
(rhs : ConstructorVal → Array (Syntax × Expr) → (Option $ Array Name) → TermElabM Syntax) : TermElabM (Array Syntax) := do
indVal.ctors.toArray.mapM fun ctor => do
let ctorInfo ← getConstInfoCtor ctor
forallTelescopeReducing ctorInfo.type fun xs type => do
@ -73,7 +80,8 @@ where
-- add `_` for inductive parameters, they are inaccessible
for i in [:indVal.numParams] do
ctorArgs := ctorArgs.push (← `(_))
let mut identNames := #[]
-- bound constructor arguments and their types
let mut binders := #[]
let mut userNames := #[]
for i in [:ctorInfo.numFields] do
let x := xs[indVal.numParams + i]
@ -81,12 +89,13 @@ where
if !localDecl.userName.hasMacroScopes then
userNames := userNames.push localDecl.userName
let a := mkIdent (← mkFreshUserName `a)
identNames := identNames.push a
binders := binders.push (a, localDecl.type)
ctorArgs := ctorArgs.push a
patterns := patterns.push (← `(@$(mkIdent ctorInfo.name):ident $ctorArgs:term*))
let rhs ← rhs ctorInfo identNames (if userNames.size == identNames.size then some userNames else none)
let rhs ← rhs ctorInfo binders (if userNames.size == binders.size then some userNames else none)
`(matchAltExpr| | $[$patterns:term],* => $rhs:term)
def declModsF := Parser.Command.declModifiers false
def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if declNames.size == 1 then
if (← isStructure (← getEnv) declNames[0]) then
@ -108,10 +117,12 @@ def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
let cmds ← liftTermElabM none <| do
let ctx ← mkContext "fromJson" declNames[0]
let header ← mkHeader ctx ``FromJson 0 ctx.typeInfos[0]
let fromJsonFuncId := mkIdent ctx.auxFunNames[0]
let discrs ← mkDiscrs header indVal
let alts ← mkAlts indVal
let alts ← mkAlts indVal fromJsonFuncId
let matchCmd ← alts.foldrM (fun xs x => `($xs <|> $x)) (←`(Except.error "no inductive constructor matched"))
let cmd ← `(private def $(mkIdent ctx.auxFunNames[0]):ident $header.binders:explicitBinder* (json : Json)
let declMods ← if indVal.isRec then `(declModsF| private partial) else `(declModsF| private)
let cmd ← `($declMods:declModifiers def $fromJsonFuncId:ident $header.binders:explicitBinder* (json : Json)
: Except String $(← mkInductiveApp ctx.typeInfos[0] header.argNames) :=
$matchCmd )
return #[cmd] ++ (← mkInstanceCmds ctx ``FromJson declNames)
@ -120,12 +131,12 @@ def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
else
return false
where
mkAlts (indVal : InductiveVal) : TermElabM (Array Syntax) := do
mkAlts (indVal : InductiveVal) (fromJsonFuncId : Syntax) : TermElabM (Array Syntax) := do
let alts ←
indVal.ctors.toArray.mapM fun ctor => do
let ctorInfo ← getConstInfoCtor ctor
forallTelescopeReducing ctorInfo.type fun xs type => do
let mut identNames := #[]
let mut binders := #[]
let mut userNames := #[]
for i in [:ctorInfo.numFields] do
let x := xs[indVal.numParams + i]
@ -133,17 +144,25 @@ where
if !localDecl.userName.hasMacroScopes then
userNames := userNames.push localDecl.userName
let a := mkIdent (← mkFreshUserName `a)
identNames := identNames.push a
let jsonAccess ← identNames.mapIdxM (fun idx _ => `(jsons[$(quote idx.val)]))
binders := binders.push (a, localDecl.type)
-- Return syntax to parse `id`, either via `FromJson` or recursively
-- if `id`'s type is the type we're deriving for.
let mkFromJson (idx : Nat) (type : Expr) : TermElabM Syntax :=
if type.isAppOf indVal.name then `(Lean.Parser.Term.doExpr| $fromJsonFuncId:ident jsons[$(quote idx)])
else `(Lean.Parser.Term.doExpr| fromJson? jsons[$(quote idx)])
let identNames := binders.map Prod.fst
let fromJsons ← binders.mapIdxM fun idx (_, type) => mkFromJson idx type
let userNamesOpt ←
if identNames.size == userNames.size then
if binders.size == userNames.size then
`(some #[$[$(userNames.map quote):ident],*])
else `(none)
let stx ←
`((Json.parseTagged json $(quote ctor.getString!) $(quote ctorInfo.numFields) $(quote userNamesOpt)).bind
(fun jsons => do
$[let $identNames:ident ← fromJson? $jsonAccess]*
return $(mkIdent ctor):ident $identNames*))
$[let $identNames:ident ← $fromJsons]*
return $(mkIdent ctor):ident $identNames*))
(stx, ctorInfo.numFields)
-- the smaller cases, especially the ones without fields are likely faster
let alts := alts.qsort (fun (_, x) (_, y) => x < y)

View file

@ -27,7 +27,7 @@ open Json in macro_rules
`(Json.arr #[$fields,*])
def checkToJson [ToJson α] (obj : α) (rhs : Json) : MetaM Unit :=
let lhs := (obj |> toJson).pretty
let lhs := (obj |> toJson).pretty
if lhs == rhs.pretty then
()
else
@ -50,16 +50,16 @@ structure Foo where
deriving ToJson, FromJson, Repr, BEq
#eval checkToJson { x := 1, y := "bla" : Foo} (json { y : "bla", x : 1 })
#eval checkRoundTrip { x := 1, y := "bla" : Foo }
#eval checkRoundTrip { x := 1, y := "bla" : Foo }
-- set_option trace.Elab.command true
-- set_option trace.Elab.command true
structure WInfo where
a : Nat
b : Nat
deriving ToJson, FromJson, Repr, BEq
-- set_option trace.Elab.command true
inductive E
-- set_option trace.Elab.command true
inductive E
| W : WInfo → E
| WAlt (a b : Nat)
| X : Nat → Nat → E
@ -85,3 +85,12 @@ deriving ToJson, FromJson, Repr, BEq
#eval checkToJson E.Z (json "Z")
#eval checkRoundTrip E.Z
inductive ERec
| mk : Nat → ERec
| W : ERec → ERec
deriving ToJson, FromJson, Repr, BEq
#eval checkToJson (ERec.W (ERec.mk 6)) (json { W : { mk : 6 }})
#eval checkRoundTrip (ERec.mk 7)
#eval checkRoundTrip (ERec.W (ERec.mk 8))