More clean-ups
[invirt/third/libt4.git] / rpc / rpc.h
index 84c12f3..7b65101 100644 (file)
--- a/rpc/rpc.h
+++ b/rpc/rpc.h
@@ -5,6 +5,7 @@
 #include <sys/socket.h>
 #include <netinet/in.h>
 
+#include "rpc_protocol.h"
 #include "thr_pool.h"
 #include "marshall.h"
 #include "marshall_wrap.h"
@@ -15,23 +16,27 @@ namespace rpc {
     static constexpr milliseconds to_min{100};
 }
 
-class rpc_const {
-    public:
-        static const unsigned int bind = 1;   // handler number reserved for bind
-        static const int timeout_failure = -1;
-        static const int unmarshal_args_failure = -2;
-        static const int unmarshal_reply_failure = -3;
-        static const int atmostonce_failure = -4;
-        static const int oldsrv_failure = -5;
-        static const int bind_failure = -6;
-        static const int cancel_failure = -7;
-};
+template<class P, class R, class ...Args> struct is_valid_call : false_type {};
+
+template<class S, class R, class ...Args>
+struct is_valid_call<S(R &, Args...), R, Args...> : true_type {};
+
+template<class P, class F> struct is_valid_registration : false_type {};
+
+template<class S, class R, class ...Args>
+struct is_valid_registration<S(R &, typename decay<Args>::type...), S(R &, Args...)> : true_type {};
+
+template<class P, class C, class S, class R, class ...Args>
+struct is_valid_registration<P, S(C::*)(R &, Args...)> : is_valid_registration<P, S(R &, Args...)> {};
 
 // rpc client endpoint.
 // manages a xid space per destination socket
 // threaded: multiple threads can be sending RPCs,
 class rpcc : private connection_delegate {
     private:
+        using proc_id_t = rpc_protocol::proc_id_t;
+        template <class S>
+        using proc_t = rpc_protocol::proc_t<S>;
 
         // manages per rpc info
         struct caller {
@@ -78,20 +83,20 @@ class rpcc : private connection_delegate {
         request dup_req_;
         int xid_rep_done_;
 
-        int call1(proc_t proc, marshall &req, string &rep, milliseconds to);
+        int call1(proc_id_t proc, marshall &req, string &rep, milliseconds to);
 
         template<class R>
-        int call_m(proc_t proc, marshall &req, R & r, milliseconds to) {
+        int call_m(proc_id_t proc, marshall &req, R & r, milliseconds to) {
             string rep;
             int intret = call1(proc, req, rep, to);
             unmarshall u(rep, true);
             if (intret < 0) return intret;
             u >> r;
             if (u.okdone() != true) {
-                cerr << "rpcc::call_m: failed to unmarshall the reply.  You are probably " <<
-                    "calling RPC 0x" << hex << proc << " with the wrong return type." << endl;
+                LOG("rpcc::call_m: failed to unmarshall the reply.  You are probably " <<
+                    "calling RPC 0x" << hex << proc << " with the wrong return type.");
                 VERIFY(0);
-                return rpc_const::unmarshal_reply_failure;
+                return rpc_protocol::unmarshal_reply_failure;
             }
             return intret;
         }
@@ -111,21 +116,25 @@ class rpcc : private connection_delegate {
 
         void cancel();
 
-        template<class R, typename ...Args>
-        inline int call(proc_t proc, R & r, const Args&... args) {
+        template<class P, class R, typename ...Args>
+        inline int call(proc_t<P> proc, R & r, const Args&... args) {
             return call_timeout(proc, rpc::to_max, r, args...);
         }
 
-        template<class R, typename ...Args>
-        inline int call_timeout(proc_t proc, milliseconds to, R & r, const Args&... args) {
+        template<class P, class R, typename ...Args>
+        inline int call_timeout(proc_t<P> proc, milliseconds to, R & r, const Args&... args) {
+            static_assert(is_valid_call<P, R, Args...>::value, "RPC called with incorrect argument types");
             marshall m{args...};
-            return call_m(proc, m, r, to);
+            return call_m(proc.id, m, r, to);
         }
 };
 
 // rpc server endpoint.
 class rpcs : private connection_delegate {
     private:
+        using proc_id_t = rpc_protocol::proc_id_t;
+        template <class S>
+        using proc_t = rpc_protocol::proc_t<S>;
 
         typedef enum {
             NEW,  // new RPC, not a duplicate
@@ -160,7 +169,7 @@ class rpcs : private connection_delegate {
         rpcstate_t checkduplicate_and_update(unsigned int clt_nonce, 
                 int xid, int rep_xid, string & b);
 
-        void updatestat(proc_t proc);
+        void updatestat(proc_id_t proc);
 
         // latest connection to the client
         map<unsigned int, shared_ptr<connection>> conns_;
@@ -168,12 +177,12 @@ class rpcs : private connection_delegate {
         // counting
         const size_t counting_;
         size_t curr_counts_;
-        map<proc_t, size_t> counts_;
+        map<proc_id_t, size_t> counts_;
 
         bool reachable_;
 
         // map proc # to function
-        map<proc_t, handler *> procs_;
+        map<proc_id_t, handler *> procs_;
 
         mutex procs_m_; // protect insert/delete to procs[]
         mutex count_m_;  // protect modification of counts
@@ -183,13 +192,13 @@ class rpcs : private connection_delegate {
         void dispatch(shared_ptr<connection> c, const string & buf);
 
         // internal handler registration
-        void reg1(proc_t proc, handler *);
+        void reg1(proc_id_t proc, handler *);
 
         unique_ptr<thread_pool> dispatchpool_;
         unique_ptr<tcpsconn> listener_;
 
         // RPC handler for clients binding
-        int rpcbind(unsigned int &r, int a);
+        rpc_protocol::status rpcbind(unsigned int &r, int a);
 
         bool got_pdu(const shared_ptr<connection> & c, const string & b);
 
@@ -200,13 +209,14 @@ class rpcs : private connection_delegate {
 
         void set_reachable(bool r) { reachable_ = r; }
 
-        template<class F, class C=void> void reg(proc_t proc, F f, C *c=nullptr) {
+        template<class P, class F, class C=void> void reg(proc_t<P> proc, F f, C *c=nullptr) {
+            static_assert(is_valid_registration<P, F>::value, "RPC handler registered with incorrect argument types");
             struct ReturnOnFailure {
                 static inline int unmarshall_args_failure() {
-                    return rpc_const::unmarshal_args_failure;
+                    return rpc_protocol::unmarshal_args_failure;
                 }
             };
-            reg1(proc, marshalled_func<F, ReturnOnFailure>::wrap(f, c));
+            reg1(proc.id, marshalled_func<F, ReturnOnFailure>::wrap(f, c));
         }
 
         void start();