More renaming
[invirt/third/libt4.git] / lock_client.cc
index a71a206..d996b40 100644 (file)
-// RPC stubs for clients to talk to lock_server
+// RPC stubs for clients to talk to lock_server, and cache the locks.
 
 #include "lock_client.h"
 #include "rpc/rpc.h"
-#include <arpa/inet.h>
-
 #include <sstream>
 #include <iostream>
+#include <algorithm>
 #include <stdio.h>
+#include "tprintf.h"
+#include <arpa/inet.h>
 
-lock_client::lock_client(std::string dst)
+#include "rsm_client.h"
+#include "lock.h"
+
+using std::ostringstream;
+
+lock_state::lock_state():
+    state(none)
 {
+}
+
+void lock_state::wait() {
+    auto self = std::this_thread::get_id();
+    {
+        adopt_lock ml(m);
+        c[self].wait(ml);
+    }
+    c.erase(self);
+}
+
+void lock_state::signal() {
+    // signal anyone
+    if (c.begin() != c.end())
+        c.begin()->second.notify_one();
+}
+
+void lock_state::signal(std::thread::id who) {
+    if (c.count(who))
+        c[who].notify_one();
+}
+
+int lock_client::last_port = 0;
+
+lock_state & lock_client::get_lock_state(lock_protocol::lockid_t lid) {
+    lock sl(lock_table_lock);
+    // by the semantics of std::map, this will create
+    // the lock if it doesn't already exist
+    return lock_table[lid];
+}
+
+lock_client::lock_client(string xdst, class lock_release_user *_lu) : lu(_lu) {
     sockaddr_in dstsock;
-    make_sockaddr(dst.c_str(), &dstsock);
+    make_sockaddr(xdst.c_str(), &dstsock);
     cl = new rpcc(dstsock);
     if (cl->bind() < 0) {
         printf("lock_client: call bind\n");
     }
+
+    srand(time(NULL)^last_port);
+    rlock_port = ((rand()%32000) | (0x1 << 10));
+    const char *hname;
+    // VERIFY(gethostname(hname, 100) == 0);
+    hname = "127.0.0.1";
+    ostringstream host;
+    host << hname << ":" << rlock_port;
+    id = host.str();
+    last_port = rlock_port;
+    rpcs *rlsrpc = new rpcs(rlock_port);
+    rlsrpc->reg(rlock_protocol::revoke, &lock_client::revoke_handler, this);
+    rlsrpc->reg(rlock_protocol::retry, &lock_client::retry_handler, this);
+    {
+        lock sl(xid_mutex);
+        xid = 0;
+    }
+    rsmc = new rsm_client(xdst);
+    releaser_thread = std::thread(&lock_client::releaser, this);
 }
 
