#!/usr/bin/python3

import argparse
import collections
import copy
import json
import socket

import eventlet
from six.moves.urllib.parse import urlparse
from swift.cli import recon

parser = argparse.ArgumentParser(
    prog="swift_status", description="Script that monitor swift cluster state"
)
parser.add_argument("-s", "--status", action="store_true", help="Return cluster state")
parser.add_argument("-j", "--json", action="store_true", help="Format output as json")
parser.add_argument(
    "-d", "--diskusage", action="store_true", help="Return disk usage for object ring"
)
parser.add_argument(
    "-u", "--unmounted", action="store_true", help="Show unmounted drives"
)
parser.add_argument("--ntp-drift", action="store_true", help="Show NTP drift")
parser.add_argument(
    "--validate-ring", action="store_true", help="Validate the swift ring"
)
parser.add_argument(
    "-f",
    "--diskfilling",
    action="store_true",
    help="Return disk filling for object ring",
)
parser.add_argument(
    "-r", "--reason", action="store_true", help="Return cluster state reason"
)
args = parser.parse_args()

swiftrecon = recon.SwiftRecon()
ring_names = ["object", "container", "account"]
swift_directory = "/etc/swift"
pool_size = 30
pool = eventlet.GreenPool(pool_size)

cluster_state = "OK"
reason = None


def get_hosts(swiftrecon, swift_directory, ring_names):
    hosts = {}
    for ring in ring_names:
        hosts[ring] = swiftrecon.get_hosts(
            region_filter=None,
            zone_filter=None,
            swift_dir=swift_directory,
            ring_names=[ring],
        )
    return hosts


def unmounted(hosts, scout):
    unmounted = {}
    unmounted_data = {}
    for url, response, status, _ts_start, _ts_end in pool.imap(scout.scout, hosts):
        if status == 200:
            unmounted[url] = []
            for i in response:
                if isinstance(i["mounted"], bool):
                    unmounted[url].append(i["device"])
    for host in unmounted:
        if unmounted[host]:
            hostname = socket.gethostbyaddr(urlparse(host).hostname)[0]
            # Already added swiftstore
            try:
                unmounted_data[hostname] = []
            except KeyError:
                pass
        for entry in unmounted[host]:
            node = socket.gethostbyaddr(urlparse(host).hostname)[0]
            # Already added disk
            try:
                unmounted_data[node].append(entry)
            except KeyError:
                pass
    return unmounted_data


def time_check(hosts, scout, jitter=0.0):
    jitter = abs(jitter)
    ntpdrift_data = []
    swiftrecon._ptime
    for url, ts_remote, status, ts_start, ts_end in pool.imap(scout.scout, hosts):
        if status != 200:
            continue
        # Must round as request take some ms
        ts_remote = round(ts_remote, 1)
        ts_start = round(ts_start, 1)
        ts_end = round(ts_end, 1)
        if ts_remote + jitter < ts_start or ts_remote - jitter > ts_end:
            hostname = socket.gethostbyaddr(urlparse(url).hostname)[0]
            if hostname not in ntpdrift_data:
                ntpdrift_data.append(hostname)
            continue
    if ntpdrift_data:
        return ntpdrift_data


def validate_servers(hosts, scout, server_type):
    errors = {}
    validate_data = []
    for url, response, status in pool.imap(scout.scout_server_type, hosts):
        if status == 200:
            if response != server_type + "-server":
                errors[url] = response
    for url in errors:
        hostname = socket.gethostbyaddr(urlparse(url).hostname)[0]
        if hostname not in validate_data:
            validate_data.append(hostname)
    if validate_data:
        return validate_data


def disk_usage(hosts, scout):
    raw_total_used = []
    raw_total_avail = []
    swiftstore = []
    for url, response, status, _ts_start, _ts_end in pool.imap(scout.scout, hosts):
        hostname = socket.gethostbyaddr(urlparse(url).hostname)[0]
        if hostname not in swiftstore:
            if status == 200:
                for entry in response:
                    if entry["mounted"]:
                        raw_total_used.append(entry["used"])
                        raw_total_avail.append(entry["avail"])
            swiftstore.append(hostname)
    raw_used = sum(raw_total_used)
    raw_avail = sum(raw_total_avail)
    raw_total = raw_used + raw_avail
    avg_used = round(100.0 * raw_used / raw_total, 2)
    size = {"used": raw_used, "total": raw_total, "percent": avg_used}
    return size


def disk_filling(hosts, scout):
    stats = {}
    output = {}
    percents = {}
    for url, response, status, _ts_start, _ts_end in pool.imap(scout.scout, hosts):
        if status == 200:
            hostusage = []
            for entry in response:
                if not isinstance(entry["mounted"], bool):
                    pass
                elif entry["mounted"]:
                    used = float(entry["used"]) / float(entry["size"]) * 100.0
                    hostusage.append(round(used, 2))
            stats[url] = hostusage
    for host in stats:
        hostname = socket.gethostbyaddr(urlparse(host).hostname)[0]
        # Already added swiftstore
        try:
            output[hostname] = stats[host]
        except KeyError:
            pass
    for host in output:
        if len(output[host]) > 0:
            for percent in output[host]:
                percents[int(percent)] = percents.get(int(percent), 0) + 1
    return percents


