#!/usr/libexec/platform-python

###############################################################################
# core_params: used to store variables required for deploying ceph via cockpit
#              for cockpit-ceph-deploy.
#
# Copyright (C) 2021, Mark Hooper   <mhooper@45drives.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
###############################################################################

from optparse import OptionParser
from contextlib import redirect_stdout
from dataclasses import dataclass, asdict, field
import sys
import io
import subprocess
import os
import json
import copy
from typing import List
from datetime import datetime

#####################################
# DATA CLASSES
#####################################


@dataclass()
class AdInfo:
    workgroup: str = ""
    id_map_range: str = "100000 - 999999"
    realm: str = ""
    winbind_enum_groups: str = "no"
    winbind_enum_users: str = "no"
    domain_join_user: str = ""
    domain_join_password: str = ""

    def load_dictionary(self, dict_variable: dict):
        ret_val = False
        if isinstance(dict_variable, dict):
            for key in dict_variable.keys():
                if key in asdict(self).keys():
                    if type(dict_variable[key]) == type(getattr(self, key)):
                        setattr(self, key, dict_variable[key])
                        ret_val = True
        else:
            error_msg = {"error_msg": "unable to load dictionary."}
            print(json.dumps(error_msg, indent=4))
        return ret_val


@dataclass()
class CtdbPublicAddress:
    vip_address: str
    vip_interface: str
    subnet_mask: str

    def load_dictionary(self, dict_variable: dict):
        ret_val = False
        if isinstance(dict_variable, dict):
            for key in dict_variable.keys():
                if key in asdict(self).keys():
                    if type(dict_variable[key]) == type(getattr(self, key)):
                        setattr(self, key, dict_variable[key])
                        ret_val = True
        else:
            error_msg = {"error_msg": "unable to load dictionary."}
            print(json.dumps(error_msg, indent=4))
        return ret_val


@dataclass()
class Smbs:
    
    smb_active_directory: bool = False
    smb_local_users: bool = False
    samba_server: bool = True
    samba_cluster: bool = True
    domain_member: bool = False
    configure_shares: bool = False
    ctdb_public_addresses: List[CtdbPublicAddress] = field(default_factory=list)
    active_directory_info: AdInfo = field(default_factory=dict)

    def load_dictionary(self, dict_variable: dict):
        ret_val = False
        if isinstance(dict_variable, dict):
            for key in dict_variable.keys():
                if key in asdict(self).keys():
                    if type(dict_variable[key]) == type(getattr(self, key)):
                        setattr(self, key, dict_variable[key])
                        ret_val = True
        else:
            error_msg = {"error_msg": "unable to load dictionary."}
            print(json.dumps(error_msg, indent=4))
        return ret_val

@dataclass()
class Nfss:
    nfs_active_active: bool = False
    nfs_active_passive: bool = False
    ceph_nfs_floating_ip_address: str = ""
    ceph_nfs_floating_ip_cidr: str = ""
    ceph_nfs_rados_backend_driver: str = ""

    def load_dictionary(self, dict_variable: dict):
        ret_val = False
        if isinstance(dict_variable, dict):
            for key in dict_variable.keys():
                if key in asdict(self).keys():
                    if type(dict_variable[key]) == type(getattr(self, key)):
                        setattr(self, key, dict_variable[key])
                        ret_val = True
        else:
            error_msg = {"error_msg": "unable to load dictionary."}
            print(json.dumps(error_msg, indent=4))
        return ret_val


@dataclass()
class Rgwloadbalancers:
    haproxy_frontend_port: str = ""
    virtual_ips: List = field(default_factory=lambda: [])
    virtual_ip_netmask: str = ""
    virtual_ip_interface: str = ""
    haproxy_frontend_ssl_port: str = ""
    haproxy_frontend_ssl_certificate: str = ""
    haproxy_ssl_dh_param: str = ""
    haproxy_ssl_ciphers: List = field(default_factory=lambda: [])
    haproxy_ssl_options: List = field(default_factory=lambda: [])
    enable_ssl: bool = False

    def load_dictionary(self, dict_variable: dict):
        ret_val = False
        if isinstance(dict_variable, dict):
            for key in dict_variable.keys():
                if key in asdict(self).keys():
                    if type(dict_variable[key]) == type(getattr(self, key)):
                        setattr(self, key, dict_variable[key])
                        ret_val = True
        else:
            error_msg = {"error_msg": "unable to load dictionary."}
            print(json.dumps(error_msg, indent=4))
        return ret_val


