#!/usr/bin/env python

import subprocess as sp
import signal
import os
import sys
import time
import getopt
import random

pid = []
logs = []
views = [] # expected views
in_views = {} # the number of views a node is expected to be present
p = []
t = None
always_kill = 0
quit = False

def killprocess(num, frame):
    print "killprocess: forcestop all spawned processes...%s" % (str(pid),)
    global quit
    quit = True
    for p in pid:
        os.kill(p, signal.SIGKILL)

for sig in ['HUP', 'INT', 'ABRT', 'QUIT', 'TERM']:
    num = getattr(signal, 'SIG'+sig)
    signal.signal(num, killprocess)

def paxos_log(port):
    return "paxos-%d.log" % port

def die(*s):
    print >>sys.stderr, ''.join(s)
    exit(1)

def mydie(*s):
    if always_kill:
        killprocess()
    die(*s)

def usleep(us):
    time.sleep(us/1e6)

def cleanup():
    for p in pid:
        os.kill(p, signal.SIGKILL)
    for l in logs:
        try:
            os.unlink(l)
        except OSError:
            pass
    usleep(200000)

def spawn(p, *a):
    sa = map(str, a)
    aa = '-'.join(sa)
    try:
        pid = os.fork()
    except OSError, e:
        mydie("Cannot fork: %s" % (repr(e),))
    if pid:
        # parent
        logs.append("%s-%s.log" % (p, aa))
        if 'lock_server' in p:
            logs.append(paxos_log(a[1]))
        return pid
    else:
        # child
        os.close(1)
        sys.stdout = open("%s-%s.log" % (p, aa), 'w')
        os.close(2)
        os.dup(1)
        sys.stderr = sys.stdout
        print "%s %s" % (p, ' '.join(sa))
        try:
            os.execv(p, [p] + sa)
        except OSError, e:
            mydie("Cannot start new %s %s %s" % (p, repr(sa), repr(e)))

def randports(num):
    return sorted([random.randint(0, 54000/2)*2+10000 for i in xrange(num)])

def print_config(ports):
    try:
        config = open("config", 'w')
    except IOError:
        mydie("Couldn't open config for writing")
    for p in ports:
        print >>config, "%05d" % (p,)
    config.close()

def spawn_ls(master, port):
    return spawn("./lock_server", master, port)

def check_views(l, vs, last_v=None):
    try:
        f = open(l, 'r')
        log = f.readlines()
        f.close()
    except IOError:
        mydie("Failed: couldn't read %s" % (l,))
    i = 0
    last_view = None
    for line in log:
        if not line.startswith('done'):
            continue
        words = line.split(' ')
        num = int(words[1])
        view = map(int, words[2:])
        last_view = view
        if i >= len(vs):
            # let there be extra views
            continue
        expected = vs[i]
        if set(expected) != set(view):
            mydie("Failed: In log %s at view %s is (%s), but expected %s (%s)" %
                  (l, str(num), repr(view), str(i), repr(expected)))
        i+=1
    if i < len(vs):
        mydie("Failed: In log %s, not enough views seen!" % (l,))
    if last_v is not None and set(last_v) != set(last_view):
        mydie("Failed: In log %s last view didn't match, got view %s, but expected %s" %
              (l, repr(last_view), repr(last_v)))

def get_num_views(log, including):
    try:
        f = open(log, 'r')
    except IOError:
        return 0
    log = f.readlines()
    f.close()
    return len([x for x in log if 'done ' in x and str(including) in x])

def wait_for_view_change(log, num_views, including, timeout):
    start = time.time()
    while get_num_views(log, including) < num_views and (start + timeout > time.time()) and not quit:
        try:
            f = open(log, 'r')
            loglines = f.readlines()
            f.close()
            lastv = [x for x in loglines if 'done' in x][-1].strip()
            print "   Waiting for %s to be present in >=%s views in %s (Last view: %s)" % \
                  (including, str(num_views), log, lastv)
            usleep(100000)
        except IOError:
            continue

    if get_num_views(log, including) < num_views:
        mydie("Failed: Timed out waiting for %s to be in >=%s in log %s" %
              (including, str(num_views), log))
    else:
        print "   Done: %s is in >=%s views in %s" % (including, str(num_views), log)

