#!/usr/bin/env python3

# Sample invocation: PYTHONPATH=build/lib/ kdo quentin/root python3.8 ~/Documents/MIT/SIPB/XVM/invirt-deactivate --uri postgresql://postgres:@localhost:1235/invirt --remote xvm-remote.mit.edu --force 2>&1 | tee shutdown-20201020.log

import argparse

from invirt import database
from invirt.database import record, models
import hesiod
import time
from sqlalchemy import func
from sqlalchemy import *
from sqlalchemy import orm
from sqlalchemy.orm import create_session, relation
from subprocess import check_call, check_output
import yaml

lockers_table = Table(
    'lockers', models.meta,
    Column('name', String, nullable=False, primary_key=True),
    Column('type', Enum("AFS", "ERR", name='locker_type')),
    Column('message', String),
)

class Locker(record.Record):
    _identity_field = 'name'

models.mapper(Locker, lockers_table)

def print_list(l):
    for i in sorted(l, key=lambda i: str(i)):
        if not hasattr(i, '__iter__') or isinstance(i, (str, bytes)):
            i = (i,)
        print("\t".join(str(x) for x in i))

def main():
    parser = argparse.ArgumentParser(description='Set memory, disk, and VM quotas')

    parser.add_argument('-u', '--uri', type=str, dest='uri',
                        help='Database URI (e.g. postgresql://postgres:@localhost:1234/invirt)')
    parser.add_argument('-r', '--remote-host', type=str, dest='remote',
                        default='xvm-remote-dev.mit.edu',
                        help='Remote host')
    parser.add_argument('-f', '--force', action='store_true', dest='force',
                        help='Shut VMs down')

    args = parser.parse_args()

    database.connect(args.uri)
    database.session.begin()

    lockers_table.create(checkfirst=True)

    owners = database.session.query(database.Machine.owner).distinct()
    for o, in owners:
        l = Locker.query.get(o)
        if not l:
            l = Locker(name=o)
        try:
            fs = hesiod.FilsysLookup(o).filsys
            l.type=fs[0]['type']
            l.message=fs[0].get('message')
        except FileNotFoundError:
            l.type=None
            l.message=None
        except:
            print("Error looking up", o)
            raise
    database.session.commit()

    database.session.begin()
    acl_locker_no_err = database.session.query(database.MachineAccess.machine_id).outerjoin(Locker, database.MachineAccess.user == Locker.name).filter((Locker.type == None) | (Locker.type != "ERR")).group_by(database.MachineAccess.machine_id).subquery()
    machines_no_access = database.session.query(database.Machine).outerjoin(acl_locker_no_err).filter(acl_locker_no_err.c.machine_id == None).all()

    machines_no_locker = database.session.query(database.Machine).join(Locker, database.Machine.owner == Locker.name).filter(Locker.type == None).all()
    machines_err_locker = database.session.query(database.Machine, Locker.message).join(Locker, database.Machine.owner == Locker.name).filter(Locker.type == 'ERR').all()

    print("Machines with no admin users:\n")
    print_list(machines_no_access)
    print("\nMachines with missing locker:\n")
    print_list(machines_no_locker)
    print("\nMachines with ERR locker:\n")
    print_list(machines_err_locker)

    listvms = yaml.safe_load(check_output(
        ['remctl', args.remote, 'web', 'listvms'],
    ))

    machines_running = set(listvms)
    machines_broken_locker = set(machines_no_locker) | set(m for (m,msg) in machines_err_locker)
    machines_deactivate = set(machines_no_access) & machines_broken_locker
    print("\nMachines with no admin users AND broken locker:\n")
    print_list(machines_deactivate)
    machines_to_shut_down = database.Machine.query.filter(database.Machine.name.in_(machines_running & set(m.name for m in machines_deactivate))).all()
    print("\nMachines to shutdown%s:\n" % (" (WILL SHUTDOWN)" if args.force else ""))
    print_list(machines_to_shut_down)

    if input("Are you sure (yes/NO)?") != "yes":
        return

    if args.force:
        for m in machines_to_shut_down:
            m.adminable = True

    database.session.commit()

    if args.force:
        for m in machines_to_shut_down:
            c = ['remctl', args.remote, 'control', m.name, 'shutdown']
            print(' '.join(c))
            try:
                check_call(c)
            except:
                print("Failed.")
        print("Waiting 30 seconds for VMs to exit")
        time.sleep(30)
        for m in machines_to_shut_down:
            c = ['remctl', args.remote, 'control', m.name, 'destroy']
            print(' '.join(c))
            try:
                check_call(c)
            except:
                print("Failed.")

if __name__ == '__main__':
    main()