diff --git a/src/Lean/Server/Rpc/Deriving.lean b/src/Lean/Server/Rpc/Deriving.lean index 73eeeea4f1..94b4e615c4 100644 --- a/src/Lean/Server/Rpc/Deriving.lean +++ b/src/Lean/Server/Rpc/Deriving.lean @@ -25,9 +25,7 @@ private def deriveWithRefInstance (typeNm : Name) : CommandElabM Bool := do rpcDecode := WithRpcRef.decodeUnsafeAs $typeId:ident $(quote typeNm) @[implementedBy unsafeInst] - instance : RpcEncoding (WithRpcRef $typeId:ident) Lsp.RpcRef where - rpcEncode r := pure ⟨0⟩ - rpcDecode r := throw "unreachable" + instance : RpcEncoding (WithRpcRef $typeId:ident) Lsp.RpcRef := default ) elabCommand cmds return true @@ -73,10 +71,16 @@ def withFieldsFlattened (indVal : InductiveVal) (params : Array Expr) end -def isOptField (n : Name) : Bool := - n.toString.endsWith "?" +private def getRpcPacketFor (ty : Expr) : MetaM Expr := do + let packetTy ← mkFreshExprMVar (Expr.sort levelOne) + let _ ← synthInstance (mkApp2 (mkConst ``RpcEncoding) ty packetTy) + instantiateMVars packetTy -private def deriveStructureInstance (indVal : InductiveVal) (params : Array Expr) : TermElabM Command := +private def deriveStructureInstance (indVal : InductiveVal) (params : Array Expr) + (paramBinders packetParamBinders encInstBinders : Array (TSyntax ``Parser.Term.bracketedBinder)) : TermElabM Command := do + let packetParamNames := packetParamBinders.map fun + | `(bracketedBinder| ($t:ident : $_)) => t + | _ => unreachable! withFields indVal params fun fields => do trace[Elab.Deriving.RpcEncoding] "for structure {indVal.name} with params {params}" -- Postulate that every field have a rpc encoding, storing the encoding type ident @@ -84,34 +88,16 @@ private def deriveStructureInstance (indVal : InductiveVal) (params : Array Expr -- as otherwise typeclass synthesis fails. let mut binders := #[] let mut fieldIds := #[] - let mut fieldEncIds : Array Term := #[] - let mut uniqFieldEncIds : Array Ident := #[] - let mut fieldEncIds' : DiscrTree Ident := {} + let mut fieldEncTypeStxs := #[] for (fieldName, fieldTp) in fields do let mut fieldTp := fieldTp - if isOptField fieldName then - if !fieldTp.isAppOf ``Option then - throwError "optional field '{fieldName}' has type{indentD m!"{fieldTp}"}\nbut is expected to have type{indentD "Option _"}" --" - fieldTp := fieldTp.getArg! 0 - -- postulate that the field has an encoding and remember the encoding's binder name + let fieldEncTypeStx ← PrettyPrinter.delab (← getRpcPacketFor fieldTp) + let stx ← PrettyPrinter.delab fieldTp fieldIds := fieldIds.push <| mkIdent fieldName - let mut fieldEncId : Ident := ⟨Syntax.missing⟩ - match (← fieldEncIds'.getMatch fieldTp).back? with - | none => - fieldEncId ← mkIdent <$> mkFreshUserName fieldName - binders := binders.push (← `(bracketedBinder| ( $fieldEncId:ident ))) - let stx ← PrettyPrinter.delab fieldTp - binders := binders.push - (← `(bracketedBinder| [ RpcEncoding $stx $fieldEncId:ident ])) - fieldEncIds' ← fieldEncIds'.insert fieldTp fieldEncId - uniqFieldEncIds := uniqFieldEncIds.push fieldEncId - | some fid => fieldEncId := fid - - if isOptField fieldName then - fieldEncIds := fieldEncIds.push <| ← ``(Option $fieldEncId:ident) - else - fieldEncIds := fieldEncIds.push fieldEncId + fieldEncTypeStxs := fieldEncTypeStxs.push fieldEncTypeStx + binders := binders.push + (← `(bracketedBinder| [ RpcEncoding $stx $fieldEncTypeStx ])) -- helpers for field initialization syntax let fieldInits (func : Name) := fieldIds.mapM fun fid => @@ -123,114 +109,92 @@ private def deriveStructureInstance (indVal : InductiveVal) (params : Array Expr let paramIds ← params.mapM fun p => return mkIdent (← getFVarLocalDecl p).userName let typeId := Syntax.mkApp (← `(@$(mkIdent indVal.name))) paramIds let packetId := mkIdent <| indVal.name ++ `RpcEncodingPacket - let packetAppliedId := Syntax.mkApp packetId uniqFieldEncIds + let packetAppliedId := Syntax.mkApp packetId packetParamNames - `(variable $binders* - - protected structure $packetId:ident where - $[($fieldIds : $fieldEncIds)]* + `(protected structure $packetId:ident $packetParamBinders* where + $[($fieldIds : $fieldEncTypeStxs)]* deriving FromJson, ToJson + variable $(paramBinders ++ packetParamBinders ++ encInstBinders)* in instance : RpcEncoding $typeId $packetAppliedId where - rpcEncode a := return { - $[$encInits],* - } - rpcDecode a := return { - $[$decInits],* - } + rpcEncode a := return { $[$encInits],* } + rpcDecode a := return { $[$decInits],* } ) private structure CtorState where - -- names of encoded argument types in the RPC packet - encArgTypes : DiscrTree Name := {} - uniqEncArgTypes : Array Name := #[] - -- binders for `encArgTypes` as well as the relevant `RpcEncoding`s - binders : Array (TSyntax ``Parser.Term.bracketedBinder) := #[] -- the syntax of each constructor in the packet ctors : Array (TSyntax ``Parser.Command.ctor) := #[] -- syntax of each arm of the `rpcEncode` pattern-match encodes : Array (TSyntax ``Parser.Term.matchAlt) := #[] -- syntax of each arm of the `rpcDecode` pattern-match decodes : Array (TSyntax ``Parser.Term.matchAlt) := #[] - deriving Inhabited private def matchF := Lean.Parser.Term.matchAlt (rhsParser := Lean.Parser.termParser) -private def deriveInductiveInstance (indVal : InductiveVal) (params : Array Expr) : TermElabM Command := do +private def deriveInductiveInstance (indVal : InductiveVal) (params packetParams : Array Expr) + (paramBinders packetParamBinders encInstBinders : Array (TSyntax ``Parser.Term.bracketedBinder)) : TermElabM Command := do trace[Elab.Deriving.RpcEncoding] "for inductive {indVal.name} with params {params}" + let packetParamNames := packetParamBinders.map fun + | `(bracketedBinder| ($t:ident : $_)) => t + | _ => unreachable! - -- produce all encoding types and binders for them - let st ← foldWithConstructors indVal params (init := { : CtorState}) fun acc ctor argVars tp => do - trace[Elab.Deriving.RpcEncoding] "{ctor} : {argVars} → {tp}" - let mut acc := acc - let argFVars ← argVars.mapM (LocalDecl.fvarId <$> getFVarLocalDecl ·) - for arg in argVars do + withoutModifyingEnv do + let packetNm := indVal.name ++ `RpcEncodingPacket + addDecl <| .axiomDecl { + name := packetNm + levelParams := [] + type := ← mkForallFVars packetParams (mkSort levelOne) + isUnsafe := true + } + let pktCtorTp := mkAppN (mkConst packetNm) packetParams + let recInstTp := mkApp2 (mkConst ``RpcEncoding) (mkAppN (mkConst indVal.name) params) pktCtorTp + withLocalDecl `inst .instImplicit recInstTp fun _ => do + let st ← foldWithConstructors indVal params (init := { : CtorState }) fun acc ctor argVars _ => do + let mut pktCtorTp := pktCtorTp + -- create the constructor + for arg in argVars.reverse do let argTp ← inferType arg - if (← findExprDependsOn argTp (pf := fun fv => argFVars.contains fv)) then - throwError "cross-argument dependencies are not supported ({arg} : {argTp})" + let encTp ← getRpcPacketFor argTp + pktCtorTp := mkForall (← getFVarLocalDecl arg).userName BinderInfo.default encTp pktCtorTp + -- TODO(WN): this relies on delab printing non-macro-scoped user names in non-dependent foralls + -- to generate the expected JSON encoding + let pktCtorTpStx ← PrettyPrinter.delab pktCtorTp + let pktCtor ← `(Lean.Parser.Command.ctor| | $(mkIdent ctor.getString!):ident : $pktCtorTpStx:term) - if (← acc.encArgTypes.getMatch argTp).isEmpty then - let tid ← mkFreshUserName (← getFVarLocalDecl arg).userName - let argTpStx ← PrettyPrinter.delab argTp - acc := { acc with encArgTypes := ← acc.encArgTypes.insert argTp tid - uniqEncArgTypes := acc.uniqEncArgTypes.push tid - binders := acc.binders.append #[ - (← `(bracketedBinder| ( $(mkIdent tid):ident ))), - (← `(bracketedBinder| [ RpcEncoding $argTpStx $(mkIdent tid):ident ])) - ] } - return acc + -- create encoder and decoder match arms + let nms ← argVars.mapM fun _ => mkIdent <$> mkFreshBinderName + let mkPattern (src : Name) := Syntax.mkApp (mkIdent <| Name.mkStr src ctor.getString!) nms + let mkBody (tgt : Name) (func : Name) : TermElabM Term := do + let items ← nms.mapM fun nm => `(← $(mkIdent func) $nm) + let tm := Syntax.mkApp (mkIdent <| Name.mkStr tgt ctor.getString!) items + `(return $tm:term) - -- introduce encoding types into the local context so that we can use the delaborator to print them - withLocalDecls - (st.uniqEncArgTypes.map fun tid => (tid, BinderInfo.default, fun _ => pure <| mkSort levelOne)) - fun ts => do - trace[Elab.Deriving.RpcEncoding] m!"RpcEncoding type binders : {ts}" + let encArm ← `(matchF| | $(mkPattern indVal.name):term => $(← mkBody packetNm ``rpcEncode)) + let decArm ← `(matchF| | $(mkPattern packetNm):term => $(← mkBody indVal.name ``rpcDecode)) - let packetNm := indVal.name ++ `RpcEncodingPacket - let st ← foldWithConstructors indVal params (init := st) fun acc ctor argVars _ => do - -- create the constructor - let mut pktCtorTp := Lean.mkConst packetNm - for arg in argVars.reverse do - let argTp ← inferType arg - let encTpNm := (← acc.encArgTypes.getMatch argTp).back - let encTp ← elabTerm (mkIdent encTpNm) none - pktCtorTp := mkForall (← getFVarLocalDecl arg).userName BinderInfo.default encTp pktCtorTp - -- TODO(WN): this relies on delab printing non-macro-scoped user names in non-dependent foralls - -- to generate the expected JSON encoding - let pktCtorTpStx ← PrettyPrinter.delab pktCtorTp - let pktCtor ← `(Lean.Parser.Command.ctor| | $(mkIdent ctor.getString!):ident : $pktCtorTpStx:term) + return { acc with ctors := acc.ctors.push pktCtor + encodes := acc.encodes.push ⟨encArm⟩ + decodes := acc.decodes.push ⟨decArm⟩ } - -- create encoder and decoder match arms - let nms ← argVars.mapM fun _ => mkIdent <$> mkFreshBinderName - let mkPattern (src : Name) := Syntax.mkApp (mkIdent <| Name.mkStr src ctor.getString!) nms - let mkBody (tgt : Name) (func : Name) : TermElabM Term := do - let items ← nms.mapM fun nm => `(← $(mkIdent func) $nm) - let tm := Syntax.mkApp (mkIdent <| Name.mkStr tgt ctor.getString!) items - `(return $tm:term) + -- helpers for type name syntax + let paramIds ← params.mapM fun p => return mkIdent (← getFVarLocalDecl p).userName + let typeId := Syntax.mkApp (← `(@$(mkIdent indVal.name))) paramIds + let packetId := Syntax.mkApp (mkIdent packetNm) packetParamNames - let encArm ← `(matchF| | $(mkPattern indVal.name):term => $(← mkBody packetNm ``rpcEncode)) - let decArm ← `(matchF| | $(mkPattern packetNm):term => $(← mkBody indVal.name ``rpcDecode)) + `(inductive $(mkIdent packetNm) $packetParamBinders:bracketedBinder* where + $[$(st.ctors):ctor]* + deriving FromJson, ToJson - return { acc with ctors := acc.ctors.push pktCtor - encodes := acc.encodes.push ⟨encArm⟩ - decodes := acc.decodes.push ⟨decArm⟩ } - - -- helpers for type name syntax - let paramIds ← params.mapM fun p => return mkIdent (← getFVarLocalDecl p).userName - let typeId := Syntax.mkApp (← `(@$(mkIdent indVal.name))) paramIds - let packetAppliedId := Syntax.mkApp (mkIdent packetNm) (st.uniqEncArgTypes.map (mkIdent ·)) - - `(variable $st.binders* - - protected inductive $(mkIdent packetNm) where - $[$(st.ctors):ctor]* - deriving FromJson, ToJson - - instance : RpcEncoding $typeId $packetAppliedId where - rpcEncode := fun x => match x with - $[$(st.encodes):matchAlt]* - rpcDecode := fun x => match x with - $[$(st.decodes):matchAlt]* - ) + variable $(paramBinders ++ packetParamBinders ++ encInstBinders)* in + partial instance : RpcEncoding $typeId $packetId := + { rpcEncode, rpcDecode } + where + rpcEncode {m} [Monad m] [MonadRpcSession m] (x : $typeId) : ExceptT String m $packetId := + have inst : RpcEncoding $typeId $packetId := { rpcEncode, rpcDecode } + match x with $[$(st.encodes):matchAlt]* + rpcDecode {m} [Monad m] [MonadRpcSession m] (x : $packetId) : ExceptT String m $typeId := + have inst : RpcEncoding $typeId $packetId := { rpcEncode, rpcDecode } + match x with $[$(st.decodes):matchAlt]* + ) /-- Creates an `RpcEncodingPacket` for `typeName`. For structures, the packet is a structure with the same field names. For inductives, it mirrors the inductive structure with every field @@ -243,30 +207,33 @@ private def deriveInstance (typeName : Name) : CommandElabM Bool := do if indVal.numIndices ≠ 0 then throwError "indexed inductive families are not supported" - let cmd ← liftTermElabM none do + let (paramBinders, packetParamBinders, encInstBinders) ← liftTermElabM none do -- introduce fvars for all the parameters forallTelescopeReducing indVal.type fun params _ => do - assert! params.size == indVal.numParams + let mut paramBinders := #[] -- input parameters + let mut packetParamBinders := #[] -- RPC encoding packets for type input parameters + let mut encInstBinders := #[] -- RPC encoding instance binders corresponding to packetParamBinders - -- bind every parameter and *some* (not named) `RpcEncoding` for it - let mut binders := #[] for param in params do - let paramNm := (←getFVarLocalDecl param).userName - binders := binders.push (← `(bracketedBinder| ( $(mkIdent paramNm) ))) + let paramNm := (← getFVarLocalDecl param).userName + paramBinders := paramBinders.push (← `(bracketedBinder| ($(mkIdent paramNm)))) -- only look for encodings for `Type` parameters - if !(← inferType param).isType then continue - binders := binders.push - (← `(bracketedBinder| [ RpcEncoding $(mkIdent paramNm) _ ])) + if (← inferType param).isType then + let packet := mkIdent (← mkFreshUserName (paramNm.appendAfter "Packet")) + packetParamBinders := packetParamBinders.push (← `(bracketedBinder| ($packet : Type))) + encInstBinders := encInstBinders.push (← `(bracketedBinder| [RpcEncoding $(mkIdent paramNm) $packet])) - `(section - variable $binders* - $(← if isStructure (← getEnv) typeName then - deriveStructureInstance indVal params - else - deriveInductiveInstance indVal params):command - end) + return (paramBinders, packetParamBinders, encInstBinders) + + elabCommand <| ← liftTermElabM none do + Term.elabBinders (paramBinders ++ packetParamBinders ++ encInstBinders) fun locals => do + let params := locals[:paramBinders.size] + let packetParams := locals[paramBinders.size:paramBinders.size+packetParamBinders.size] + if isStructure (← getEnv) typeName then + deriveStructureInstance indVal params paramBinders packetParamBinders encInstBinders + else + deriveInductiveInstance indVal params packetParams paramBinders packetParamBinders encInstBinders - elabCommand cmd return true private unsafe def dispatchDeriveInstanceUnsafe (declNames : Array Name) (args? : Option (TSyntax ``Parser.Term.structInst)) : CommandElabM Bool := do