def waitpid_to(pid, to):
    start = time.time()
    done_pid = (0,0)
    while done_pid == (0,0) and (time.time() - start) < to:
        usleep(100000)
        done_pid = os.waitpid(pid, os.WNOHANG)

    if done_pid <= 0:
        os.kill(pid, signal.SIGKILL)
        mydie("Failed: Timed out waiting for process %s" % (str(pid),))
    else:
        return 1

def wait_and_check_expected_view(v):
    views.append(v)
    for vv in v:
        in_views[vv] += 1
    for port in v:
        wait_for_view_change(paxos_log(port), in_views[port], port, 20)
    for port in v:
        log = paxos_log(port)
        check_views(log, views)

def start_nodes(n, command):
    global pid, logs, views
    pid = []
    logs = []
    views = []
    for pp in p:
        in_views[pp] = 0

    for i in xrange(n):
        if command == "ls":
            pid.append(spawn_ls(p[0],p[i]))
            print "Start lock_server on %s" % (str(p[i]),)
        usleep(100000)

        wait_and_check_expected_view(p[:i+1])

options, arguments = getopt.getopt(sys.argv[1:], "s:k")
options = dict(options)

if 's' in options:
    random.seed(options[s])

if 'k' in options:
    always_kill = 1

# get a sorted list of random ports
p = randports(5)
print_config(p)

NUM_TESTS = 17
do_run = [0] * NUM_TESTS

# see which tests are set
if len(arguments):
    for t in arguments:
        t = int(t)
        if t < NUM_TESTS and t >= 0:
            do_run[t] = 1
else:
    # turn on all tests
    for i in xrange(NUM_TESTS):
        do_run[i] = 1

if do_run[0]:
    print "test0: start 3-process lock server"
    start_nodes(3,"ls")
    cleanup()
    usleep(200000)

if do_run[1]:
    print "test1: start 3-process lock server, kill third server"
    start_nodes(3,"ls")
    print "Kill third server (PID: %s) on port %s" % (str(pid[2]), str(p[2]))
    os.kill(pid[2], signal.SIGTERM)
    usleep(500000)
    # it should go through 4 views
    v4 = [p[0], p[1]]
    wait_and_check_expected_view(v4)
    cleanup()
    usleep(200000)

if do_run[2]:
    print "test2: start 3-process lock server, kill first server"
    start_nodes(3,"ls")
    print "Kill first (PID: %d) on port %d" % (pid[0], p[0])
    os.kill(pid[0], signal.SIGTERM)
    usleep(500000)
    # it should go through 4 views
    v4 = [p[1], p[2]]
    wait_and_check_expected_view(v4)
    cleanup()
    usleep(200000)

if do_run[3]:
    print "test3: start 3-process lock_server, kill a server, restart a server"
    start_nodes(3,"ls")
    print "Kill server (PID: %s) on port %s" % (pid[2], p[2])
    os.kill(pid[2], signal.SIGTERM)
    usleep(500000)
    v4 = (p[0], p[1])
    wait_and_check_expected_view(v4)
    print "Restart killed server on port %s" % (p[2],)
    pid[2] = spawn_ls (p[0], p[2])
    usleep(500000)
    v5 = (p[0], p[1], p[2])
    wait_and_check_expected_view(v5)
    cleanup()
    usleep(200000)