@dataclass()
class Host:
    hostname: str = ""
    radosgw_address: str = ""
    monitor_interface: str = ""

    def load_dictionary(self, dict_variable: dict):
        ret_val = False
        if isinstance(dict_variable, dict):
            for key in dict_variable.keys():
                if key in asdict(self).keys():
                    if type(dict_variable[key]) == type(getattr(self, key)):
                        setattr(self, key, dict_variable[key])
                        ret_val = True
        else:
            error_msg = {"error_msg": "unable to load dictionary."}
            print(json.dumps(error_msg, indent=4))
        return ret_val


@dataclass()
class Roles:
    mons: List = field(default_factory=lambda: [])
    mgrs: List = field(default_factory=lambda: [])
    osds: List = field(default_factory=lambda: [])
    metrics: List = field(default_factory=lambda: [])
    mdss: List = field(default_factory=lambda: [])
    smbs: List = field(default_factory=lambda: [])
    nfss: List = field(default_factory=lambda: [])
    iscsigws: List = field(default_factory=lambda: [])
    rgws: List = field(default_factory=lambda: [])
    rgwloadbalancers: List = field(default_factory=lambda: [])
    client: List = field(default_factory=lambda: [])

    def load_dictionary(self, dict_variable: dict):
        ret_val = False
        if isinstance(dict_variable, dict):
            for key in dict_variable.keys():
                if key in asdict(self).keys():
                    if isinstance(dict_variable[key], list):
                        for hostname in dict_variable[key]:
                            if isinstance(hostname, str):
                                getattr(self, key).append(hostname)
                                ret_val = True
        else:
            error_msg = {"error_msg": "unable to load dictionary."}
            print(json.dumps(error_msg, indent=4))
        return ret_val


@dataclass()
class Options:
    monitor_interface: str = ""
    public_network: str = ""
    cluster_network: str = ""
    radosgw_interface: str = ""
    ip_version: str = ""
    radosgw_civetweb_port: str = ""
    radosgw_frontend_type: str = ""
    radosgw_civetweb_num_threads: str = ""
    radosgw_civetweb_options: str = "num_threads=\{\{ radosgw_civetweb_num_threads \}\}"
    hybrid_cluster: bool = False
    offline_install: bool = False

    def load_dictionary(self, dict_variable: dict):
        ret_val = False
        if isinstance(dict_variable, dict):
            for key in dict_variable.keys():
                if key in asdict(self).keys():
                    if type(dict_variable[key]) == type(getattr(self, key)):
                        setattr(self, key, dict_variable[key])
                        ret_val = True
        else:
            error_msg = {"error_msg": "unable to load dictionary."}
            print(json.dumps(error_msg, indent=4))
        return ret_val


#####################################
# GLOBALS
#####################################
g_ceph_core_params = {}
g_default_params = {}
g_param_file_content = None
g_param_file_path = "/usr/share/cockpit/ceph-deploy/params/core_params.json"
g_param_file_dir = "/usr/share/cockpit/ceph-deploy/params"
g_debug_flag = False

#####################################
# FUNCTIONS
#####################################


def print_to_string(item):
    with io.StringIO() as buf, redirect_stdout(buf):
        print(item)
        output = buf.getvalue().rstrip("\n")
        return output


def debug_print(item, tag):
    if g_debug_flag:
        print("****************************************************************")
        print("* {t}".format(t=tag))
        print("****************************************************************")
        print(item)
        print("****************************************************************")


def get_default_params():
    default_params = {}
    default_params["hosts"] = {}
    default_roles = Roles()
    default_params["roles"] = copy.deepcopy(asdict(default_roles))
    default_options = Options()
    default_params["options"] = copy.deepcopy(asdict(default_options))
    default_groups = {"rgwloadbalancers": {}, "nfss": {}, "smbs":{}}
    default_params["groups"] = copy.deepcopy((default_groups))
    default_params["time_stamp"] = datetime.today().strftime("%Y-%m-%d-%H-%M")
    return default_params.copy()


