Rewrote threaded log code to be more idiomatic.
[invirt/third/libt4.git] / rpc / rpctest.cc
1 // RPC test and pseudo-documentation.
2 // generates print statements on failures, but eventually says "rpctest OK"
3
4 #include "types.h"
5 #include "rpc.h"
6 #include <arpa/inet.h>
7 #include <getopt.h>
8 #include <unistd.h>
9 #include <string.h>
10
11 #define NUM_CL 2
12
13 char log_thread_prefix = 'r';
14
15 static rpcs *server;  // server rpc object
16 static rpcc *clients[NUM_CL];  // client rpc object
17 static string dst; //server's ip address
18 static in_port_t port;
19
20 // server-side handlers. they must be methods of some class
21 // to simplify rpcs::reg(). a server process can have handlers
22 // from multiple classes.
23 class srv {
24     public:
25         int handle_22(string & r, const string a, const string b);
26         int handle_fast(int & r, const int a);
27         int handle_slow(int & r, const int a);
28         int handle_bigrep(string & r, const size_t a);
29 };
30
31 namespace srv_protocol {
32     using status = rpc_protocol::status;
33     REMOTE_PROCEDURE_BASE(0);
34     REMOTE_PROCEDURE(22, _22, (string &, string, string));
35     REMOTE_PROCEDURE(23, fast, (int &, int));
36     REMOTE_PROCEDURE(24, slow, (int &, int));
37     REMOTE_PROCEDURE(25, bigrep, (string &, size_t));
38 }
39
40 // a handler. a and b are arguments, r is the result.
41 // there can be multiple arguments but only one result.
42 // the caller also gets to see the int return value
43 // as the return value from rpcc::call().
44 // rpcs::reg() decides how to unmarshall by looking
45 // at these argument types, so this function definition
46 // does what a .x file does in SunRPC.
47 int srv::handle_22(string & r, const string a, string b) {
48     r = a + b;
49     return 0;
50 }
51
52 int srv::handle_fast(int & r, const int a) {
53     r = a + 1;
54     return 0;
55 }
56
57 int srv::handle_slow(int & r, const int a) {
58     usleep(random() % 500);
59     r = a + 2;
60     return 0;
61 }
62
63 int srv::handle_bigrep(string & r, const size_t len) {
64     r = string(len, 'x');
65     return 0;
66 }
67
68 static srv service;
69
70 static void startserver() {
71     server = new rpcs(port);
72     server->reg(srv_protocol::_22, &srv::handle_22, &service);
73     server->reg(srv_protocol::fast, &srv::handle_fast, &service);
74     server->reg(srv_protocol::slow, &srv::handle_slow, &service);
75     server->reg(srv_protocol::bigrep, &srv::handle_bigrep, &service);
76     server->start();
77 }
78
79 static void testmarshall() {
80     marshall m;
81     rpc_protocol::request_header rh{1,2,3,4,5};
82     m.pack_header(rh);
83     VERIFY(((string)m).size()==rpc_protocol::RPC_HEADER_SZ);
84     int i = 12345;
85     unsigned long long l = 1223344455L;
86     size_t sz = 101010101;
87     string s = "hallo....";
88     string bin("\x00\x00\x00\x00\x00\x00\x00\x40\x00\x00\x7f\xe5", 12);
89     m << i;
90     m << l;
91     m << s;
92     m << sz;
93     m << bin;
94
95     string b = m;
96     VERIFY(b.size() == rpc_protocol::RPC_HEADER_SZ+sizeof(i)+sizeof(l)+sizeof(uint32_t)+s.size()+sizeof(uint32_t)+sizeof(uint32_t)+bin.size());
97
98     unmarshall un(b, true);
99     rpc_protocol::request_header rh1;
100     un.unpack_header(rh1);
101     VERIFY(memcmp(&rh,&rh1,sizeof(rh))==0);
102     int i1;
103     unsigned long long l1;
104     string s1;
105     string bin1;
106     size_t sz1;
107     un >> i1;
108     un >> l1;
109     un >> s1;
110     un >> sz1;
111     un >> bin1;
112     VERIFY(un.okdone());
113     VERIFY(i1==i && l1==l && s1==s && sz1==sz && bin1==bin);
114 }
115
116 static void client1(size_t cl) {
117     // test concurrency.
118     size_t which_cl = cl % NUM_CL;
119
120     for(int i = 0; i < 100; i++){
121         unsigned long arg = (random() % 2000);
122         string rep;
123         int ret = clients[which_cl]->call(srv_protocol::bigrep, rep, arg);
124         VERIFY(ret == 0);
125         if ((unsigned long)rep.size()!=arg)
126             cout << "repsize wrong " << rep.size() << "!=" << arg << endl;
127         VERIFY((unsigned long)rep.size() == arg);
128     }
129
130     // test rpc replies coming back not in the order of
131     // the original calls -- i.e. does xid reply dispatch work.
132     for(int i = 0; i < 100; i++){
133         int which = (random() % 2);
134         int arg = (random() % 1000);
135         int rep;
136
137         auto start = steady_clock::now();
138
139         int ret = clients[which_cl]->call(which ? srv_protocol::fast : srv_protocol::slow, rep, arg);
140         auto end = steady_clock::now();
141         auto diff = duration_cast<milliseconds>(end - start).count();
142         if (ret != 0)
143             cout << diff << " ms have elapsed!!!" << endl;
144         VERIFY(ret == 0);
145         VERIFY(rep == (which ? arg+1 : arg+2));
146     }
147 }
148
149 static void client2(size_t cl) {
150     size_t which_cl = cl % NUM_CL;
151
152     time_t t1;
153     time(&t1);
154
155     while(time(0) - t1 < 10){
156         unsigned long arg = (random() % 2000);
157         string rep;
158         int ret = clients[which_cl]->call(srv_protocol::bigrep, rep, arg);
159         if ((unsigned long)rep.size()!=arg)
160             cout << "ask for " << arg << " reply got " << rep.size() << " ret " << ret << endl;
161         VERIFY((unsigned long)rep.size() == arg);
162     }
163 }
164
165 static void client3(void *xx) {
166     rpcc *c = (rpcc *) xx;
167
168     for(int i = 0; i < 4; i++){
169         int rep = 0;
170         int ret = c->call_timeout(srv_protocol::slow, milliseconds(300), rep, i);
171         VERIFY(ret == rpc_protocol::timeout_failure || rep == i+2);
172     }
173 }
174
175 static void simple_tests(rpcc *c) {
176     cout << "simple_tests" << endl;
177     // an RPC call to procedure #22.
178     // rpcc::call() looks at the argument types to decide how
179     // to marshall the RPC call packet, and how to unmarshall
180     // the reply packet.
181     string rep;
182     int intret = c->call(srv_protocol::_22, rep, (string)"hello", (string)" goodbye");
183     VERIFY(intret == 0); // this is what handle_22 returns
184     VERIFY(rep == "hello goodbye");
185     cout << "   -- string concat RPC .. ok" << endl;
186
187     // small request, big reply (perhaps req via UDP, reply via TCP)
188     intret = c->call_timeout(srv_protocol::bigrep, milliseconds(20000), rep, 70000ul);
189     VERIFY(intret == 0);
190     VERIFY(rep.size() == 70000);
191     cout << "   -- small request, big reply .. ok" << endl;
192
193     // specify a timeout value to an RPC that should succeed (udp)
194     int xx = 0;
195     intret = c->call_timeout(srv_protocol::fast, milliseconds(300), xx, 77);
196     VERIFY(intret == 0 && xx == 78);
197     cout << "   -- no spurious timeout .. ok" << endl;
198
199     // specify a timeout value to an RPC that should succeed (tcp)
200     {
201         string arg(1000, 'x');
202         string rep2;
203         c->call_timeout(srv_protocol::_22, milliseconds(300), rep2, arg, (string)"x");
204         VERIFY(rep2.size() == 1001);
205         cout << "   -- no spurious timeout .. ok" << endl;
206     }
207
208     // huge RPC
209     string big(1000000, 'x');
210     intret = c->call(srv_protocol::_22, rep, big, (string)"z");
211     VERIFY(intret == 0);
212     VERIFY(rep.size() == 1000001);
213     cout << "   -- huge 1M rpc request .. ok" << endl;
214
215     // specify a timeout value to an RPC that should timeout (udp)
216     string non_existent = "127.0.0.1:7661";
217     rpcc *c1 = new rpcc(non_existent);
218     time_t t0 = time(0);
219     intret = c1->bind(milliseconds(300));
220     time_t t1 = time(0);
221     VERIFY(intret < 0 && (t1 - t0) <= 4);
222     cout << "   -- rpc timeout .. ok" << endl;
223     cout << "simple_tests OK" << endl;
224 }
225
226 static void concurrent_test(size_t nt) {
227     // create threads that make lots of calls in parallel,
228     // to test thread synchronization for concurrent calls
229     // and dispatches.
230     cout << "start concurrent_test (" << nt << " threads) ...";
231
232     vector<thread> th(nt);
233
234     for(size_t i = 0; i < nt; i++)
235         th[i] = thread(client1, i);
236
237     for(size_t i = 0; i < nt; i++)
238         th[i].join();
239
240     cout << " OK" << endl;
241 }
242
243 static void lossy_test() {
244     cout << "start lossy_test ...";
245     VERIFY(setenv("RPC_LOSSY", "5", 1) == 0);
246
247     if (server) {
248         delete server;
249         startserver();
250     }
251
252     for (int i = 0; i < NUM_CL; i++) {
253         delete clients[i];
254         clients[i] = new rpcc(dst);
255         VERIFY(clients[i]->bind()==0);
256     }
257
258     size_t nt = 1;
259
260     vector<thread> th(nt);
261
262     for(size_t i = 0; i < nt; i++)
263         th[i] = thread(client2, i);
264
265     for(size_t i = 0; i < nt; i++)
266         th[i].join();
267
268     cout << ".. OK" << endl;
269     VERIFY(setenv("RPC_LOSSY", "0", 1) == 0);
270 }
271
272 static void failure_test() {
273     rpcc *client1;
274     rpcc *client = clients[0];
275
276     cout << "failure_test" << endl;
277
278     delete server;
279
280     client1 = new rpcc(dst);
281     VERIFY (client1->bind(milliseconds(3000)) < 0);
282     cout << "   -- create new client and try to bind to failed server .. failed ok" << endl;
283
284     delete client1;
285
286     startserver();
287
288     string rep;
289     int intret = client->call(srv_protocol::_22, rep, (string)"hello", (string)" goodbye");
290     VERIFY(intret == rpc_protocol::oldsrv_failure);
291     cout << "   -- call recovered server with old client .. failed ok" << endl;
292
293     delete client;
294
295     clients[0] = client = new rpcc(dst);
296     VERIFY (client->bind() >= 0);
297     VERIFY (client->bind() < 0);
298
299     intret = client->call(srv_protocol::_22, rep, (string)"hello", (string)" goodbye");
300     VERIFY(intret == 0);
301     VERIFY(rep == "hello goodbye");
302
303     cout << "   -- delete existing rpc client, create replacement rpc client .. ok" << endl;
304
305
306     size_t nt = 10;
307     cout << "   -- concurrent test on new rpc client w/ " << nt << " threads ..";
308
309     vector<thread> th(nt);
310
311     for(size_t i = 0; i < nt; i++)
312         th[i] = thread(client3, client);
313
314     for(size_t i = 0; i < nt; i++)
315         th[i].join();
316
317     cout << "ok" << endl;
318
319     delete server;
320     delete client;
321
322     startserver();
323     clients[0] = client = new rpcc(dst);
324     VERIFY (client->bind() >= 0);
325     cout << "   -- delete existing rpc client and server, create replacements.. ok" << endl;
326
327     cout << "   -- concurrent test on new client and server w/ " << nt << " threads ..";
328
329     for(size_t i = 0; i < nt; i++)
330         th[i] = thread(client3, client);
331
332     for(size_t i = 0; i < nt; i++)
333         th[i].join();
334
335     cout << "ok" << endl;
336
337     cout << "failure_test OK" << endl;
338 }
339
340 int main(int argc, char *argv[]) {
341
342     setvbuf(stdout, NULL, _IONBF, 0);
343     setvbuf(stderr, NULL, _IONBF, 0);
344     int debug_level = 0;
345
346     bool isclient = false;
347     bool isserver = false;
348
349     srandom((uint32_t)getpid());
350     port = 20000 + (getpid() % 10000);
351
352     int ch = 0;
353     while ((ch = getopt(argc, argv, "csd:p:l"))!=-1) {
354         switch (ch) {
355             case 'c':
356                 isclient = true;
357                 break;
358             case 's':
359                 isserver = true;
360                 break;
361             case 'd':
362                 debug_level = atoi(optarg);
363                 break;
364             case 'p':
365                 port = (in_port_t)atoi(optarg);
366                 break;
367             case 'l':
368                 VERIFY(setenv("RPC_LOSSY", "5", 1) == 0);
369                 break;
370             default:
371                 break;
372         }
373     }
374
375     if (!isserver && !isclient)  {
376         isserver = isclient = true;
377     }
378
379     if (debug_level > 0) {
380         DEBUG_LEVEL = debug_level;
381         IF_LEVEL(1) LOG_NONMEMBER << "DEBUG LEVEL: " << debug_level;
382     }
383
384     testmarshall();
385
386     if (isserver) {
387         cout << "starting server on port " << port << " RPC_HEADER_SZ " << (int)rpc_protocol::RPC_HEADER_SZ << endl;
388         startserver();
389     }
390
391     if (isclient) {
392         // server's address.
393         dst = "127.0.0.1:" + to_string(port);
394
395
396         // start the client.  bind it to the server.
397         // starts a thread to listen for replies and hand them to
398         // the correct waiting caller thread. there should probably
399         // be only one rpcc per process. you probably need one
400         // rpcc per server.
401         for (int i = 0; i < NUM_CL; i++) {
402             clients[i] = new rpcc(dst);
403             VERIFY (clients[i]->bind() == 0);
404         }
405
406         simple_tests(clients[0]);
407         concurrent_test(10);
408         lossy_test();
409         if (isserver) {
410             failure_test();
411         }
412
413         cout << "rpctest OK" << endl;
414
415         exit(0);
416     }
417
418     while (1)
419         usleep(100000);
420 }