if do_run[4]:
    print "test4: 3-process lock_server, kill third server, kill second server, restart third server, kill third server again, restart second server, re-restart third server, check logs"
    start_nodes(3,"ls")
    print "Kill server (PID: %s) on port %s" % (pid[2], p[2])
    os.kill(pid[2], signal.SIGTERM)
    usleep(500000)
    v4 = (p[0], p[1])
    wait_and_check_expected_view(v4)
    print "Kill server (PID: %s) on port %s" % (pid[1], p[1])
    os.kill(pid[1], signal.SIGTERM)
    usleep(500000)
    #no view change can happen because of a lack of majority
    print "Restarting server on port %s" % (p[2],)
    pid[2] = spawn_ls(p[0], p[2])
    usleep(500000)
    #no view change can happen because of a lack of majority
    for port in p[0:1+2]:
        num_v = get_num_views(paxos_log(port), port)
        if num_v != in_views[port]:
            die("%s_v views in ", paxos_log(port), " : no new views should be formed due to the lack of majority" % (num,))
    # kill node 3 again,
    print "Kill server (PID: %s) on port %s" % (pid[2], p[2])
    os.kill(pid[2], signal.SIGTERM)
    usleep(500000)
    print "Restarting server on port %s" % (p[1],)
    pid[1] = spawn_ls(p[0], p[1])
    usleep(700000)
    for port in p[0:1+1]:
        in_views[port] = get_num_views(paxos_log(port), port)
        print "   Node %s is present in " % (port,), in_views[port], " views in ", paxos_log(port)
    print "Restarting server on port %s" % (p[2],)
    pid[2] = spawn_ls(p[0], p[2])
    lastv = (p[0],p[1],p[2])
    for port in lastv:
        wait_for_view_change(paxos_log(port), in_views[port]+1, port, 20)
    # now check the paxos logs and make sure the logs go through the right
    # views
    for port in lastv:
        check_views(paxos_log(port), views, lastv)
    cleanup()

if do_run[5]:
    print "test5: 3-process lock_server, send signal 1 to first server, kill third server, restart third server, check logs"
    start_nodes(3,"ls")
    print "Sending paxos breakpoint 1 to first server on port %s" % (p[0],)
    spawn("./rsm_tester", p[0]+1, "breakpoint", 3)
    usleep(100000)
    print "Kill third server (PID: %s) on port %s" % (pid[2], p[2])
    os.kill(pid[2], signal.SIGTERM)
    usleep(500000)
    for port in p[0:1+2]:
        num_v = get_num_views(paxos_log(port), port)
        if num_v != in_views[port]:
            die("%s_v views in ", paxos_log(port), " : no new views should be formed due to the lack of majority" % (num,))
    print "Restarting third server on port %s" % (p[2],)
    pid[2]= spawn_ls(p[0], p[2])
    lastv = (p[1],p[2])
    for port in lastv:
        wait_for_view_change(paxos_log(port), in_views[port]+1, port, 20)
    usleep(1000000)
    # now check the paxos logs and make sure the logs go through the right
    # views
    for port in lastv:
        check_views(paxos_log(port), views, lastv)
    cleanup()

if do_run[6]:
    print "test6: 4-process lock_server, send signal 2 to first server, kill fourth server, restart fourth server, check logs"
    start_nodes(4,"ls")
    print "Sending paxos breakpoint 2 to first server on port %s" % (p[0],)
    spawn("./rsm_tester", p[0]+1, "breakpoint", 4)
    usleep(100000)
    print "Kill fourth server (PID: %s) on port %s" % (pid[3], p[3])
    os.kill(pid[3], signal.SIGTERM)
    usleep(500000)
    for port in (p[1],p[2]):
        num_v = get_num_views(paxos_log(port), port)
        if num_v != in_views[port]:
            die("%s_v views in ", paxos_log(port), " : no new views should be formed due to the lack of majority" % (num,))
    usleep(500000)
    print "Restarting fourth server on port %s" % (p[3],)
    pid[3] = spawn_ls(p[1], p[3])
    usleep(500000)
    v5 = (p[0],p[1],p[2])
    for port in v5:
        in_views[port]+=1
    views.append(v5)
    usleep(1000000)
    # the 6th view will be (2,3)  or (1,2,3,4)
    v6 = (p[1],p[2])
    for port in v6:
        in_views[port]+=1
    for port in v6:
        wait_for_view_change(paxos_log(port), in_views[port]+1, port, 30)
    # final will be (2,3,4)
    lastv = (p[1],p[2],p[3])
    for port in lastv:
        wait_for_view_change(paxos_log(port), in_views[port]+1, port, 20)
    for port in lastv:
        check_views(paxos_log(port), views, lastv)
    cleanup()