def load_param_file():
    # check to see if file needs to be created.
    if not os.path.exists(g_param_file_path):
        # file doesn't exist
        if not os.path.exists(g_param_file_dir):
            # file directory doesn't exist, make required directories
            os.makedirs(g_param_file_dir)
        # create new parameter file with default parameters
        param_file = open(g_param_file_path, "w")
        d_params = get_default_params()
        param_file.write(json.dumps(d_params, indent=4))
        param_file.close()
        debug_print(json.dumps(json.loads(json.dumps(
            d_params, indent=4)), indent=4), "PARAM_FILE_CREATED")
        # return the default json object
        return json.loads(json.dumps(d_params, indent=4))

    else:
        # file exists on disk, load contents and return json object.
        f = open(g_param_file_path, "r")
        param_file_content = print_to_string(f.read())
        f.close()
        try:
            parsed_json = json.loads(param_file_content)
        except ValueError as err:
            # failed to get json from command line arguments
            error_string = print_to_string(err)
            error_msg = {"error_msg": "JSON Parse Error ({p}): ".format(
                p=g_param_file_path) + error_string}
            print(json.dumps(error_msg, indent=4))
            sys.exit(1)
        debug_print(json.dumps(parsed_json, indent=4),
                    "PARAM_FILE_LOADED_FROM_DISK")
        return parsed_json


