All random numbers generated via one PRNG seeded in one place.
[invirt/third/libt4.git] / rpc / connection.cc
1 #include "connection.h"
2 #include "rpc_protocol.h"
3 #include <cerrno>
4 #include <csignal>
5 #include <netinet/tcp.h>
6 #include <unistd.h>
7 #include "marshall.h"
8
9 connection_delegate::~connection_delegate() {}
10
11 connection::connection(connection_delegate * delegate, socket_t && f1, int l1)
12 : fd(std::move(f1)), delegate_(delegate), lossy_(l1)
13 {
14     fd.flags() |= O_NONBLOCK;
15
16     signal(SIGPIPE, SIG_IGN);
17
18     global->shared_mgr.add_callback(fd, CB_RDONLY, this);
19 }
20
21 connection::~connection() {
22     {
23         lock ml(m_);
24         if (dead_)
25             return;
26         dead_ = true;
27         shutdown(fd,SHUT_RDWR);
28     }
29     // after block_remove_fd, select will never wait on fd and no callbacks
30     // will be active
31     global->shared_mgr.block_remove_fd(fd);
32     VERIFY(dead_);
33     VERIFY(wpdu_.status == unused);
34 }
35
36 shared_ptr<connection> connection::to_dst(const sockaddr_in & dst, connection_delegate * delegate, int lossy) {
37     socket_t s = socket(AF_INET, SOCK_STREAM, 0);
38     s.setsockopt(IPPROTO_TCP, TCP_NODELAY, (int)1);
39     if (connect(s, (sockaddr*)&dst, sizeof(dst)) < 0) {
40         IF_LEVEL(1) LOG_NONMEMBER << "failed to " << inet_ntoa(dst.sin_addr) << ":" << ntoh(dst.sin_port);
41         close(s);
42         return nullptr;
43     }
44     IF_LEVEL(2) LOG_NONMEMBER << "connection::to_dst fd=" << s << " to dst " << inet_ntoa(dst.sin_addr) << ":" << ntoh(dst.sin_port);
45     return std::make_shared<connection>(delegate, std::move(s), lossy);
46 }
47
48 bool connection::send(const string & b) {
49     lock ml(m_);
50
51     while (!dead_ && wpdu_.status != unused)
52         send_wait_.wait(ml);
53
54     if (dead_)
55         return false;
56
57     wpdu_ = {inflight, b, 0};
58
59     if (std::bernoulli_distribution(lossy_*.01)(global->random_generator)) {
60         IF_LEVEL(1) LOG << "send LOSSY TEST shutdown fd " << fd;
61         shutdown(fd,SHUT_RDWR);
62     }
63
64     if (!writepdu()) {
65         dead_ = true;
66         ml.unlock();
67         global->shared_mgr.block_remove_fd(fd);
68         ml.lock();
69     } else if (wpdu_.status == inflight && wpdu_.cursor < b.size()) {
70         // should be rare to need to explicitly add write callback
71         global->shared_mgr.add_callback(fd, CB_WRONLY, this);
72         while (!dead_ && wpdu_.status == inflight && wpdu_.cursor < b.size())
73             send_complete_.wait(ml);
74     }
75     bool ret = (!dead_ && wpdu_.status == inflight && wpdu_.cursor == b.size());
76     wpdu_ = {unused, "", 0};
77     send_wait_.notify_all();
78     return ret;
79 }
80
81 // fd is ready to be written
82 void connection::write_cb(int s) {
83     lock ml(m_);
84     VERIFY(!dead_);
85     VERIFY(fd == s);
86     if (wpdu_.status != inflight) {
87         global->shared_mgr.del_callback(fd, CB_WRONLY);
88         return;
89     }
90     if (!writepdu()) {
91         global->shared_mgr.del_callback(fd, CB_RDWR);
92         dead_ = true;
93     } else {
94         VERIFY(wpdu_.status != error);
95         if (wpdu_.cursor < wpdu_.buf.size())
96             return;
97     }
98     send_complete_.notify_one();
99 }
100
101 bool connection::writepdu() {
102     VERIFY(wpdu_.status == inflight);
103     if (wpdu_.cursor == wpdu_.buf.size())
104         return true;
105
106     ssize_t n = write(fd, &wpdu_.buf[wpdu_.cursor], (wpdu_.buf.size()-wpdu_.cursor));
107     if (n < 0) {
108         if (errno != EAGAIN) {
109             IF_LEVEL(1) LOG << "writepdu fd " << fd << " failure errno=" << errno;
110             wpdu_ = {error, "", 0};
111         }
112         return (errno == EAGAIN);
113     }
114     wpdu_.cursor += (size_t)n;
115     return true;
116 }
117
118 // fd is ready to be read
119 void connection::read_cb(int s) {
120     lock ml(m_);
121     VERIFY(fd == s);
122     if (dead_)
123         return;
124
125     IF_LEVEL(5) LOG << "got data on fd " << s;
126
127     if (rpdu_.status == unused || rpdu_.cursor < rpdu_.buf.size()) {
128         if (!readpdu()) {
129             IF_LEVEL(5) LOG << "readpdu on fd " << s << " failed; dying";
130             global->shared_mgr.del_callback(fd, CB_RDWR);
131             dead_ = true;
132             send_complete_.notify_one();
133         }
134     }
135
136     if (rpdu_.status == inflight && rpdu_.buf.size() == rpdu_.cursor) {
137         if (delegate_->got_pdu(shared_from_this(), rpdu_.buf)) {
138             // connection_delegate has successfully consumed the pdu
139             rpdu_ = {unused, "", 0};
140         }
141     }
142 }
143
144 bool connection::readpdu() {
145     IF_LEVEL(5) LOG << "the receive buffer has length " << rpdu_.buf.size();
146     if (rpdu_.status == unused) {
147         rpc_protocol::rpc_sz_t sz1;
148         ssize_t n = fd.read(sz1);
149
150         if (n == 0)
151             return false;
152
153         if (n < 0) {
154             VERIFY(errno!=EAGAIN);
155             return false;
156         }
157
158         if (n > 0 && n != sizeof(sz1)) {
159             IF_LEVEL(0) LOG << "short read of sz";
160             return false;
161         }
162
163         size_t sz = ntoh(sz1);
164
165         if (sz > rpc_protocol::MAX_PDU) {
166             IF_LEVEL(2) LOG << "read pdu TOO BIG " << sz << " network order=" << std::hex << sz1;
167             return false;
168         }
169
170         IF_LEVEL(5) LOG << "read size of datagram = " << sz;
171
172         rpdu_ = {inflight, string(sz+sizeof(sz1), 0), sizeof(sz1)};
173     }
174
175     ssize_t n = fd.read(&rpdu_.buf[rpdu_.cursor], rpdu_.buf.size() - rpdu_.cursor);
176
177     IF_LEVEL(5) LOG << "read " << n << " bytes";
178
179     if (n <= 0) {
180         if (errno == EAGAIN)
181             return true;
182         rpdu_ = {unused, "", 0};
183         return false;
184     }
185     rpdu_.cursor += (size_t)n;
186     return true;
187 }
188
189 connection_listener::connection_listener(connection_delegate * delegate, in_port_t port, int lossytest)
190 : tcp_(socket(AF_INET, SOCK_STREAM, 0)), delegate_(delegate), lossy_(lossytest)
191 {
192     tcp_.setsockopt(SOL_SOCKET, SO_REUSEADDR, (int)1);
193     tcp_.setsockopt(IPPROTO_TCP, TCP_NODELAY, (int)1);
194     tcp_.setsockopt(SOL_SOCKET, SO_RCVTIMEO, timeval{0, 50000});
195     tcp_.setsockopt(SOL_SOCKET, SO_SNDTIMEO, timeval{0, 50000});
196
197     sockaddr_in sin = sockaddr_in(); // zero initialize
198     sin.sin_family = AF_INET;
199     sin.sin_port = hton(port);
200
201     if (bind(tcp_, (sockaddr *)&sin, sizeof(sin)) < 0) {
202         perror("accept_loop bind");
203         VERIFY(0);
204     }
205
206     if (listen(tcp_, 1000) < 0) {
207         perror("accept_loop listen");
208         VERIFY(0);
209     }
210
211     socklen_t addrlen = sizeof(sin);
212     VERIFY(getsockname(tcp_, (sockaddr *)&sin, &addrlen) == 0);
213     port_ = ntoh(sin.sin_port);
214
215     IF_LEVEL(2) LOG << "listen on " << port_ << " " << sin.sin_port;
216
217     global->shared_mgr.add_callback(tcp_, CB_RDONLY, this);
218 }
219
220 connection_listener::~connection_listener() {
221     global->shared_mgr.block_remove_fd(tcp_);
222 }
223
224 void connection_listener::read_cb(int) {
225     sockaddr_in sin;
226     socklen_t slen = sizeof(sin);
227     int s1 = accept(tcp_, (sockaddr *)&sin, &slen);
228     if (s1 < 0) {
229         perror("connection_listener::accept_conn error");
230         throw std::runtime_error("connection listener failure");
231     }
232
233     IF_LEVEL(2) LOG << "accept_loop got connection fd=" << s1 << " " << inet_ntoa(sin.sin_addr) << ":" << ntoh(sin.sin_port);
234
235     // garbage collect dead connections
236     for (auto i = conns_.begin(); i != conns_.end();) {
237         if (i->second->isdead())
238             conns_.erase(i++);
239         else
240             ++i;
241     }
242
243     conns_[s1] = std::make_shared<connection>(delegate_, s1, lossy_);
244 }