def parse_diskusage(hosts):
    disk_usagedict = {}
    for ring, ring_hosts in hosts.items():
        scout = recon.Scout(recon_type="diskusage")
        disk_usagedict[ring] = disk_usage(ring_hosts, scout)
    return disk_usagedict


def parse_diskfilling(hosts):
    disk_fillingdict = {}
    scout = recon.Scout(recon_type="diskusage")
    disk_fillingdict = collections.OrderedDict(
        sorted(disk_filling(hosts["object"], scout).items())
    )
    return disk_fillingdict


def parse_unmountedregion(swiftrecon, swift_directory, ring_names):
    unmounteddict = {}
    for ring in ring_names:
        unmountedregion = {}
        # Mandatory to parse all regions
        region = 1
        while True:
            hosts = swiftrecon.get_hosts(
                region_filter=region,
                zone_filter=None,
                swift_dir=swift_directory,
                ring_names=[ring],
            )
            if len(hosts) == 0:
                break
            scout = recon.Scout(recon_type="unmounted")
            unmountedregion[region] = unmounted(hosts, scout)
            region += 1
        unmounteddict[ring] = unmountedregion
    return unmounteddict


def parse_ntpdrift(hosts):
    ntpdict = {}
    for ring, ring_hosts in hosts.items():
        scout = recon.Scout(recon_type="time")
        ntpdict[ring] = time_check(ring_hosts, scout)
    return ntpdict


def parse_validate(hosts):
    validatedict = {}
    for ring, ring_hosts in hosts.items():
        # Otherwise consider all servers as object servers
        scout = recon.Scout(recon_type="server_type_check")
        validatedict[ring] = validate_servers(ring_hosts, scout, ring)
    return validatedict


hosts = get_hosts(swiftrecon, swift_directory, ring_names)
diskusage = parse_diskusage(hosts)
ntp = parse_ntpdrift(hosts)
validateservers = parse_validate(hosts)
unmountedperregion = parse_unmountedregion(swiftrecon, swift_directory, ring_names)
diskfilling = parse_diskfilling(hosts)
diskusagehuman = copy.deepcopy(diskusage)

region_with_unmounted = 0
for ring in ring_names:
    for region in range(1, len(unmountedperregion[ring]) + 1):
        if len(unmountedperregion[ring][region]) > 0:
            region_with_unmounted += 1

if region_with_unmounted > 1:
    cluster_state = "Critical"
    reason = "More than 1 region with unmounted drives"
if all(validateservers.values()) and cluster_state == "OK":
    cluster_state = "Warn"
    reason = "Some servers cannot be validate"
if all(ntp.values()) and cluster_state == "OK":
    cluster_state = "Warn"
    reason = "Some servers have ntpdrift"
try:
    list(diskfilling)[
        list(diskfilling).index(max(diskfilling, key=diskfilling.get)) + 2
    ]
    cluster_state = "Warn"
    reason = "Some servers have drive with more than 2% filling than medium fill"
except IndexError:
    pass
# Pretty output
diskfilling = dict({f"{k}%": v for k, v in diskfilling.items()})

for key in diskusagehuman:
    diskusagehuman[key]["used"] = (
        str(round(diskusagehuman[key]["used"] / (1024 * 1024 * 1024 * 1024), 2)) + " TB"
    )
    diskusagehuman[key]["total"] = (
        str(round(diskusagehuman[key]["total"] / (1024 * 1024 * 1024 * 1024), 2))
        + " TB"
    )

formated_output = {
    "status": cluster_state,
    "reason": reason,
    "region_with_unmounted": region_with_unmounted,
    "unmounted_per_region": unmountedperregion,
    "ntp_drift": ntp,
    "validation": validateservers,
    "disk_usage": diskusage,
    "disk_usage_human": diskusagehuman,
    "disk_filling": diskfilling,
}

no_args = not any(
    (
        args.status,
        args.reason,
        args.unmounted,
        args.ntp_drift,
        args.validate_ring,
        args.diskusage,
        args.diskfilling,
    )
)

if args.json:
    print(json.dumps(formated_output))
else:
    if args.status or no_args:
        print(f"Cluster status: {cluster_state}")
    if args.reason or no_args:
        print(f"Cluster state reason: {reason}")
    if args.unmounted or no_args:
        print(f"Region with umounted drives: {region_with_unmounted}")
        print("\nUnmounted disks output per region:")
        for ring in unmountedperregion.keys():
            print(f"  {ring}:")
            for region in range(1, len(unmountedperregion[ring]) + 1):
                print(f"    {region}: {unmountedperregion[ring][region]}")
    if args.ntp_drift or no_args:
        print("\nNtp drift:")
        for ring in ntp.keys():
            print(f"  {ring}: {ntp[ring]}")
    if args.validate_ring or no_args:
        print("\nValidate server output:")
        for ring in validateservers.keys():
            print(f"  {ring}: {validateservers[ring]}")
    if args.diskusage or no_args:
        print("\nDisk usage output:")
        for ring in diskusage.keys():
            print(
                f"  {ring}: {diskusage[ring]['used']} used on {diskusage[ring]['total']} {diskusage[ring]['percent']}%"
            )
        print("\nDisk usage humean readable output:")
        for ring in diskusagehuman.keys():
            print(
                f"  {ring}: {diskusagehuman[ring]['used']} used on {diskusagehuman[ring]['total']} {diskusagehuman[ring]['percent']}%"
            )
    if args.diskfilling or no_args:
        print(f"\nDisk filling for object ring: {diskfilling}")