def check_root():
    # Check to see if user has root access
    root_test = subprocess.run(
        ["ls", "/root"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode
    if root_test:
        error_msg = {
            "error_msg": "/usr/share/cockpit/ceph-deploy/helper_scripts/core_params must be run with root privileges"}
        print(json.dumps(error_msg, indent=4))
        sys.exit(root_test)


def load_groups(group_string):
    try:
        groups_obj = json.loads(group_string)
    except ValueError as err:
        # failed to get json from command line arguments
        error_string = print_to_string(err)
        error_msg = {"error_msg": "Error parsing groups JSON" + error_string}
        print(json.dumps(error_msg, indent=4))
        sys.exit(1)

    groups_dict = {}

    if len(groups_obj.keys()) > 0:
        for groupname in groups_obj.keys():
            if isinstance(groups_obj[groupname], dict):
                if groupname == "rgwloadbalancers":
                    rgwlb = Rgwloadbalancers()
                    rgwlb.load_dictionary(groups_obj[groupname])
                    valid_group_dict = asdict(rgwlb)
                    groups_dict[groupname] = copy.deepcopy(valid_group_dict)
                elif groupname == "nfss":
                    nfss = Nfss()
                    nfss.load_dictionary(groups_obj[groupname])
                    valid_group_dict = asdict(nfss)
                    groups_dict[groupname] = copy.deepcopy(valid_group_dict)
                elif groupname == "smbs":
                    smbs = Smbs()
                    smbs.load_dictionary(groups_obj[groupname])
                    valid_group_dict = asdict(smbs)
                    groups_dict[groupname] = copy.deepcopy(valid_group_dict)
            else:
                error_msg = {
                    "error_msg": "Invalid format, group must be a JSON object."}
                print(json.dumps(error_msg, indent=4))
                sys.exit(1)
    else:
        error_msg = {"error_msg": "Cannot operate on empty group (JSON)"}
        print(json.dumps(error_msg, indent=4))
        sys.exit(1)

    return copy.deepcopy(groups_dict)


def load_hosts(host_string):
    # parameters have been provided
    try:
        # Load in json
        hosts_obj = json.loads(host_string)
    except ValueError as err:
        # failed to get json from command line arguments
        error_string = print_to_string(err)
        error_msg = {"error_msg": "Error parsing hosts JSON" + error_string}
        print(json.dumps(error_msg, indent=4))
        sys.exit(1)

    hosts_dict = {}

    if len(hosts_obj.keys()) > 0:
        for hostname in hosts_obj.keys():
            if isinstance(hosts_obj[hostname], dict):
                if "hostname" in hosts_obj[hostname].keys() and hosts_obj[hostname]["hostname"] != hostname:
                    error_msg = {
                        "error_msg": "Invalid hostname parameter provided for {h} (JSON)".format(h=hostname)}
                    print(json.dumps(error_msg, indent=4))
                    sys.exit(1)
                else:
                    new_host = Host()
                    new_host.load_dictionary(hosts_obj[hostname])
                    valid_host_dict = asdict(new_host)
                    valid_host_dict["hostname"] = hostname
                    hosts_dict[hostname] = copy.deepcopy(valid_host_dict)
            else:
                error_msg = {
                    "error_msg": "Invalid format, host must be a JSON object."}
                print(json.dumps(error_msg, indent=4))
                sys.exit(1)
    else:
        error_msg = {"error_msg": "Cannot operate on empty hosts (JSON)"}
        print(json.dumps(error_msg, indent=4))
        sys.exit(1)

    return copy.deepcopy(hosts_dict)


def load_roles(roles_string, hosts, param_file_content):
    debug_print(json.dumps(param_file_content, indent=4),
                "PARAM FILE CONTENT IN load_roles()")
    try:
        roles_obj = json.loads(roles_string)
    except ValueError as err:
        # failed to get json from command line arguments
        error_string = print_to_string(err)
        error_msg = {"error_msg": "Error parsing roles (JSON)" + error_string}
        print(json.dumps(error_msg, indent=4))
        sys.exit(1)

    roles_dict = {}

    if len(roles_obj.keys()) > 0:
        new_roles = Roles()
        new_roles.load_dictionary(roles_obj)
        for role_name in roles_obj.keys():
            if role_name not in asdict(new_roles).keys():
                error_msg = {
                    "error_msg": "Invalid role name {r}".format(r=role_name)}
                print(json.dumps(error_msg, indent=4))
                sys.exit(1)
        valid_roles_dict = copy.deepcopy(roles_obj)
        roles_dict = copy.deepcopy(valid_roles_dict)
        for role_name in roles_dict.keys():
            for hostname in roles_dict[role_name]:
                if hostname not in hosts.keys() and hostname not in param_file_content["hosts"].keys():
                    error_msg = {"error_msg": "Cannot give host '{h}' a '{r}' role. '{h}' does not exist.".format(
                        h=hostname, r=role_name)}
                    print(json.dumps(error_msg, indent=4))
                    sys.exit(1)
    else:
        error_msg = {
            "error_msg": "Cannot operate on empty roles object (JSON)"}
        print(json.dumps(error_msg, indent=4))
        sys.exit(1)

    return copy.deepcopy(roles_dict)


def load_opts(options_string):
    try:
        opts_obj = json.loads(options_string)
    except ValueError as err:
        # failed to get json from command line arguments
        error_string = print_to_string(err)
        error_msg = {
            "error_msg": "Error parsing options (JSON)" + error_string}
        print(json.dumps(error_msg, indent=4))
        sys.exit(1)

    opts_dict = {}

    if len(opts_obj.keys()) > 0:
        new_opts = Options()
        new_opts.load_dictionary(opts_obj)
        for opt_name in opts_obj.keys():
            if opt_name not in asdict(new_opts).keys():
                error_msg = {
                    "error_msg": "Invalid option name {o}".format(o=opt_name)}
                print(json.dumps(error_msg, indent=4))
                sys.exit(1)
        valid_opts_dict = copy.deepcopy(opts_obj)
        opts_dict = copy.deepcopy(valid_opts_dict)
    else:
        error_msg = {
            "error_msg": "Cannot operate on empty roles object (JSON)"}
        print(json.dumps(error_msg, indent=4))
        sys.exit(1)

    return copy.deepcopy(opts_dict)


def delete_items(hosts, roles, opts, groups, param_file_content):
    if bool(opts):
        for option_field in opts.keys():
            param_file_content["options"][option_field] = ""
    if bool(roles):
        for role_name in roles.keys():
            for hostname in roles[role_name]:
                if hostname in param_file_content["roles"][role_name]:
                    param_file_content["roles"][role_name].remove(hostname)
    if bool(hosts):
        for host in hosts.keys():
            if host in param_file_content["hosts"].keys():
                del param_file_content["hosts"][host]
            for role_name in param_file_content["roles"].keys():
                if host in param_file_content["roles"][role_name]:
                    param_file_content["roles"][role_name].remove(host)
    if bool(groups):
        for group in groups.keys():
            if group in param_file_content["groups"].keys():
                del param_file_content["groups"][group]

    result_msg = write_to_parameter_file(param_file_content)
    if "success_msg" in result_msg.keys():
        result_msg["success_msg"] = "Items have been removed"
        print(json.dumps(result_msg, indent=4))
        sys.exit(0)


def write_to_parameter_file(new_content):
    global g_param_file_content
    global g_param_file_path
    global g_param_file_dir

    new_content["time_stamp"] = datetime.today().strftime("%Y-%m-%d-%H-%M")

    if not os.path.exists(g_param_file_path):
        # file doesn't exist
        if not os.path.exists(g_param_file_dir):
            # file directory doesn't exist, make required directories
            os.makedirs(g_param_file_dir)
    # create new parameter file using supplied parameters
    param_file = open(g_param_file_path, "w")
    param_file.write(json.dumps(new_content, indent=4))
    param_file.close()
    success_msg = {"success_msg": "Parameter file has been updated"}
    success_msg["old_file_content"] = copy.deepcopy(g_param_file_content)
    g_param_file_content = load_param_file()
    success_msg["new_file_content"] = g_param_file_content
    return success_msg


def merge_with_existing(hosts, roles, opts, groups, param_file_content):
    debug_print(json.dumps(param_file_content, indent=4),
                "param_file_content (before merge)")
    for hostname in hosts.keys():
        param_file_content["hosts"][hostname] = copy.deepcopy(hosts[hostname])
    for role_name in roles.keys():
        param_file_content["roles"][role_name] = copy.deepcopy(
            list(set(param_file_content["roles"][role_name] + roles[role_name])))
        param_file_content["roles"][role_name].sort()
    for opt_name in opts.keys():
        param_file_content["options"][opt_name] = copy.deepcopy(opts[opt_name])
    for group_name in groups.keys():
        param_file_content["groups"][group_name] = copy.deepcopy(
            groups[group_name])
    debug_print(json.dumps(param_file_content, indent=4),
                "param_file_content (after merge)")

    result_msg = write_to_parameter_file(param_file_content)
    if "success_msg" in result_msg.keys():
        print(json.dumps(result_msg, indent=4))
        sys.exit(0)


def show_existing_file(param_file_content):
    success_msg = {"success_msg": " parameter file loaded."}
    success_msg["old_file_content"] = param_file_content
    print(json.dumps(success_msg, indent=4))
    sys.exit(0)

def show_default(default_content):
    success_msg = {"success_msg": " default parameter file"}
    success_msg["default_content"] = default_content
    print(json.dumps(success_msg, indent=4))
    sys.exit(0)


def print_args_to_string():
    arguments = len(sys.argv) - 1
    position = 1
    arg_string = ""
    while (arguments >= position):
        arg_string += "Parameter %i: %s\n" % (position, sys.argv[position])
        position = position + 1

#####################################
# MAIN
#####################################


def main():
    check_root()
    parser = OptionParser(add_help_option=False)
    parser.add_option("-h", "--hosts", action="store",
                      dest="hosts", default=None)
    parser.add_option("-r", "--roles", action="store",
                      dest="roles", default=None)
    parser.add_option("-o", "--options", action="store",
                      dest="opts", default=None)
    parser.add_option("-g", "--groups", action="store",
                      dest="groups", default=None)
    parser.add_option("-s", "--show-existing",
                      action="store_true", dest="show_existing", default=False)
    parser.add_option("-w", "--write", action="store_true",
                      dest="write", default=False)
    parser.add_option("-x", "--delete", action="store_true",
                      dest="delete_items", default=False)
    parser.add_option("-d", "--debug", action="store_true",
                      dest="debug", default=False)
    parser.add_option("-e", "--show-default", action="store_true",
                      dest="show_default", default=False)
    (options, args) = parser.parse_args()

    global g_ceph_core_params
    global g_default_params
    global g_param_file_content
    global g_param_file_path
    global g_param_file_dir
    global g_debug_flag

    if options.debug:
        g_debug_flag = True
        debug_print(print_args_to_string(), "command line args: ")

    # load existing param file content
    g_param_file_content = load_param_file()
    g_default_params = get_default_params()
    hosts = {}
    roles = {}
    opts = {}
    groups = {}

    if options.show_existing:
        # show existing file content and exit
        show_existing_file(g_param_file_content)

    if options.show_default:
        # show existing file content and exit
        show_default(g_default_params)

    if options.hosts != None:
        hosts = load_hosts(options.hosts)
        debug_print(json.dumps(hosts, indent=4), "hosts")

    if options.roles != None:
        roles = load_roles(options.roles, hosts, g_param_file_content)
        debug_print(json.dumps(roles, indent=4), "roles")

    if options.opts != None:
        opts = load_opts(options.opts)
        debug_print(json.dumps(opts, indent=4), "opts")

    if options.groups != None:
        groups = load_groups(options.groups)
        debug_print(json.dumps(groups, indent=4), "groups")

    if options.delete_items:
        delete_items(hosts, roles, opts, groups,
                     copy.deepcopy(g_param_file_content))

    if options.write:
        merge_with_existing(hosts, roles, opts, groups,
                            copy.deepcopy(g_param_file_content))


if __name__ == "__main__":
    main()