if do_run[7]:
    print "test7: 4-process lock_server, send signal 2 to first server, kill fourth server, kill other servers, restart other servers, restart fourth server, check logs"
    start_nodes(4,"ls")
    print "Sending paxos breakpoint 2 to first server on port %s" % (p[0],)
    spawn("./rsm_tester", p[0]+1, "breakpoint", 4)
    usleep(300000)
    print "Kill fourth server (PID: %s) on port %s" % (pid[3], p[3])
    os.kill(pid[3], signal.SIGTERM)
    usleep(500000)
    print "Kill third server (PID: %s) on port %s" % (pid[2], p[2])
    os.kill(pid[2], signal.SIGTERM)
    print "Kill second server (PID: %s) on port %s" % (pid[1], p[1])
    os.kill(pid[1], signal.SIGTERM)
    usleep(500000)
    print "Restarting second server on port %s" % (p[1],)
    pid[1] = spawn_ls(p[0], p[1])
    usleep(500000)
    print "Restarting third server on port %s" % (p[2],)
    pid[2] = spawn_ls(p[0], p[2])
    usleep(500000)
    #no view change is possible by now because there is no majority
    for port in (p[1],p[2]):
        num_v = get_num_views(paxos_log(port), port)
        if num_v != in_views[port]:
            die("%s_v views in ", paxos_log(port), " : no new views should be formed due to the lack of majority" % (num,))
    print "Restarting fourth server on port %s" % (p[3],)
    pid[3] = spawn_ls(p[1], p[3])
    usleep(500000)
    v5 = (p[0], p[1], p[2])
    views.append(v5)
    for port in v5:
        in_views[port]+=1
    usleep(1500000)
    lastv = (p[1],p[2],p[3])
    for port in lastv:
        wait_for_view_change(paxos_log(port), in_views[port]+1, port, 20)
    for port in lastv:
        check_views(paxos_log(port), views, lastv)
    cleanup()

if do_run[8]:
    print "test8: start 3-process lock service"
    start_nodes(3,"ls")
    print "Start lock_tester %s" % (p[0],)
    t = spawn("./lock_tester", p[0])
    print "   Wait for lock_tester to finish (waitpid %s)" % (t,)
    waitpid_to(t, 600)
    if os.system("grep \"passed all tests successfully\" lock_tester-%s.log" % (p[0],)):
        mydie("Failed lock tester for test 8")
    cleanup()
    usleep(200000)

if do_run[9]:
    print "test9: start 3-process rsm, kill second slave while lock_tester is running"
    start_nodes(3,"ls")
    print "Start lock_tester %s" % (p[0],)
    t = spawn("./lock_tester", p[0])
    usleep(random.randint(1,1000000))
    print "Kill slave (PID: %s) on port %s" % (pid[2], p[2])
    os.kill(pid[2], signal.SIGTERM)
    usleep(300000)
    # it should go through 4 views
    v4 = (p[0], p[1])
    wait_and_check_expected_view(v4)
    print "   Wait for lock_tester to finish (waitpid %s)" % (t,)
    waitpid_to(t, 600)
    if os.system("grep \"passed all tests successfully\" lock_tester-%s.log" % (p[0],)):
        mydie("Failed lock tester for test 9")
    cleanup()
    usleep(200000)

