Simplifications and clean-ups
[invirt/third/libt4.git] / rpc / rpctest.cc
1 // RPC test and pseudo-documentation.
2 // generates print statements on failures, but eventually says "rpctest OK"
3
4 #include "types.h"
5 #include "rpc.h"
6 #include <arpa/inet.h>
7 #include <getopt.h>
8 #include <unistd.h>
9 #include <string.h>
10
11 #define NUM_CL 2
12
13 char log_thread_prefix = 'r';
14
15 static rpcs *server;  // server rpc object
16 static rpcc *clients[NUM_CL];  // client rpc object
17 static string dst; //server's ip address
18 static in_port_t port;
19
20 using std::cout;
21 using std::endl;
22
23 // server-side handlers. they must be methods of some class
24 // to simplify rpcs::reg(). a server process can have handlers
25 // from multiple classes.
26 class srv {
27     public:
28         int handle_22(string & r, const string a, const string b);
29         int handle_fast(int & r, const int a);
30         int handle_slow(int & r, const int a);
31         int handle_bigrep(string & r, const size_t a);
32 };
33
34 namespace srv_protocol {
35     using status = rpc_protocol::status;
36     REMOTE_PROCEDURE_BASE(0);
37     REMOTE_PROCEDURE(22, _22, (string &, string, string));
38     REMOTE_PROCEDURE(23, fast, (int &, int));
39     REMOTE_PROCEDURE(24, slow, (int &, int));
40     REMOTE_PROCEDURE(25, bigrep, (string &, size_t));
41 }
42
43 // a handler. a and b are arguments, r is the result.
44 // there can be multiple arguments but only one result.
45 // the caller also gets to see the int return value
46 // as the return value from rpcc::call().
47 // rpcs::reg() decides how to unmarshall by looking
48 // at these argument types, so this function definition
49 // does what a .x file does in SunRPC.
50 int srv::handle_22(string & r, const string a, string b) {
51     r = a + b;
52     return 0;
53 }
54
55 int srv::handle_fast(int & r, const int a) {
56     r = a + 1;
57     return 0;
58 }
59
60 int srv::handle_slow(int & r, const int a) {
61     usleep(random() % 500);
62     r = a + 2;
63     return 0;
64 }
65
66 int srv::handle_bigrep(string & r, const size_t len) {
67     r = string(len, 'x');
68     return 0;
69 }
70
71 static srv service;
72
73 static void startserver() {
74     server = new rpcs(port);
75     server->reg(srv_protocol::_22, &srv::handle_22, &service);
76     server->reg(srv_protocol::fast, &srv::handle_fast, &service);
77     server->reg(srv_protocol::slow, &srv::handle_slow, &service);
78     server->reg(srv_protocol::bigrep, &srv::handle_bigrep, &service);
79     server->start();
80 }
81
82 static void testmarshall() {
83     marshall m;
84     rpc_protocol::request_header rh{1,2,3,4,5};
85     m.write_header(rh);
86     VERIFY(((string)m).size()==rpc_protocol::RPC_HEADER_SZ);
87     int i = 12345;
88     unsigned long long l = 1223344455L;
89     size_t sz = 101010101;
90     string s = "hallo....";
91     string bin("\x00\x00\x00\x00\x00\x00\x00\x40\x00\x00\x7f\xe5", 12);
92     m << i;
93     m << l;
94     m << s;
95     m << sz;
96     m << bin;
97
98     string b = m;
99     VERIFY(b.size() == rpc_protocol::RPC_HEADER_SZ+sizeof(i)+sizeof(l)+sizeof(uint32_t)+s.size()+sizeof(uint32_t)+sizeof(uint32_t)+bin.size());
100
101     unmarshall un(b, true);
102     rpc_protocol::request_header rh1;
103     un.read_header(rh1);
104     VERIFY(memcmp(&rh,&rh1,sizeof(rh))==0);
105     int i1;
106     unsigned long long l1;
107     string s1;
108     string bin1;
109     size_t sz1;
110     un >> i1;
111     un >> l1;
112     un >> s1;
113     un >> sz1;
114     un >> bin1;
115     VERIFY(un.okdone());
116     VERIFY(i1==i && l1==l && s1==s && sz1==sz && bin1==bin);
117 }
118
119 static void client1(size_t cl) {
120     // test concurrency.
121     size_t which_cl = cl % NUM_CL;
122
123     for(int i = 0; i < 100; i++){
124         unsigned long arg = (random() % 2000);
125         string rep;
126         int ret = clients[which_cl]->call(srv_protocol::bigrep, rep, arg);
127         VERIFY(ret == 0);
128         if ((unsigned long)rep.size()!=arg)
129             cout << "repsize wrong " << rep.size() << "!=" << arg << endl;
130         VERIFY((unsigned long)rep.size() == arg);
131     }
132
133     // test rpc replies coming back not in the order of
134     // the original calls -- i.e. does xid reply dispatch work.
135     for(int i = 0; i < 100; i++){
136         int which = (random() % 2);
137         int arg = (random() % 1000);
138         int rep = -1;
139
140         auto start = steady_clock::now();
141
142         int ret = clients[which_cl]->call(which ? srv_protocol::fast : srv_protocol::slow, rep, arg);
143         auto end = steady_clock::now();
144         auto diff = duration_cast<milliseconds>(end - start).count();
145         if (ret != 0)
146             cout << diff << " ms have elapsed!!!" << endl;
147         VERIFY(ret == 0);
148         VERIFY(rep == (which ? arg+1 : arg+2));
149     }
150 }
151
152 static void client2(size_t cl) {
153     size_t which_cl = cl % NUM_CL;
154
155     time_t t1;
156     time(&t1);
157
158     while(time(0) - t1 < 10){
159         unsigned long arg = (random() % 2000);
160         string rep;
161         int ret = clients[which_cl]->call(srv_protocol::bigrep, rep, arg);
162         if ((unsigned long)rep.size()!=arg)
163             cout << "ask for " << arg << " reply got " << rep.size() << " ret " << ret << endl;
164         VERIFY((unsigned long)rep.size() == arg);
165     }
166 }
167
168 static void client3(void *xx) {
169     rpcc *c = (rpcc *) xx;
170
171     for(int i = 0; i < 4; i++){
172         int rep = 0;
173         int ret = c->call_timeout(srv_protocol::slow, milliseconds(300), rep, i);
174         VERIFY(ret == rpc_protocol::timeout_failure || rep == i+2);
175     }
176 }
177
178 static void simple_tests(rpcc *c) {
179     cout << "simple_tests" << endl;
180     // an RPC call to procedure #22.
181     // rpcc::call() looks at the argument types to decide how
182     // to marshall the RPC call packet, and how to unmarshall
183     // the reply packet.
184     string rep;
185     int intret = c->call(srv_protocol::_22, rep, (string)"hello", (string)" goodbye");
186     VERIFY(intret == 0); // this is what handle_22 returns
187     VERIFY(rep == "hello goodbye");
188     cout << "   -- string concat RPC .. ok" << endl;
189
190     // small request, big reply (perhaps req via UDP, reply via TCP)
191     intret = c->call_timeout(srv_protocol::bigrep, milliseconds(20000), rep, 70000ul);
192     VERIFY(intret == 0);
193     VERIFY(rep.size() == 70000);
194     cout << "   -- small request, big reply .. ok" << endl;
195
196     // specify a timeout value to an RPC that should succeed (udp)
197     int xx = 0;
198     intret = c->call_timeout(srv_protocol::fast, milliseconds(300), xx, 77);
199     VERIFY(intret == 0 && xx == 78);
200     cout << "   -- no spurious timeout .. ok" << endl;
201
202     // specify a timeout value to an RPC that should succeed (tcp)
203     {
204         string arg(1000, 'x');
205         string rep2;
206         c->call_timeout(srv_protocol::_22, milliseconds(300), rep2, arg, (string)"x");
207         VERIFY(rep2.size() == 1001);
208         cout << "   -- no spurious timeout .. ok" << endl;
209     }
210
211     // huge RPC
212     string big(1000000, 'x');
213     intret = c->call(srv_protocol::_22, rep, big, (string)"z");
214     VERIFY(intret == 0);
215     VERIFY(rep.size() == 1000001);
216     cout << "   -- huge 1M rpc request .. ok" << endl;
217
218     // specify a timeout value to an RPC that should timeout (udp)
219     string non_existent = "127.0.0.1:7661";
220     rpcc *c1 = new rpcc(non_existent);
221     time_t t0 = time(0);
222     intret = c1->bind(milliseconds(300));
223     time_t t1 = time(0);
224     VERIFY(intret < 0 && (t1 - t0) <= 4);
225     cout << "   -- rpc timeout .. ok" << endl;
226     cout << "simple_tests OK" << endl;
227 }
228
229 static void concurrent_test(size_t nt) {
230     // create threads that make lots of calls in parallel,
231     // to test thread synchronization for concurrent calls
232     // and dispatches.
233     cout << "start concurrent_test (" << nt << " threads) ...";
234
235     vector<thread> th(nt);
236
237     for(size_t i = 0; i < nt; i++)
238         th[i] = thread(client1, i);
239
240     for(size_t i = 0; i < nt; i++)
241         th[i].join();
242
243     cout << " OK" << endl;
244 }
245
246 static void lossy_test() {
247     cout << "start lossy_test ...";
248     VERIFY(setenv("RPC_LOSSY", "5", 1) == 0);
249
250     if (server) {
251         delete server;
252         startserver();
253     }
254
255     for (int i = 0; i < NUM_CL; i++) {
256         delete clients[i];
257         clients[i] = new rpcc(dst);
258         VERIFY(clients[i]->bind()==0);
259     }
260
261     size_t nt = 1;
262
263     vector<thread> th(nt);
264
265     for(size_t i = 0; i < nt; i++)
266         th[i] = thread(client2, i);
267
268     for(size_t i = 0; i < nt; i++)
269         th[i].join();
270
271     cout << ".. OK" << endl;
272     VERIFY(setenv("RPC_LOSSY", "0", 1) == 0);
273 }
274
275 static void failure_test() {
276     rpcc *client1;
277     rpcc *client = clients[0];
278
279     cout << "failure_test" << endl;
280
281     delete server;
282
283     client1 = new rpcc(dst);
284     VERIFY (client1->bind(milliseconds(3000)) < 0);
285     cout << "   -- create new client and try to bind to failed server .. failed ok" << endl;
286
287     delete client1;
288
289     startserver();
290
291     string rep;
292     int intret = client->call(srv_protocol::_22, rep, (string)"hello", (string)" goodbye");
293     VERIFY(intret == rpc_protocol::oldsrv_failure);
294     cout << "   -- call recovered server with old client .. failed ok" << endl;
295
296     delete client;
297
298     clients[0] = client = new rpcc(dst);
299     VERIFY (client->bind() >= 0);
300     VERIFY (client->bind() < 0);
301
302     intret = client->call(srv_protocol::_22, rep, (string)"hello", (string)" goodbye");
303     VERIFY(intret == 0);
304     VERIFY(rep == "hello goodbye");
305
306     cout << "   -- delete existing rpc client, create replacement rpc client .. ok" << endl;
307
308
309     size_t nt = 10;
310     cout << "   -- concurrent test on new rpc client w/ " << nt << " threads ..";
311
312     vector<thread> th(nt);
313
314     for(size_t i = 0; i < nt; i++)
315         th[i] = thread(client3, client);
316
317     for(size_t i = 0; i < nt; i++)
318         th[i].join();
319
320     cout << "ok" << endl;
321
322     delete server;
323     delete client;
324
325     startserver();
326     clients[0] = client = new rpcc(dst);
327     VERIFY (client->bind() >= 0);
328     cout << "   -- delete existing rpc client and server, create replacements.. ok" << endl;
329
330     cout << "   -- concurrent test on new client and server w/ " << nt << " threads ..";
331
332     for(size_t i = 0; i < nt; i++)
333         th[i] = thread(client3, client);
334
335     for(size_t i = 0; i < nt; i++)
336         th[i].join();
337
338     cout << "ok" << endl;
339
340     cout << "failure_test OK" << endl;
341 }
342
343 int main(int argc, char *argv[]) {
344
345     setvbuf(stdout, NULL, _IONBF, 0);
346     setvbuf(stderr, NULL, _IONBF, 0);
347     int debug_level = 0;
348
349     bool isclient = false;
350     bool isserver = false;
351
352     srandom((uint32_t)getpid());
353     port = 20000 + (getpid() % 10000);
354
355     int ch = 0;
356     while ((ch = getopt(argc, argv, "csd:p:l"))!=-1) {
357         switch (ch) {
358             case 'c':
359                 isclient = true;
360                 break;
361             case 's':
362                 isserver = true;
363                 break;
364             case 'd':
365                 debug_level = atoi(optarg);
366                 break;
367             case 'p':
368                 port = (in_port_t)atoi(optarg);
369                 break;
370             case 'l':
371                 VERIFY(setenv("RPC_LOSSY", "5", 1) == 0);
372                 break;
373             default:
374                 break;
375         }
376     }
377
378     if (!isserver && !isclient)  {
379         isserver = isclient = true;
380     }
381
382     if (debug_level > 0) {
383         DEBUG_LEVEL = debug_level;
384         IF_LEVEL(1) LOG_NONMEMBER << "DEBUG LEVEL: " << debug_level;
385     }
386
387     testmarshall();
388
389     if (isserver) {
390         cout << "starting server on port " << port << " RPC_HEADER_SZ " << (int)rpc_protocol::RPC_HEADER_SZ << endl;
391         startserver();
392     }
393
394     if (isclient) {
395         // server's address.
396         dst = "127.0.0.1:" + to_string(port);
397
398
399         // start the client.  bind it to the server.
400         // starts a thread to listen for replies and hand them to
401         // the correct waiting caller thread. there should probably
402         // be only one rpcc per process. you probably need one
403         // rpcc per server.
404         for (int i = 0; i < NUM_CL; i++) {
405             clients[i] = new rpcc(dst);
406             VERIFY (clients[i]->bind() == 0);
407         }
408
409         simple_tests(clients[0]);
410         concurrent_test(10);
411         lossy_test();
412         if (isserver) {
413             failure_test();
414         }
415
416         cout << "rpctest OK" << endl;
417
418         exit(0);
419     }
420
421     while (1)
422         usleep(100000);
423 }