From b7c3ff6e6d6ffa52fac42e333a39ab7950e1e7d7 Mon Sep 17 00:00:00 2001 From: Arthur Adjedj Date: Tue, 9 Jan 2024 08:42:06 +0100 Subject: [PATCH] fix: manage all declarations in a given derive (#3058) Closes #3057 --- src/Lean/Elab/Deriving/BEq.lean | 24 +-- src/Lean/Elab/Deriving/DecEq.lean | 11 +- src/Lean/Elab/Deriving/FromToJson.lean | 201 +++++++++++++------------ src/Lean/Elab/Deriving/Hashable.lean | 13 +- src/Lean/Elab/Deriving/Ord.lean | 13 +- src/Lean/Elab/Deriving/Repr.lean | 13 +- src/Lean/Elab/Deriving/SizeOf.lean | 5 +- tests/lean/3057.lean | 26 ++++ tests/lean/3057.lean.expected.out | 10 ++ 9 files changed, 182 insertions(+), 134 deletions(-) create mode 100644 tests/lean/3057.lean create mode 100644 tests/lean/3057.lean.expected.out diff --git a/src/Lean/Elab/Deriving/BEq.lean b/src/Lean/Elab/Deriving/BEq.lean index a3963f056b..7ebe5cbbdf 100644 --- a/src/Lean/Elab/Deriving/BEq.lean +++ b/src/Lean/Elab/Deriving/BEq.lean @@ -91,9 +91,9 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do $auxDefs:command* end) -private def mkBEqInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do - let ctx ← mkContext "beq" declNames[0]! - let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq declNames) +private def mkBEqInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do + let ctx ← mkContext "beq" declName + let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq #[declName]) trace[Elab.Deriving.beq] "\n{cmds}" return cmds @@ -109,14 +109,18 @@ private def mkBEqEnumCmd (name : Name): TermElabM (Array Syntax) := do open Command +def mkBEqInstance (declName : Name) : CommandElabM Unit := do + let cmds ← liftTermElabM <| + if (← isEnumType declName) then + mkBEqEnumCmd declName + else + mkBEqInstanceCmds declName + cmds.forM elabCommand + def mkBEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do - if declNames.size == 1 && (← isEnumType declNames[0]!) then - let cmds ← liftTermElabM <| mkBEqEnumCmd declNames[0]! - cmds.forM elabCommand - return true - else if (← declNames.allM isInductive) && declNames.size > 0 then - let cmds ← liftTermElabM <| mkBEqInstanceCmds declNames - cmds.forM elabCommand + if (← declNames.allM isInductive) then + for declName in declNames do + mkBEqInstance declName return true else return false diff --git a/src/Lean/Elab/Deriving/DecEq.lean b/src/Lean/Elab/Deriving/DecEq.lean index a42f1590a3..09b0a83cdf 100644 --- a/src/Lean/Elab/Deriving/DecEq.lean +++ b/src/Lean/Elab/Deriving/DecEq.lean @@ -186,12 +186,15 @@ def mkDecEqEnum (declName : Name) : CommandElabM Unit := do trace[Elab.Deriving.decEq] "\n{cmd}" elabCommand cmd -def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do - if (← isEnumType declNames[0]!) then - mkDecEqEnum declNames[0]! +def mkDecEqInstance (declName : Name) : CommandElabM Bool := do + if (← isEnumType declName) then + mkDecEqEnum declName return true else - mkDecEq declNames[0]! + mkDecEq declName + +def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do + declNames.foldlM (fun b n => andM (pure b) (mkDecEqInstance n)) true builtin_initialize registerDerivingHandler `DecidableEq mkDecEqInstanceHandler diff --git a/src/Lean/Elab/Deriving/FromToJson.lean b/src/Lean/Elab/Deriving/FromToJson.lean index 5674084289..c91dd4e234 100644 --- a/src/Lean/Elab/Deriving/FromToJson.lean +++ b/src/Lean/Elab/Deriving/FromToJson.lean @@ -19,60 +19,58 @@ def mkJsonField (n : Name) : CoreM (Bool × Term) := do let s₁ := s.dropRightWhile (· == '?') return (s != s₁, Syntax.mkStrLit s₁) -def mkToJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do - if declNames.size == 1 then - if isStructure (← getEnv) declNames[0]! then - let cmds ← liftTermElabM do - let ctx ← mkContext "toJson" declNames[0]! - let header ← mkHeader ``ToJson 1 ctx.typeInfos[0]! - let fields := getStructureFieldsFlattened (← getEnv) declNames[0]! (includeSubobjectFields := false) - let fields ← fields.mapM fun field => do - let (isOptField, nm) ← mkJsonField field - let target := mkIdent header.targetNames[0]! - if isOptField then ``(opt $nm ($target).$(mkIdent field)) - else ``([($nm, toJson ($target).$(mkIdent field))]) - let cmd ← `(private def $(mkIdent ctx.auxFunNames[0]!):ident $header.binders:bracketedBinder* : Json := - mkObj <| List.join [$fields,*]) - return #[cmd] ++ (← mkInstanceCmds ctx ``ToJson declNames) - cmds.forM elabCommand - return true - else - let indVal ← getConstInfoInduct declNames[0]! - let cmds ← liftTermElabM do - let ctx ← mkContext "toJson" declNames[0]! - let toJsonFuncId := mkIdent ctx.auxFunNames[0]! - -- Return syntax to JSONify `id`, either via `ToJson` or recursively - -- if `id`'s type is the type we're deriving for. - let mkToJson (id : Ident) (type : Expr) : TermElabM Term := do - if type.isAppOf indVal.name then `($toJsonFuncId:ident $id:ident) - else ``(toJson $id:ident) - let header ← mkHeader ``ToJson 1 ctx.typeInfos[0]! - let discrs ← mkDiscrs header indVal - let alts ← mkAlts indVal fun ctor args userNames => do - let ctorStr := ctor.name.eraseMacroScopes.getString! - match args, userNames with - | #[], _ => ``(toJson $(quote ctorStr)) - | #[(x, t)], none => ``(mkObj [($(quote ctorStr), $(← mkToJson x t))]) - | xs, none => - let xs ← xs.mapM fun (x, t) => mkToJson x t - ``(mkObj [($(quote ctorStr), Json.arr #[$[$xs:term],*])]) - | xs, some userNames => - let xs ← xs.mapIdxM fun idx (x, t) => do - `(($(quote userNames[idx]!.eraseMacroScopes.getString!), $(← mkToJson x t))) - ``(mkObj [($(quote ctorStr), mkObj [$[$xs:term],*])]) - let auxTerm ← `(match $[$discrs],* with $alts:matchAlt*) - let auxCmd ← - if ctx.usePartial then - let letDecls ← mkLocalInstanceLetDecls ctx ``ToJson header.argNames - let auxTerm ← mkLet letDecls auxTerm - `(private partial def $toJsonFuncId:ident $header.binders:bracketedBinder* : Json := $auxTerm) - else - `(private def $toJsonFuncId:ident $header.binders:bracketedBinder* : Json := $auxTerm) - return #[auxCmd] ++ (← mkInstanceCmds ctx ``ToJson declNames) - cmds.forM elabCommand - return true +def mkToJsonInstance (declName : Name) : CommandElabM Bool := do + if isStructure (← getEnv) declName then + let cmds ← liftTermElabM do + let ctx ← mkContext "toJson" declName + let header ← mkHeader ``ToJson 1 ctx.typeInfos[0]! + let fields := getStructureFieldsFlattened (← getEnv) declName (includeSubobjectFields := false) + let fields ← fields.mapM fun field => do + let (isOptField, nm) ← mkJsonField field + let target := mkIdent header.targetNames[0]! + if isOptField then ``(opt $nm ($target).$(mkIdent field)) + else ``([($nm, toJson ($target).$(mkIdent field))]) + let cmd ← `(private def $(mkIdent ctx.auxFunNames[0]!):ident $header.binders:bracketedBinder* : Json := + mkObj <| List.join [$fields,*]) + return #[cmd] ++ (← mkInstanceCmds ctx ``ToJson #[declName]) + cmds.forM elabCommand + return true else - return false + let indVal ← getConstInfoInduct declName + let cmds ← liftTermElabM do + let ctx ← mkContext "toJson" declName + let toJsonFuncId := mkIdent ctx.auxFunNames[0]! + -- Return syntax to JSONify `id`, either via `ToJson` or recursively + -- if `id`'s type is the type we're deriving for. + let mkToJson (id : Ident) (type : Expr) : TermElabM Term := do + if type.isAppOf indVal.name then `($toJsonFuncId:ident $id:ident) + else ``(toJson $id:ident) + let header ← mkHeader ``ToJson 1 ctx.typeInfos[0]! + let discrs ← mkDiscrs header indVal + let alts ← mkAlts indVal fun ctor args userNames => do + let ctorStr := ctor.name.eraseMacroScopes.getString! + match args, userNames with + | #[], _ => ``(toJson $(quote ctorStr)) + | #[(x, t)], none => ``(mkObj [($(quote ctorStr), $(← mkToJson x t))]) + | xs, none => + let xs ← xs.mapM fun (x, t) => mkToJson x t + ``(mkObj [($(quote ctorStr), Json.arr #[$[$xs:term],*])]) + | xs, some userNames => + let xs ← xs.mapIdxM fun idx (x, t) => do + `(($(quote userNames[idx]!.eraseMacroScopes.getString!), $(← mkToJson x t))) + ``(mkObj [($(quote ctorStr), mkObj [$[$xs:term],*])]) + let auxTerm ← `(match $[$discrs],* with $alts:matchAlt*) + let auxCmd ← + if ctx.usePartial then + let letDecls ← mkLocalInstanceLetDecls ctx ``ToJson header.argNames + let auxTerm ← mkLet letDecls auxTerm + `(private partial def $toJsonFuncId:ident $header.binders:bracketedBinder* : Json := $auxTerm) + else + `(private def $toJsonFuncId:ident $header.binders:bracketedBinder* : Json := $auxTerm) + return #[auxCmd] ++ (← mkInstanceCmds ctx ``ToJson #[declName]) + cmds.forM elabCommand + return true + where mkAlts (indVal : InductiveVal) @@ -103,54 +101,51 @@ where let rhs ← rhs ctorInfo binders (if userNames.size == binders.size then some userNames else none) `(matchAltExpr| | $[$patterns:term],* => $rhs:term) -def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do - if declNames.size == 1 then - let declName := declNames[0]! - if isStructure (← getEnv) declName then - let cmds ← liftTermElabM do - let ctx ← mkContext "fromJson" declName - let header ← mkHeader ``FromJson 0 ctx.typeInfos[0]! - let fields := getStructureFieldsFlattened (← getEnv) declName (includeSubobjectFields := false) - let getters ← fields.mapM (fun field => do - let getter ← `(getObjValAs? j _ $(Prod.snd <| ← mkJsonField field)) - let getter ← `(doElem| Except.mapError (fun s => (toString $(quote declName)) ++ "." ++ (toString $(quote field)) ++ ": " ++ s) <| $getter) - return getter - ) - let fields := fields.map mkIdent - let cmd ← `(private def $(mkIdent ctx.auxFunNames[0]!):ident $header.binders:bracketedBinder* (j : Json) - : Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) := do - $[let $fields:ident ← $getters]* - return { $[$fields:ident := $(id fields)],* }) - return #[cmd] ++ (← mkInstanceCmds ctx ``FromJson declNames) - cmds.forM elabCommand - return true - else - let indVal ← getConstInfoInduct declName - let cmds ← liftTermElabM do - let ctx ← mkContext "fromJson" declName - let header ← mkHeader ``FromJson 0 ctx.typeInfos[0]! - let fromJsonFuncId := mkIdent ctx.auxFunNames[0]! - let alts ← mkAlts indVal fromJsonFuncId - let mut auxTerm ← alts.foldrM (fun xs x => `(Except.orElseLazy $xs (fun _ => $x))) (← `(Except.error "no inductive constructor matched")) - if ctx.usePartial then - let letDecls ← mkLocalInstanceLetDecls ctx ``FromJson header.argNames - auxTerm ← mkLet letDecls auxTerm - -- FromJson is not structurally recursive even non-nested recursive inductives, - -- so we also use `partial` then. - let auxCmd ← - if ctx.usePartial || indVal.isRec then - `(private partial def $fromJsonFuncId:ident $header.binders:bracketedBinder* (json : Json) - : Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) := - $auxTerm) - else - `(private def $fromJsonFuncId:ident $header.binders:bracketedBinder* (json : Json) - : Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) := - $auxTerm) - return #[auxCmd] ++ (← mkInstanceCmds ctx ``FromJson declNames) - cmds.forM elabCommand - return true +def mkFromJsonInstance (declName : Name) : CommandElabM Bool := do + if isStructure (← getEnv) declName then + let cmds ← liftTermElabM do + let ctx ← mkContext "fromJson" declName + let header ← mkHeader ``FromJson 0 ctx.typeInfos[0]! + let fields := getStructureFieldsFlattened (← getEnv) declName (includeSubobjectFields := false) + let getters ← fields.mapM (fun field => do + let getter ← `(getObjValAs? j _ $(Prod.snd <| ← mkJsonField field)) + let getter ← `(doElem| Except.mapError (fun s => (toString $(quote declName)) ++ "." ++ (toString $(quote field)) ++ ": " ++ s) <| $getter) + return getter + ) + let fields := fields.map mkIdent + let cmd ← `(private def $(mkIdent ctx.auxFunNames[0]!):ident $header.binders:bracketedBinder* (j : Json) + : Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) := do + $[let $fields:ident ← $getters]* + return { $[$fields:ident := $(id fields)],* }) + return #[cmd] ++ (← mkInstanceCmds ctx ``FromJson #[declName]) + cmds.forM elabCommand + return true else - return false + let indVal ← getConstInfoInduct declName + let cmds ← liftTermElabM do + let ctx ← mkContext "fromJson" declName + let header ← mkHeader ``FromJson 0 ctx.typeInfos[0]! + let fromJsonFuncId := mkIdent ctx.auxFunNames[0]! + let alts ← mkAlts indVal fromJsonFuncId + let mut auxTerm ← alts.foldrM (fun xs x => `(Except.orElseLazy $xs (fun _ => $x))) (← `(Except.error "no inductive constructor matched")) + if ctx.usePartial then + let letDecls ← mkLocalInstanceLetDecls ctx ``FromJson header.argNames + auxTerm ← mkLet letDecls auxTerm + -- FromJson is not structurally recursive even non-nested recursive inductives, + -- so we also use `partial` then. + let auxCmd ← + if ctx.usePartial || indVal.isRec then + `(private partial def $fromJsonFuncId:ident $header.binders:bracketedBinder* (json : Json) + : Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) := + $auxTerm) + else + `(private def $fromJsonFuncId:ident $header.binders:bracketedBinder* (json : Json) + : Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) := + $auxTerm) + return #[auxCmd] ++ (← mkInstanceCmds ctx ``FromJson #[declName]) + cmds.forM elabCommand + return true + where mkAlts (indVal : InductiveVal) (fromJsonFuncId : Ident) : TermElabM (Array Term) := do let alts ← @@ -188,6 +183,12 @@ where let alts := alts.qsort (fun (_, x) (_, y) => x < y) return alts.map Prod.fst +def mkToJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do + declNames.foldlM (fun b n => andM (pure b) (mkToJsonInstance n)) true + +def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do + declNames.foldlM (fun b n => andM (pure b) (mkFromJsonInstance n)) true + builtin_initialize registerDerivingHandler ``ToJson mkToJsonInstanceHandler registerDerivingHandler ``FromJson mkFromJsonInstanceHandler diff --git a/src/Lean/Elab/Deriving/Hashable.lean b/src/Lean/Elab/Deriving/Hashable.lean index 434d2bfa59..e6b4612178 100644 --- a/src/Lean/Elab/Deriving/Hashable.lean +++ b/src/Lean/Elab/Deriving/Hashable.lean @@ -75,16 +75,17 @@ def mkHashFuncs (ctx : Context) : TermElabM Syntax := do auxDefs := auxDefs.push (← mkAuxFunction ctx i) `(mutual $auxDefs:command* end) -private def mkHashableInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do - let ctx ← mkContext "hash" declNames[0]! - let cmds := #[← mkHashFuncs ctx] ++ (← mkInstanceCmds ctx `Hashable declNames) +private def mkHashableInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do + let ctx ← mkContext "hash" declName + let cmds := #[← mkHashFuncs ctx] ++ (← mkInstanceCmds ctx `Hashable #[declName]) trace[Elab.Deriving.hashable] "\n{cmds}" return cmds def mkHashableHandler (declNames : Array Name) : CommandElabM Bool := do - if (← declNames.allM isInductive) && declNames.size > 0 then - let cmds ← liftTermElabM <| mkHashableInstanceCmds declNames - cmds.forM elabCommand + if (← declNames.allM isInductive) then + for declName in declNames do + let cmds ← liftTermElabM <| mkHashableInstanceCmds declName + cmds.forM elabCommand return true else return false diff --git a/src/Lean/Elab/Deriving/Ord.lean b/src/Lean/Elab/Deriving/Ord.lean index a59c371009..5b239e9ad3 100644 --- a/src/Lean/Elab/Deriving/Ord.lean +++ b/src/Lean/Elab/Deriving/Ord.lean @@ -86,18 +86,19 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do $auxDefs:command* end) -private def mkOrdInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do - let ctx ← mkContext "ord" declNames[0]! - let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Ord declNames) +private def mkOrdInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do + let ctx ← mkContext "ord" declName + let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Ord #[declName]) trace[Elab.Deriving.ord] "\n{cmds}" return cmds open Command def mkOrdInstanceHandler (declNames : Array Name) : CommandElabM Bool := do - if (← declNames.allM isInductive) && declNames.size > 0 then - let cmds ← liftTermElabM <| mkOrdInstanceCmds declNames - cmds.forM elabCommand + if (← declNames.allM isInductive) then + for declName in declNames do + let cmds ← liftTermElabM <| mkOrdInstanceCmds declName + cmds.forM elabCommand return true else return false diff --git a/src/Lean/Elab/Deriving/Repr.lean b/src/Lean/Elab/Deriving/Repr.lean index 461d21036a..871e6be3f1 100644 --- a/src/Lean/Elab/Deriving/Repr.lean +++ b/src/Lean/Elab/Deriving/Repr.lean @@ -104,18 +104,19 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do $auxDefs:command* end) -private def mkReprInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do - let ctx ← mkContext "repr" declNames[0]! - let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Repr declNames) +private def mkReprInstanceCmd (declName : Name) : TermElabM (Array Syntax) := do + let ctx ← mkContext "repr" declName + let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Repr #[declName]) trace[Elab.Deriving.repr] "\n{cmds}" return cmds open Command def mkReprInstanceHandler (declNames : Array Name) : CommandElabM Bool := do - if (← declNames.allM isInductive) && declNames.size > 0 then - let cmds ← liftTermElabM <| mkReprInstanceCmds declNames - cmds.forM elabCommand + if (← declNames.allM isInductive) then + for declName in declNames do + let cmds ← liftTermElabM <| mkReprInstanceCmd declName + cmds.forM elabCommand return true else return false diff --git a/src/Lean/Elab/Deriving/SizeOf.lean b/src/Lean/Elab/Deriving/SizeOf.lean index 0d904db525..e082f8869b 100644 --- a/src/Lean/Elab/Deriving/SizeOf.lean +++ b/src/Lean/Elab/Deriving/SizeOf.lean @@ -16,8 +16,9 @@ namespace Lean.Elab.Deriving.SizeOf open Command def mkSizeOfHandler (declNames : Array Name) : CommandElabM Bool := do - if (← declNames.allM isInductive) && declNames.size > 0 then - liftTermElabM <| Meta.mkSizeOfInstances declNames[0]! + if (← declNames.allM isInductive) then + for declName in declNames do + liftTermElabM <| Meta.mkSizeOfInstances declName return true else return false diff --git a/tests/lean/3057.lean b/tests/lean/3057.lean new file mode 100644 index 0000000000..ad92ee37c4 --- /dev/null +++ b/tests/lean/3057.lean @@ -0,0 +1,26 @@ +/- +The derive handlers should manage both inductives and not ignore the second one, +Fixes `#3057` +-/ + +mutual + inductive Tree : Type := + | node : ListTree → Tree + deriving Repr, DecidableEq, BEq, Hashable, Ord + + inductive ListTree : Type := + | nil : ListTree + | cons : Tree → ListTree → ListTree + deriving Repr, DecidableEq, BEq, Hashable, Ord +end + +#synth Repr Tree +#synth Repr ListTree +#synth DecidableEq Tree +#synth DecidableEq ListTree +#synth BEq Tree +#synth BEq ListTree +#synth Hashable Tree +#synth Hashable ListTree +#synth Ord Tree +#synth Ord ListTree diff --git a/tests/lean/3057.lean.expected.out b/tests/lean/3057.lean.expected.out new file mode 100644 index 0000000000..ccac3ae309 --- /dev/null +++ b/tests/lean/3057.lean.expected.out @@ -0,0 +1,10 @@ +instReprTree +instReprListTree +fun a b => instDecidableEqTree a b +fun a b => instDecidableEqListTree a b +instBEqTree +instBEqListTree +instHashableTree +instHashableListTree +instOrdTree +instOrdListTree