if do_run[10]:
    print "test10: start 3-process rsm, kill second slave and restarts it later while lock_tester is running"
    start_nodes(3,"ls")
    print "Start lock_tester %s" % (p[0],)
    t = spawn("./lock_tester", p[0])
    usleep(random.randint(1,1000000))
    print "Kill slave (PID: %s) on port %s" % (pid[2], p[2])
    os.kill(pid[2], signal.SIGTERM)
    usleep(300000)
    # it should go through 4 views
    v4 = (p[0], p[1])
    wait_and_check_expected_view(v4)
    usleep(300000)
    print "Restarting killed lock_server on port %s" % (p[2],)
    pid[2] = spawn_ls(p[0], p[2])
    v5 = (p[0],p[1],p[2])
    wait_and_check_expected_view(v5)
    print "   Wait for lock_tester to finish (waitpid %s)" % (t,)
    waitpid_to(t, 600)
    if os.system("grep \"passed all tests successfully\" lock_tester-%s.log" % (p[0],)):
        mydie("Failed lock tester for test 10")
    cleanup()
    usleep(200000)

if do_run[11]:
    print "test11: start 3-process rsm, kill primary while lock_tester is running"
    start_nodes(3,"ls")
    print "Start lock_tester %s" % (p[0],)
    t = spawn("./lock_tester", p[0])
    usleep(random.randint(1,1000000))
    print "Kill primary (PID: %s) on port %s" % (pid[0], p[0])
    os.kill(pid[0], signal.SIGTERM)
    usleep(300000)
    # it should go through 4 views
    v4 = (p[1], p[2])
    wait_and_check_expected_view(v4)
    print "   Wait for lock_tester to finish (waitpid %s)" % (t,)
    waitpid_to(t, 600)
    if os.system("grep \"passed all tests successfully\" lock_tester-%s.log" % (p[0],)):
        mydie("Failed lock tester for test 11")
    cleanup()
    usleep(200000)

if do_run[12]:
    print "test12: start 3-process rsm, kill master at break1 and restart it while lock_tester is running"
    start_nodes(3, "ls")
    print "Start lock_tester %s" % (p[0],)
    t = spawn("./lock_tester", p[0])
    usleep(100000)
    print "Kill master (PID: %s) on port %s at breakpoint 1" % (pid[0], p[0])
    spawn("./rsm_tester", p[0]+1, "breakpoint", 1)
    usleep(100000)
    # it should go through 5 views
    v4 = (p[1], p[2])
    wait_and_check_expected_view(v4)
    print "Restarting killed lock_server on port %s" % (p[0],)
    pid[0] = spawn_ls(p[1], p[0])
    usleep(300000)
    # the last view should include all nodes
    lastv = (p[0],p[1],p[2])
    for port in lastv:
        wait_for_view_change(paxos_log(port), in_views[port]+1, port, 20)
    for port in lastv:
        check_views(paxos_log(port), views, lastv)
    print "   Wait for lock_tester to finish (waitpid %s)" % (t,)
    waitpid_to(t, 600)
    if os.system("grep \"passed all tests successfully\" lock_tester-%s.log" % (p[0],)):
        mydie("Failed lock tester for test 12")
    cleanup()
    usleep(200000)

if do_run[13]:
    print "test13: start 3-process rsm, kill slave at break1 and restart it while lock_tester is running"
    start_nodes(3, "ls")
    print "Start lock_tester %s" % (p[0],)
    t = spawn("./lock_tester", p[0])
    usleep(100000)
    print "Kill slave (PID: %s) on port %s at breakpoint 1" % (pid[2], p[2])
    spawn("./rsm_tester", p[2]+1, "breakpoint", 1)
    usleep(100000)
    # it should go through 4 views
    v4 = (p[0], p[1])
    wait_and_check_expected_view(v4)
    print "Restarting killed lock_server on port %s" % (p[2],)
    pid[2] = spawn_ls(p[0], p[2])
    usleep(300000)
    # the last view should include all nodes
    lastv = (p[0],p[1],p[2])
    for port in lastv:
        wait_for_view_change(paxos_log(port), in_views[port]+1, port, 20)
    for port in lastv:
        check_views(paxos_log(port), views, lastv)
    print "   Wait for lock_tester to finish (waitpid %s)" % (t,)
    waitpid_to(t, 600)
    if os.system("grep \"passed all tests successfully\" lock_tester-%s.log" % (p[0],)):
        mydie("Failed lock tester for test 13")
    cleanup()
    usleep(200000)

