From cb4a73a48761950ec8ad93f6c9bd10346313bf35 Mon Sep 17 00:00:00 2001 From: Arthur Adjedj Date: Tue, 10 Sep 2024 10:55:52 +0200 Subject: [PATCH] refactor: `Lean.Elab.Deriving.FromToJson` (#5292) Refactors the derive handlers for `ToJson` and `FromJson` in preparation for #3160. This splits up the different parts of the handler according to how other similar handlers are implemented while keeping the original logic intact. This makes the changes necessary to adapt the file in #3160 much easier. --- src/Lean/Elab/Deriving/FromToJson.lean | 316 ++++++++++++++----------- 1 file changed, 184 insertions(+), 132 deletions(-) diff --git a/src/Lean/Elab/Deriving/FromToJson.lean b/src/Lean/Elab/Deriving/FromToJson.lean index 2f8f2ce5f8..c68e477d04 100644 --- a/src/Lean/Elab/Deriving/FromToJson.lean +++ b/src/Lean/Elab/Deriving/FromToJson.lean @@ -15,145 +15,109 @@ open Lean.Json open Lean.Parser.Term open Lean.Meta +def mkToJsonHeader (indVal : InductiveVal) : TermElabM Header := do + mkHeader ``ToJson 1 indVal + +def mkFromJsonHeader (indVal : InductiveVal) : TermElabM Header := do + let header ← mkHeader ``FromJson 0 indVal + let jsonArg ← `(bracketedBinderF|(json : Json)) + return {header with + binders := header.binders.push jsonArg} + def mkJsonField (n : Name) : CoreM (Bool × Term) := do let .str .anonymous s := n | throwError "invalid json field name {n}" let s₁ := s.dropRightWhile (· == '?') return (s != s₁, Syntax.mkStrLit s₁) -def mkToJsonInstance (declName : Name) : CommandElabM Bool := do - if isStructure (← getEnv) declName then - let cmds ← liftTermElabM do - let ctx ← mkContext "toJson" declName - let header ← mkHeader ``ToJson 1 ctx.typeInfos[0]! - let fields := getStructureFieldsFlattened (← getEnv) declName (includeSubobjectFields := false) - let fields ← fields.mapM fun field => do - let (isOptField, nm) ← mkJsonField field - let target := mkIdent header.targetNames[0]! - if isOptField then ``(opt $nm ($target).$(mkIdent field)) - else ``([($nm, toJson ($target).$(mkIdent field))]) - let cmd ← `(private def $(mkIdent ctx.auxFunNames[0]!):ident $header.binders:bracketedBinder* : Json := - mkObj <| List.join [$fields,*]) - return #[cmd] ++ (← mkInstanceCmds ctx ``ToJson #[declName]) - cmds.forM elabCommand - return true - else - let indVal ← getConstInfoInduct declName - let cmds ← liftTermElabM do - let ctx ← mkContext "toJson" declName - let toJsonFuncId := mkIdent ctx.auxFunNames[0]! - -- Return syntax to JSONify `id`, either via `ToJson` or recursively - -- if `id`'s type is the type we're deriving for. - let mkToJson (id : Ident) (type : Expr) : TermElabM Term := do +def mkToJsonBodyForStruct (header : Header) (indName : Name) : TermElabM Term := do + let fields := getStructureFieldsFlattened (← getEnv) indName (includeSubobjectFields := false) + let fields ← fields.mapM fun field => do + let (isOptField, nm) ← mkJsonField field + let target := mkIdent header.targetNames[0]! + if isOptField then ``(opt $nm $target.$(mkIdent field)) + else ``([($nm, toJson ($target).$(mkIdent field))]) + `(mkObj <| List.join [$fields,*]) + +def mkToJsonBodyForInduct (ctx : Context) (header : Header) (indName : Name) : TermElabM Term := do + let indVal ← getConstInfoInduct indName + let toJsonFuncId := mkIdent ctx.auxFunNames[0]! + -- Return syntax to JSONify `id`, either via `ToJson` or recursively + -- if `id`'s type is the type we're deriving for. + let mkToJson (id : Ident) (type : Expr) : TermElabM Term := do if type.isAppOf indVal.name then `($toJsonFuncId:ident $id:ident) else ``(toJson $id:ident) - let header ← mkHeader ``ToJson 1 ctx.typeInfos[0]! - let discrs ← mkDiscrs header indVal - let alts ← mkAlts indVal fun ctor args userNames => do - let ctorStr := ctor.name.eraseMacroScopes.getString! - match args, userNames with - | #[], _ => ``(toJson $(quote ctorStr)) - | #[(x, t)], none => ``(mkObj [($(quote ctorStr), $(← mkToJson x t))]) - | xs, none => - let xs ← xs.mapM fun (x, t) => mkToJson x t - ``(mkObj [($(quote ctorStr), Json.arr #[$[$xs:term],*])]) - | xs, some userNames => - let xs ← xs.mapIdxM fun idx (x, t) => do - `(($(quote userNames[idx]!.eraseMacroScopes.getString!), $(← mkToJson x t))) - ``(mkObj [($(quote ctorStr), mkObj [$[$xs:term],*])]) - let auxTerm ← `(match $[$discrs],* with $alts:matchAlt*) - let auxCmd ← - if ctx.usePartial then - let letDecls ← mkLocalInstanceLetDecls ctx ``ToJson header.argNames - let auxTerm ← mkLet letDecls auxTerm - `(private partial def $toJsonFuncId:ident $header.binders:bracketedBinder* : Json := $auxTerm) - else - `(private def $toJsonFuncId:ident $header.binders:bracketedBinder* : Json := $auxTerm) - return #[auxCmd] ++ (← mkInstanceCmds ctx ``ToJson #[declName]) - cmds.forM elabCommand - return true + let discrs ← mkDiscrs header indVal + let alts ← mkAlts indVal fun ctor args userNames => do + let ctorStr := ctor.name.eraseMacroScopes.getString! + match args, userNames with + | #[], _ => ``(toJson $(quote ctorStr)) + | #[(x, t)], none => ``(mkObj [($(quote ctorStr), $(← mkToJson x t))]) + | xs, none => + let xs ← xs.mapM fun (x, t) => mkToJson x t + ``(mkObj [($(quote ctorStr), Json.arr #[$[$xs:term],*])]) + | xs, some userNames => + let xs ← xs.mapIdxM fun idx (x, t) => do + `(($(quote userNames[idx]!.eraseMacroScopes.getString!), $(← mkToJson x t))) + ``(mkObj [($(quote ctorStr), mkObj [$[$xs:term],*])]) + `(match $[$discrs],* with $alts:matchAlt*) where mkAlts (indVal : InductiveVal) - (rhs : ConstructorVal → Array (Ident × Expr) → Option (Array Name) → TermElabM Term) : TermElabM (Array (TSyntax ``matchAlt)) := do - indVal.ctors.toArray.mapM fun ctor => do - let ctorInfo ← getConstInfoCtor ctor - forallTelescopeReducing ctorInfo.type fun xs _ => do - let mut patterns := #[] - -- add `_` pattern for indices - for _ in [:indVal.numIndices] do - patterns := patterns.push (← `(_)) - let mut ctorArgs := #[] - -- add `_` for inductive parameters, they are inaccessible - for _ in [:indVal.numParams] do - ctorArgs := ctorArgs.push (← `(_)) - -- bound constructor arguments and their types - let mut binders := #[] - let mut userNames := #[] - for i in [:ctorInfo.numFields] do - let x := xs[indVal.numParams + i]! - let localDecl ← x.fvarId!.getDecl - if !localDecl.userName.hasMacroScopes then - userNames := userNames.push localDecl.userName - let a := mkIdent (← mkFreshUserName `a) - binders := binders.push (a, localDecl.type) - ctorArgs := ctorArgs.push a - patterns := patterns.push (← `(@$(mkIdent ctorInfo.name):ident $ctorArgs:term*)) - let rhs ← rhs ctorInfo binders (if userNames.size == binders.size then some userNames else none) - `(matchAltExpr| | $[$patterns:term],* => $rhs:term) + (rhs : ConstructorVal → Array (Ident × Expr) → Option (Array Name) → TermElabM Term): TermElabM (Array (TSyntax ``matchAlt)) := do + let mut alts := #[] + for ctorName in indVal.ctors do + let ctorInfo ← getConstInfoCtor ctorName + let alt ← forallTelescopeReducing ctorInfo.type fun xs _ => do + let mut patterns := #[] + -- add `_` pattern for indices + for _ in [:indVal.numIndices] do + patterns := patterns.push (← `(_)) + let mut ctorArgs := #[] + -- add `_` for inductive parameters, they are inaccessible + for _ in [:indVal.numParams] do + ctorArgs := ctorArgs.push (← `(_)) + -- bound constructor arguments and their types + let mut binders := #[] + let mut userNames := #[] + for i in [:ctorInfo.numFields] do + let x := xs[indVal.numParams + i]! + let localDecl ← x.fvarId!.getDecl + if !localDecl.userName.hasMacroScopes then + userNames := userNames.push localDecl.userName + let a := mkIdent (← mkFreshUserName `a) + binders := binders.push (a, localDecl.type) + ctorArgs := ctorArgs.push a + patterns := patterns.push (← `(@$(mkIdent ctorInfo.name):ident $ctorArgs:term*)) + let rhs ← rhs ctorInfo binders (if userNames.size == binders.size then some userNames else none) + `(matchAltExpr| | $[$patterns:term],* => $rhs:term) + alts := alts.push alt + return alts -def mkFromJsonInstance (declName : Name) : CommandElabM Bool := do - if isStructure (← getEnv) declName then - let cmds ← liftTermElabM do - let ctx ← mkContext "fromJson" declName - let header ← mkHeader ``FromJson 0 ctx.typeInfos[0]! - let fields := getStructureFieldsFlattened (← getEnv) declName (includeSubobjectFields := false) - let getters ← fields.mapM (fun field => do - let getter ← `(getObjValAs? j _ $(Prod.snd <| ← mkJsonField field)) - let getter ← `(doElem| Except.mapError (fun s => (toString $(quote declName)) ++ "." ++ (toString $(quote field)) ++ ": " ++ s) <| $getter) - return getter - ) - let fields := fields.map mkIdent - let cmd ← `(private def $(mkIdent ctx.auxFunNames[0]!):ident $header.binders:bracketedBinder* (j : Json) - : Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) := do - $[let $fields:ident ← $getters]* - return { $[$fields:ident := $(id fields)],* }) - return #[cmd] ++ (← mkInstanceCmds ctx ``FromJson #[declName]) - cmds.forM elabCommand - return true - else - let indVal ← getConstInfoInduct declName - let cmds ← liftTermElabM do - let ctx ← mkContext "fromJson" declName - let header ← mkHeader ``FromJson 0 ctx.typeInfos[0]! - let fromJsonFuncId := mkIdent ctx.auxFunNames[0]! - let alts ← mkAlts indVal fromJsonFuncId - let mut auxTerm ← alts.foldrM (fun xs x => `(Except.orElseLazy $xs (fun _ => $x))) (← `(Except.error "no inductive constructor matched")) - if ctx.usePartial then - let letDecls ← mkLocalInstanceLetDecls ctx ``FromJson header.argNames - auxTerm ← mkLet letDecls auxTerm - -- FromJson is not structurally recursive even non-nested recursive inductives, - -- so we also use `partial` then. - let auxCmd ← - if ctx.usePartial || indVal.isRec then - `(private partial def $fromJsonFuncId:ident $header.binders:bracketedBinder* (json : Json) - : Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) := - $auxTerm) - else - `(private def $fromJsonFuncId:ident $header.binders:bracketedBinder* (json : Json) - : Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) := - $auxTerm) - return #[auxCmd] ++ (← mkInstanceCmds ctx ``FromJson #[declName]) - cmds.forM elabCommand - return true +def mkFromJsonBodyForStruct (indName : Name) : TermElabM Term := do + let fields := getStructureFieldsFlattened (← getEnv) indName (includeSubobjectFields := false) + let getters ← fields.mapM (fun field => do + let getter ← `(getObjValAs? json _ $(Prod.snd <| ← mkJsonField field)) + let getter ← `(doElem| Except.mapError (fun s => (toString $(quote indName)) ++ "." ++ (toString $(quote field)) ++ ": " ++ s) <| $getter) + return getter + ) + let fields := fields.map mkIdent + `(do + $[let $fields:ident ← $getters]* + return { $[$fields:ident := $(id fields)],* }) +def mkFromJsonBodyForInduct (ctx : Context) (indName : Name) : TermElabM Term := do + let indVal ← getConstInfoInduct indName + let alts ← mkAlts indVal + let auxTerm ← alts.foldrM (fun xs x => `(Except.orElseLazy $xs (fun _ => $x))) (← `(Except.error "no inductive constructor matched")) + `($auxTerm) where - mkAlts (indVal : InductiveVal) (fromJsonFuncId : Ident) : TermElabM (Array Term) := do - let alts ← - indVal.ctors.toArray.mapM fun ctor => do - let ctorInfo ← getConstInfoCtor ctor - forallTelescopeReducing ctorInfo.type fun xs _ => do - let mut binders := #[] + mkAlts (indVal : InductiveVal) : TermElabM (Array Term) := do + let mut alts := #[] + for ctorName in indVal.ctors do + let ctorInfo ← getConstInfoCtor ctorName + let alt ← do forallTelescopeReducing ctorInfo.type fun xs _ => do + let mut binders := #[] let mut userNames := #[] for i in [:ctorInfo.numFields] do let x := xs[indVal.numParams + i]! @@ -162,7 +126,7 @@ where userNames := userNames.push localDecl.userName let a := mkIdent (← mkFreshUserName `a) binders := binders.push (a, localDecl.type) - + let fromJsonFuncId := mkIdent ctx.auxFunNames[0]! -- Return syntax to parse `id`, either via `FromJson` or recursively -- if `id`'s type is the type we're deriving for. let mkFromJson (idx : Nat) (type : Expr) : TermElabM (TSyntax ``doExpr) := @@ -175,23 +139,111 @@ where else ``(none) let stx ← - `((Json.parseTagged json $(quote ctor.eraseMacroScopes.getString!) $(quote ctorInfo.numFields) $(quote userNamesOpt)).bind + `((Json.parseTagged json $(quote ctorName.eraseMacroScopes.getString!) $(quote ctorInfo.numFields) $(quote userNamesOpt)).bind (fun jsons => do $[let $identNames:ident ← $fromJsons:doExpr]* - return $(mkIdent ctor):ident $identNames*)) + return $(mkIdent ctorName):ident $identNames*)) pure (stx, ctorInfo.numFields) + alts := alts.push alt -- the smaller cases, especially the ones without fields are likely faster - let alts := alts.qsort (fun (_, x) (_, y) => x < y) - return alts.map Prod.fst + let alts' := alts.qsort (fun (_, x) (_, y) => x < y) + return alts'.map Prod.fst + +def mkToJsonBody (ctx : Context) (header : Header) (e : Expr): TermElabM Term := do + let indName := e.getAppFn.constName! + if isStructure (← getEnv) indName then + mkToJsonBodyForStruct header indName + else + mkToJsonBodyForInduct ctx header indName + +def mkToJsonAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do + let auxFunName := ctx.auxFunNames[i]! + let header ← mkToJsonHeader ctx.typeInfos[i]! + let binders := header.binders + Term.elabBinders binders fun _ => do + let type ← Term.elabTerm header.targetType none + let mut body ← mkToJsonBody ctx header type + if ctx.usePartial then + let letDecls ← mkLocalInstanceLetDecls ctx ``ToJson header.argNames + body ← mkLet letDecls body + `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term) + else + `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term) + +def mkFromJsonBody (ctx : Context) (e : Expr) : TermElabM Term := do + let indName := e.getAppFn.constName! + if isStructure (← getEnv) indName then + mkFromJsonBodyForStruct indName + else + mkFromJsonBodyForInduct ctx indName + +def mkFromJsonAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do + let auxFunName := ctx.auxFunNames[i]! + let indval := ctx.typeInfos[i]! + let header ← mkFromJsonHeader indval --TODO fix header info + let binders := header.binders + Term.elabBinders binders fun _ => do + let type ← Term.elabTerm header.targetType none + let mut body ← mkFromJsonBody ctx type + if ctx.usePartial || indval.isRec then + let letDecls ← mkLocalInstanceLetDecls ctx ``FromJson header.argNames + body ← mkLet letDecls body + `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← mkInductiveApp ctx.typeInfos[i]! header.argNames) := $body:term) + else + `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← mkInductiveApp ctx.typeInfos[i]! header.argNames) := $body:term) + + +def mkToJsonMutualBlock (ctx : Context) : TermElabM Command := do + let mut auxDefs := #[] + for i in [:ctx.typeInfos.size] do + auxDefs := auxDefs.push (← mkToJsonAuxFunction ctx i) + `(mutual + $auxDefs:command* + end) + +def mkFromJsonMutualBlock (ctx : Context) : TermElabM Command := do + let mut auxDefs := #[] + for i in [:ctx.typeInfos.size] do + auxDefs := auxDefs.push (← mkFromJsonAuxFunction ctx i) + `(mutual + $auxDefs:command* + end) + +private def mkToJsonInstance (declName : Name) : TermElabM (Array Command) := do + let ctx ← mkContext "toJson" declName + let cmds := #[← mkToJsonMutualBlock ctx] ++ (← mkInstanceCmds ctx ``ToJson #[declName]) + trace[Elab.Deriving.toJson] "\n{cmds}" + return cmds + +private def mkFromJsonInstance (declName : Name) : TermElabM (Array Command) := do + let ctx ← mkContext "fromJson" declName + let cmds := #[← mkFromJsonMutualBlock ctx] ++ (← mkInstanceCmds ctx ``FromJson #[declName]) + trace[Elab.Deriving.fromJson] "\n{cmds}" + return cmds def mkToJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do - declNames.foldlM (fun b n => andM (pure b) (mkToJsonInstance n)) true + if (← declNames.allM isInductive) && declNames.size > 0 then + for declName in declNames do + let cmds ← liftTermElabM <| mkToJsonInstance declName + cmds.forM elabCommand + return true + else + return false def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do - declNames.foldlM (fun b n => andM (pure b) (mkFromJsonInstance n)) true + if (← declNames.allM isInductive) && declNames.size > 0 then + for declName in declNames do + let cmds ← liftTermElabM <| mkFromJsonInstance declName + cmds.forM elabCommand + return true + else + return false builtin_initialize registerDerivingHandler ``ToJson mkToJsonInstanceHandler registerDerivingHandler ``FromJson mkFromJsonInstanceHandler + registerTraceClass `Elab.Deriving.toJson + registerTraceClass `Elab.Deriving.fromJson + end Lean.Elab.Deriving.FromToJson