-int
-lock_client::stat(lock_protocol::lockid_t lid)
-{
+void lock_client::releaser() {
+    while (1) {
+        lock_protocol::lockid_t lid;
+        release_fifo.deq(&lid);
+        LOG("Releaser: " << lid);
+
+        lock_state &st = get_lock_state(lid);
+        lock sl(st.m);
+        VERIFY(st.state == lock_state::locked && st.held_by == releaser_thread.get_id());
+        st.state = lock_state::releasing;
+        {
+            sl.unlock();
+            int r;
+            rsmc->call(lock_protocol::release, r, lid, id, st.xid);
+            sl.lock();
+        }
+        st.state = lock_state::none;
+        LOG("Lock " << lid << ": none");
+        st.signal();
+    }
+}
+
+int lock_client::stat(lock_protocol::lockid_t lid) {
+    VERIFY(0);
     int r;
     lock_protocol::status ret = cl->call(lock_protocol::stat, r, cl->id(), lid);
     VERIFY (ret == lock_protocol::OK);
     return r;
 }
 
-lock_protocol::status
-lock_client::acquire(lock_protocol::lockid_t lid)
-{
-    int r;
-    return cl->call(lock_protocol::acquire, r, cl->id(), lid);
+lock_protocol::status lock_client::acquire(lock_protocol::lockid_t lid) {
+    lock_state &st = get_lock_state(lid);
+    lock sl(st.m);
+    auto self = std::this_thread::get_id();
+
+    // check for reentrancy
+    VERIFY(st.state != lock_state::locked || st.held_by != self);
+    VERIFY(find(st.wanted_by.begin(), st.wanted_by.end(), self) == st.wanted_by.end());
+
+    st.wanted_by.push_back(self);
+
+    while (1) {
+        if (st.state != lock_state::free)
+            LOG("Lock " << lid << ": not free");
+
+        if (st.state == lock_state::none || st.state == lock_state::retrying) {
+            if (st.state == lock_state::none) {
+                lock sl(xid_mutex);
+                st.xid = xid++;
+            }
+            st.state = lock_state::acquiring;
+            LOG("Lock " << lid << ": acquiring");
+            lock_protocol::status result;
+            {
+                sl.unlock();
+                int r;
+                result = rsmc->call(lock_protocol::acquire, r, lid, id, st.xid);
+                sl.lock();
+            }
+            LOG("acquire returned " << result);
+            if (result == lock_protocol::OK) {
+                st.state = lock_state::free;
+                LOG("Lock " << lid << ": free");
+            }
+        }
+
+        VERIFY(st.wanted_by.size() != 0);
+        if (st.state == lock_state::free) {
+            // is it for me?
+            auto front = st.wanted_by.front();
+            if (front == releaser_thread.get_id()) {
+                st.wanted_by.pop_front();
+                st.state = lock_state::locked;
+                st.held_by = releaser_thread.get_id();
+                LOG("Queuing " << lid << " for release");
+                release_fifo.enq(lid);
+            } else if (front == self) {
+                st.wanted_by.pop_front();
+                st.state = lock_state::locked;
+                st.held_by = self;
+                break;
+            } else {
+                st.signal(front);
+            }
+        }
+
+        LOG("waiting...");
+        st.wait();
+        LOG("wait ended");
+    }
+
+    LOG("Lock " << lid << ": locked");
+    return lock_protocol::OK;
 }
 
-lock_protocol::status
-lock_client::release(lock_protocol::lockid_t lid)
-{
-    int r;
-    return cl->call(lock_protocol::release, r, cl->id(), lid);
+lock_protocol::status lock_client::release(lock_protocol::lockid_t lid) {
+    lock_state &st = get_lock_state(lid);
+    lock sl(st.m);
+    auto self = std::this_thread::get_id();
+    VERIFY(st.state == lock_state::locked && st.held_by == self);
+    st.state = lock_state::free;
+    LOG("Lock " << lid << ": free");
+    if (st.wanted_by.size()) {
+        auto front = st.wanted_by.front();
+        if (front == releaser_thread.get_id()) {
+            st.state = lock_state::locked;
+            st.held_by = releaser_thread.get_id();
+            st.wanted_by.pop_front();
+            LOG("Queuing " << lid << " for release");
+            release_fifo.enq(lid);
+        } else
+            st.signal(front);
+    }
+    LOG("Finished signaling.");
+    return lock_protocol::OK;
+}
+
+rlock_protocol::status lock_client::revoke_handler(int &, lock_protocol::lockid_t lid, lock_protocol::xid_t xid) {
+    LOG("Revoke handler " << lid << " " << xid);
+    lock_state &st = get_lock_state(lid);
+    lock sl(st.m);
+
+    if (st.state == lock_state::releasing || st.state == lock_state::none)
+        return rlock_protocol::OK;
+
+    if (st.state == lock_state::free &&
+        (st.wanted_by.size() == 0 || st.wanted_by.front() == releaser_thread.get_id())) {
+        // gimme
+        st.state = lock_state::locked;
+        st.held_by = releaser_thread.get_id();
+        if (st.wanted_by.size())
+            st.wanted_by.pop_front();
+        release_fifo.enq(lid);
+    } else {
+        // get in line
+        st.wanted_by.push_back(releaser_thread.get_id());
+    }
+    return rlock_protocol::OK;
+}
+
+rlock_protocol::status lock_client::retry_handler(int &, lock_protocol::lockid_t lid, lock_protocol::xid_t xid) {
+    lock_state &st = get_lock_state(lid);
+    lock sl(st.m);
+    VERIFY(st.state == lock_state::acquiring);
+    st.state = lock_state::retrying;
+    LOG("Lock " << lid << ": none");
+    st.signal(); // only one thread needs to wake up
+    return rlock_protocol::OK;
 }
 
 t4_lock_client *t4_lock_client_new(const char *dst) {
@@ -60,4 +245,3 @@ t4_status t4_lock_client_release(t4_lock_client *client, t4_lockid_t lid) {
 t4_status t4_lock_client_stat(t4_lock_client *client, t4_lockid_t lid) {
     return ((lock_client *)client)->stat(lid);
 }
-