Imported from 6.824 labs
[invirt/third/libt4.git] / rpc / marshall.h
1 #ifndef marshall_h
2 #define marshall_h
3
4 #include <iostream>
5 #include <sstream>
6 #include <string>
7 #include <vector>
8 #include <map>
9 #include <stdlib.h>
10 #include <string.h>
11 #include <cstddef>
12 #include <inttypes.h>
13 #include "lang/verify.h"
14 #include "lang/algorithm.h"
15
16 struct req_header {
17         req_header(int x=0, int p=0, int c = 0, int s = 0, int xi = 0):
18                 xid(x), proc(p), clt_nonce(c), srv_nonce(s), xid_rep(xi) {}
19         int xid;
20         int proc;
21         unsigned int clt_nonce;
22         unsigned int srv_nonce;
23         int xid_rep;
24 };
25
26 struct reply_header {
27         reply_header(int x=0, int r=0): xid(x), ret(r) {}
28         int xid;
29         int ret;
30 };
31
32 typedef uint64_t rpc_checksum_t;
33 typedef int rpc_sz_t;
34
35 enum {
36         //size of initial buffer allocation 
37         DEFAULT_RPC_SZ = 1024,
38 #if RPC_CHECKSUMMING
39         //size of rpc_header includes a 4-byte int to be filled by tcpchan and uint64_t checksum
40         RPC_HEADER_SZ = static_max<sizeof(req_header), sizeof(reply_header)>::value + sizeof(rpc_sz_t) + sizeof(rpc_checksum_t)
41 #else
42                 RPC_HEADER_SZ = static_max<sizeof(req_header), sizeof(reply_header)>::value + sizeof(rpc_sz_t)
43 #endif
44 };
45
46 class marshall {
47         private:
48                 char *_buf;     // Base of the raw bytes buffer (dynamically readjusted)
49                 int _capa;      // Capacity of the buffer
50                 int _ind;       // Read/write head position
51
52         public:
53                 marshall() {
54                         _buf = (char *) malloc(sizeof(char)*DEFAULT_RPC_SZ);
55                         VERIFY(_buf);
56                         _capa = DEFAULT_RPC_SZ;
57                         _ind = RPC_HEADER_SZ;
58                 }
59
60                 ~marshall() { 
61                         if (_buf) 
62                                 free(_buf); 
63                 }
64
65                 int size() { return _ind;}
66                 char *cstr() { return _buf;}
67
68                 void rawbyte(unsigned char);
69                 void rawbytes(const char *, int);
70
71                 // Return the current content (excluding header) as a string
72                 std::string get_content() { 
73                         return std::string(_buf+RPC_HEADER_SZ,_ind-RPC_HEADER_SZ);
74                 }
75
76                 // Return the current content (excluding header) as a string
77                 std::string str() {
78                         return get_content();
79                 }
80
81                 void pack(int i);
82
83                 void pack_req_header(const req_header &h) {
84                         int saved_sz = _ind;
85                         //leave the first 4-byte empty for channel to fill size of pdu
86                         _ind = sizeof(rpc_sz_t); 
87 #if RPC_CHECKSUMMING
88                         _ind += sizeof(rpc_checksum_t);
89 #endif
90                         pack(h.xid);
91                         pack(h.proc);
92                         pack((int)h.clt_nonce);
93                         pack((int)h.srv_nonce);
94                         pack(h.xid_rep);
95                         _ind = saved_sz;
96                 }
97
98                 void pack_reply_header(const reply_header &h) {
99                         int saved_sz = _ind;
100                         //leave the first 4-byte empty for channel to fill size of pdu
101                         _ind = sizeof(rpc_sz_t); 
102 #if RPC_CHECKSUMMING
103                         _ind += sizeof(rpc_checksum_t);
104 #endif
105                         pack(h.xid);
106                         pack(h.ret);
107                         _ind = saved_sz;
108                 }
109
110                 void take_buf(char **b, int *s) {
111                         *b = _buf;
112                         *s = _ind;
113                         _buf = NULL;
114                         _ind = 0;
115                         return;
116                 }
117 };
118 marshall& operator<<(marshall &, bool);
119 marshall& operator<<(marshall &, unsigned int);
120 marshall& operator<<(marshall &, int);
121 marshall& operator<<(marshall &, unsigned char);
122 marshall& operator<<(marshall &, char);
123 marshall& operator<<(marshall &, unsigned short);
124 marshall& operator<<(marshall &, short);
125 marshall& operator<<(marshall &, unsigned long long);
126 marshall& operator<<(marshall &, const std::string &);
127
128 class unmarshall {
129         private:
130                 char *_buf;
131                 int _sz;
132                 int _ind;
133                 bool _ok;
134         public:
135                 unmarshall(): _buf(NULL),_sz(0),_ind(0),_ok(false) {}
136                 unmarshall(char *b, int sz): _buf(b),_sz(sz),_ind(),_ok(true) {}
137                 unmarshall(const std::string &s) : _buf(NULL),_sz(0),_ind(0),_ok(false) 
138                 {
139                         //take the content which does not exclude a RPC header from a string
140                         take_content(s);
141                 }
142                 ~unmarshall() {
143                         if (_buf) free(_buf);
144                 }
145
146                 //take contents from another unmarshall object
147                 void take_in(unmarshall &another);
148
149                 //take the content which does not exclude a RPC header from a string
150                 void take_content(const std::string &s) {
151                         _sz = s.size()+RPC_HEADER_SZ;
152                         _buf = (char *)realloc(_buf,_sz);
153                         VERIFY(_buf);
154                         _ind = RPC_HEADER_SZ;
155                         memcpy(_buf+_ind, s.data(), s.size());
156                         _ok = true;
157                 }
158
159                 bool ok() { return _ok; }
160                 char *cstr() { return _buf;}
161                 bool okdone();
162                 unsigned int rawbyte();
163                 void rawbytes(std::string &s, unsigned int n);
164
165                 int ind() { return _ind;}
166                 int size() { return _sz;}
167                 void unpack(int *); //non-const ref
168                 void take_buf(char **b, int *sz) {
169                         *b = _buf;
170                         *sz = _sz;
171                         _sz = _ind = 0;
172                         _buf = NULL;
173                 }
174
175                 void unpack_req_header(req_header *h) {
176                         //the first 4-byte is for channel to fill size of pdu
177                         _ind = sizeof(rpc_sz_t); 
178 #if RPC_CHECKSUMMING
179                         _ind += sizeof(rpc_checksum_t);
180 #endif
181                         unpack(&h->xid);
182                         unpack(&h->proc);
183                         unpack((int *)&h->clt_nonce);
184                         unpack((int *)&h->srv_nonce);
185                         unpack(&h->xid_rep);
186                         _ind = RPC_HEADER_SZ;
187                 }
188
189                 void unpack_reply_header(reply_header *h) {
190                         //the first 4-byte is for channel to fill size of pdu
191                         _ind = sizeof(rpc_sz_t); 
192 #if RPC_CHECKSUMMING
193                         _ind += sizeof(rpc_checksum_t);
194 #endif
195                         unpack(&h->xid);
196                         unpack(&h->ret);
197                         _ind = RPC_HEADER_SZ;
198                 }
199 };
200
201 unmarshall& operator>>(unmarshall &, bool &);
202 unmarshall& operator>>(unmarshall &, unsigned char &);
203 unmarshall& operator>>(unmarshall &, char &);
204 unmarshall& operator>>(unmarshall &, unsigned short &);
205 unmarshall& operator>>(unmarshall &, short &);
206 unmarshall& operator>>(unmarshall &, unsigned int &);
207 unmarshall& operator>>(unmarshall &, int &);
208 unmarshall& operator>>(unmarshall &, unsigned long long &);
209 unmarshall& operator>>(unmarshall &, std::string &);
210
211 template <class C> marshall &
212 operator<<(marshall &m, std::vector<C> v)
213 {
214         m << (unsigned int) v.size();
215         for(unsigned i = 0; i < v.size(); i++)
216                 m << v[i];
217         return m;
218 }
219
220 template <class C> unmarshall &
221 operator>>(unmarshall &u, std::vector<C> &v)
222 {
223         unsigned n;
224         u >> n;
225         for(unsigned i = 0; i < n; i++){
226                 C z;
227                 u >> z;
228                 v.push_back(z);
229         }
230         return u;
231 }
232
233 template <class A, class B> marshall &
234 operator<<(marshall &m, const std::map<A,B> &d) {
235         typename std::map<A,B>::const_iterator i;
236
237         m << (unsigned int) d.size();
238
239         for (i = d.begin(); i != d.end(); i++) {
240                 m << i->first << i->second;
241         }
242         return m;
243 }
244
245 template <class A, class B> unmarshall &
246 operator>>(unmarshall &u, std::map<A,B> &d) {
247         unsigned int n;
248         u >> n;
249
250         d.clear();
251
252         for (unsigned int lcv = 0; lcv < n; lcv++) {
253                 A a;
254                 B b;
255                 u >> a >> b;
256                 d[a] = b;
257         }
258         return u;
259 }
260
261 #endif