if do_run[14]:
    print "test14: start 5-process rsm, kill slave break1, kill slave break2"
    start_nodes(5, "ls")
    print "Start lock_tester %s" % (p[0],)
    t = spawn("./lock_tester", p[0])
    usleep(100000)
    print "Kill slave (PID: %s) on port %s at breakpoint 1" % (pid[4], p[4])
    spawn("./rsm_tester", p[4]+1, "breakpoint", 1)
    print "Kill slave (PID: %s) on port %s at breakpoint 2" % (pid[3], p[3])
    spawn("./rsm_tester", p[3]+1, "breakpoint", 2)
    usleep(100000)
    # two view changes:
    print "first view change wait"
    lastv = (p[0],p[1],p[2],p[3])
    for port in lastv:
        wait_for_view_change(paxos_log(port), in_views[port]+1, port, 20)
    print "second view change wait"
    lastv = (p[0],p[1],p[2])
    for port in lastv:
        wait_for_view_change(paxos_log(port), in_views[port]+1, port, 20)
    print "   Wait for lock_tester to finish (waitpid %s)" % (t,)
    waitpid_to(t, 600)
    if os.system("grep \"passed all tests successfully\" lock_tester-%s.log" % (p[0],)):
        mydie("Failed lock tester for test 14")
    cleanup()
    usleep(200000)

if do_run[15]:
    print "test15: start 5-process rsm, kill slave break1, kill primary break2"
    start_nodes(5, "ls")
    print "Start lock_tester %s" % (p[0],)
    t = spawn("./lock_tester", p[0])
    usleep(100000)
    print "Kill slave (PID: %s) on port %s at breakpoint 1" % (pid[4], p[4])
    spawn("./rsm_tester", p[4]+1, "breakpoint", 1)
    print "Kill primary (PID: %s) on port %s at breakpoint 2" % (pid[0], p[0])
    spawn("./rsm_tester", p[0]+1, "breakpoint", 2)
    usleep(100000)
    # two view changes:
    print "first view change wait"
    lastv = (p[0],p[1],p[2],p[3])
    for port in lastv:
        wait_for_view_change(paxos_log(port), in_views[port]+1, port, 20)
    print "second view change wait"
    lastv = (p[1],p[2],p[3])
    for port in lastv:
        wait_for_view_change(paxos_log(port), in_views[port]+1, port, 20)
    print "   Wait for lock_tester to finish (waitpid %s)" % (t,)
    waitpid_to(t, 600)
    if os.system("grep \"passed all tests successfully\" lock_tester-%s.log" % (p[0],)):
        mydie("Failed lock tester for test 15")
    cleanup()
    usleep(200000)

if do_run[16]:
    print "test16: start 3-process rsm, partition primary, heal it"
    start_nodes(3, "ls")
    print "Start lock_tester %s" % (p[0],)
    t = spawn("./lock_tester", p[0])
    usleep(100000)
    print "Partition primary (PID: %s) on port %s at breakpoint" % (pid[0], p[0])
    spawn("./rsm_tester", p[0]+1, "partition", 0)
    usleep(300000)
    print "first view change wait"
    lastv = (p[1],p[2])
    for port in lastv:
        wait_for_view_change(paxos_log(port), in_views[port]+1, port, 20)
    usleep(100000)
    print "Heal partition primary (PID: %s) on port %s at breakpoint" % (pid[0], p[0])
    spawn("./rsm_tester", p[0]+1, "partition", 1)
    usleep(100000)
    # xxx it should test that this is the 5th view!
    print "second view change wait"
    lastv = (p[0], p[1],p[2])
    for port in lastv:
        wait_for_view_change(paxos_log(port), in_views[port]+1, port, 20)
    print "   Wait for lock_tester to finish (waitpid %s)" % (t,)
    waitpid_to(t, 600)
    if os.system("grep \"passed all tests successfully\" lock_tester-%s.log" % (p[0],)):
        mydie("Failed lock tester for test 16")
    cleanup()
    usleep(200000)

print "tests done OK"

try:
    os.unlink("config")
except OSError:
    pass
