"""Creates all databases and run computations for templates defined in /home/user/.thyrsis/compute_validation.csv
    This file is a csv file with ';' separators and should have the columns :
    - template : str, name of the template
    - point : str, name of the point (from table points_interet)
    - date : str, date with format 'YYYY-MM-DD'
    - unit : str, unit
    - value : float, the value to test
    - comment : str
    the first line is the header with names of columns :
    'template;point;date;unit;value;comment'

USAGE:
    python -m thyrsis.simulation.compute_validation
        [-dir=DIRNAME -raz -init -run -plot -results -a] [template] [-template] [-code=CODENAME_LIST]

OPTIONS

   -dir : name of results directory [default: thyrsis_compute_validation]
   -raz : delete previous results directory
   -init : create databases
   -run : run simulations (compute)
   -results : calculate results
   -plot : plot results
   -a : all : init, compute, results and plot
   template: ne traite que le cas template
   -template: ne traite pas le cas template
   -code : nom du code de calcul, sinon code déterminé par Settings
"""
import os
import shutil
import sys
from builtins import str
from datetime import date, datetime
from distutils.spawn import find_executable
from subprocess import Popen

import pandas as pd

from ..database import sqlite
from ..log import logger
from ..settings import Settings
from ..spatialitemeshdataprovider import SpatialiteMeshDataProvider
from ..utilities import Timer
from .compute import compute_mpi


def compute_validation(topdir, init, run, resu, plot, results, codes, raz):

    settings = Settings()
    top_dir = os.path.join(os.getcwd(), topdir)
    if raz and os.path.exists(top_dir):
        logger.notice("removing directory %s" % (top_dir))
        shutil.rmtree(top_dir)
    if not os.path.exists(top_dir):
        logger.notice("creating directory %s" % (top_dir))
        os.mkdir(top_dir)

    thyrsis_dir = os.environ["THYRSIS"]
    data_dir = os.path.join(thyrsis_dir, "data")

    report_path = os.path.join(top_dir, "report.csv")
    if os.path.isfile(report_path):
        logger.notice("reading file %s" % (report_path))
        report = pd.read_csv(report_path, sep=";")
        report.set_index("template", inplace=True)
    else:
        logger.notice("creating file %s" % (report_path))
        l = [
            [code + "_value", code + "_difference_%", code + "_cputime"]
            for code in codes
        ]
        report = pd.DataFrame(
            columns=["template", "reference"] + [y for x in l for y in x] + ["comment"]
        )
        report["template"] = sorted(results.keys())
        report["reference"] = [
            results[template]["value"] for template in sorted(results.keys())
        ]
        report["comment"] = [
            results[template]["comment"] for template in sorted(results.keys())
        ]
        report.set_index("template", inplace=True)

    timer = Timer()
    for template in sorted(results.keys()):
        logger.notice("\n" + template)
        sitename = template.split("_")[0]

        template_dir = os.path.join(top_dir, template)
        if not os.path.exists(template_dir):
            os.mkdir(template_dir)
        for code in codes:

            cputime = ""

            logger.notice("setting code to", code)
            settings.setValue("General", "codeHydro", code)
            settings.setValue("General", "codeHydroCommand", codes[code])
            settings.save()

            code_dir = os.path.join(template_dir, code)
            dbname = os.path.join(code_dir, template.lower() + ".sqlite")
            if init:
                if os.path.exists(code_dir):
                    shutil.rmtree(code_dir)
                os.makedirs(code_dir)
                tpltname = template.split("_")[2]
                if "TRANSITOIRE" in tpltname:
                    u_infiltration = os.path.join(
                        data_dir, sitename, "u_infiltration_transitoire.txt"
                    )
                    if os.path.exists(u_infiltration):
                        if "ZNS" in tpltname:
                            shutil.copy2(
                                u_infiltration,
                                os.path.join(
                                    code_dir, "u_infiltration_transitoire.insat"
                                ),
                            )
                        elif "ZS" in tpltname:
                            shutil.copy2(
                                u_infiltration,
                                os.path.join(
                                    code_dir, "u_infiltration_transitoire.sat"
                                ),
                            )
                    else:
                        logger.error("infiltration file missing")
                        continue

                #    if sitename == 'SACLAY' and init:
                #        u_ezns = os.path.join(install_dir, 'sites', sitename, 'input', 'u_epaisseurZNS.sat')
                #        if os.path.exists(u_ezns):
                #            shutil.copy2(u_ezns, code_dir)
                #        else:
                #            logger.error(u"fichier d'épaisseur ZNS absent")

                os.system("python -m thyrsis.database " + template + " " + dbname)

            if run:
                timer.reset("")
                if not os.path.exists(code_dir):
                    exit("initializing has to be done first")
                logger.notice("\nrunning in " + code_dir + "\n")
                compute_mpi(dbname, True, True)
                cputime = timer.reset("")

            if resu:
                if not os.path.exists(code_dir):
                    exit("initializing and running have to be done first")
                with sqlite.connect(dbname) as conn:
                    cur = conn.cursor()
                    project_SRID = str(
                        cur.execute(
                            "SELECT srid FROM geometry_columns WHERE f_table_name = 'noeuds'"
                        ).fetchone()[0]
                    )
                    xy = cur.execute(
                        "SELECT X(GEOMETRY), Y(GEOMETRY) FROM points_interet WHERE nom='%s'"
                        % (results[template]["point"])
                    ).fetchone()
                notice = (
                    results[template]["point"]
                    + " "
                    + str(xy[0])
                    + " "
                    + str(xy[1])
                    + " "
                    + results[template]["date"]
                )
                logger.notice(template)
                logger.notice(notice)
                column = (
                    "potentiel" if results[template]["unit"] == "m" else "concentration"
                )
                provider = SpatialiteMeshDataProvider(
                    "dbname=%s crs=epsg:%s resultColumn=%s"
                    % (dbname, project_SRID, column)
                )
                provider.setUnits(results[template]["unit"])
                (date_idx,) = [
                    i
                    for i, d in enumerate(provider.dates())
                    if d[:10] == results[template]["date"]
                ]
                # print xy, date_idx, provider.dates()[date_idx]
                value = provider.valuesAt(xy)[date_idx]
                vref = results[template]["value"]
                d = 100 * (value - vref) / vref
                notice = "%.5f;%.1f;%s;" % (value, d, cputime.strip())
                report.loc[
                    template,
                    [code + "_value", code + "_difference_%", code + "_cputime"],
                ] = [value, d, cputime.strip()]
                report[code + "_value"] = report[code + "_value"].astype(float)
                report[code + "_difference_%"] = report[code + "_difference_%"].astype(
                    float
                )
                logger.notice("%.5f;" % (vref) + notice + "\n")
                report.to_csv(
                    report_path,
                    encoding="utf-8",
                    index=True,
                    header=True,
                    sep=";",
                    float_format="%.5f",
                )

        if plot:
            create_plots(template_dir, codes)


def create_plots(template_dir, codes):

    with open(os.path.join(template_dir, "codes.pal"), "w") as fil:
        fil.write("set style line 1 pt 7 lw 3 lt 1 lc rgb 'red' # METIS\n")
        fil.write("set style line 2 pt 9 lw 2 lt 1 lc rgb 'blue' # OPENFOAM\n")
        fil.write(
            "set style line 3 pt 2 lw 4 lt 1 lc rgb 'green' # OPENFOAM.POSTPROCESSING\n"
        )
        fil.write(
            "set style line 4 pt 2 lw 4 lt 1 lc rgb 'black' # METIS.POSTPROCESSING\n"
        )

    create_plots_zns(template_dir, codes)
    create_plots_probes(template_dir, codes)


def create_plots_zns(template_dir, codes):
    logger.notice("creating plots for zns")
    dzns = {}
    sat_dir = {}
    lzns = None
    loutput = None
    for code in codes:
        logger.notice("code %s" % (code))
        sample_dir = os.path.join(
            code, os.path.basename(template_dir).lower() + "_tmp", "sample00001"
        )
        sat_dir[code] = os.path.join(sample_dir, "sature")
        full_sample_dir = os.path.join(template_dir, sample_dir)
        if not os.path.isdir(full_sample_dir):
            dzns[code] = None
            continue
        dzns[code] = [x for x in os.listdir(full_sample_dir) if "insat" in x]
        if dzns[code] and not lzns:
            lzns = dzns[code]
        if dzns[code]:
            if code == "metis" and not loutput:
                loutput = [
                    x.split(".")[0]
                    for x in os.listdir(os.path.join(full_sample_dir, dzns[code][0]))
                    if "debit" in x and "debit_entrant" not in x
                ]
            dzns[code] = [os.path.join(sample_dir, x) for x in dzns[code]]

    ncodes_zns = len([x for x in dzns if dzns[x]])

    if lzns == None:
        logger.notice("no zns to plot")
        return

    for output in loutput:
        logger.notice("zns output =", output)
        if output == "debit_eau":
            name_mapping = {
                "metis": "debit_eau.insat",
                "openfoam": "waterMassBalance.csv",
            }
            column_mapping = {"metis": 2, "openfoam": 2}
            ylabel = "water flow (m3/s)"
            flux_entrant = "debit_eau_entrant"
        elif output == "debit":
            name_mapping = {"metis": "debit.insat", "openfoam": "CmassBalance.csv"}
            column_mapping = {"metis": 2, "openfoam": 3}
            ylabel = "mass flow (kg/s)"
            flux_entrant = "debit_entrant"
        else:
            logger.notice("output : %s unknown for zns" % (output))
            continue

        if not sat_dir[code] and not os.path.isfile(
            os.path.join(template_dir, sat_dir[code], flux_entrant + "_plot.sat")
        ):
            logger.notice(
                "No file %s"
                % (
                    os.path.join(
                        template_dir, sat_dir[code], flux_entrant + "_plot.sat"
                    )
                )
            )
            sat_dir[code] = None
        ncodes_flux_entrant = len([x for x in sat_dir if sat_dir[x]])

        plotname = flux_entrant + ".gp"
        with open(os.path.join(template_dir, plotname), "w") as fil:
            fil.write("set term png size 1200, 800\n")
            fil.write("set output '%s.png'\n" % (flux_entrant))
            fil.write(
                "set label '%s' at graph 0.95,0.05 right font 'LiberationSans-Regular,24'\n"
                % (flux_entrant)
            )
            fil.write("set xlabel 'seconds'\n")
            fil.write("set ylabel '%s'\n" % (ylabel))
            fil.write("load 'codes.pal'\n")

            fil.write("plot ")
            for ic, code in enumerate(codes):
                if sat_dir[code]:
                    fil.write(
                        "     '%s/%s' u 1:2 w linespoints ls %d ti '%s'"
                        % (sat_dir[code], flux_entrant + "_plot.sat", ic + 1, code)
                    )
                    if ic < ncodes_flux_entrant - 1:
                        fil.write(", \\\n")

        Popen(["gnuplot", plotname], cwd=template_dir).wait()

        for iz, zns in enumerate(lzns):
            logger.notice("zns = ", zns)
            plotname = output + zns[5:] + ".gp"
            with open(os.path.join(template_dir, plotname), "w") as fil:
                fil.write("set term png size 1200, 800\n")
                fil.write("set output '%s.png'\n" % (output + zns[5:]))
                fil.write(
                    "set label '%s' at graph 0.95,0.05 right font 'LiberationSans-Regular,24'\n"
                    % (output + zns[5:])
                )
                fil.write("set xlabel 'seconds'\n")
                fil.write("set ylabel '%s'\n" % (ylabel))
                fil.write("load 'codes.pal'\n")

                fil.write("plot ")
                for ic, code in enumerate(codes):
                    if dzns[code]:
                        fil.write(
                            "     '%s/%s' u 1:%d w linespoints ls %d ti '%s'"
                            % (
                                dzns[code][iz],
                                name_mapping[code],
                                column_mapping[code],
                                ic + 1,
                                code,
                            )
                        )
                        if ic < ncodes_zns - 1:
                            fil.write(", \\\n")

            Popen(["gnuplot", plotname], cwd=template_dir).wait()


def create_plots_probes(template_dir, codes):
    logger.notice("creating plots at probes")
    probes = {}
    points = None
    output = None
    for code in codes:
        code_dir = os.path.join(template_dir, code)
        if not os.path.isdir(code_dir):
            probes[code] = None
            continue
        lprobes = [x for x in os.listdir(code_dir) if "probes" in x]
        probes[code] = lprobes[0] if lprobes else None
        if probes[code] and not points:
            output = probes[code].split(".")[0]
            with open(os.path.join(code_dir, probes[code])) as fil:
                points = next(fil).split()[1:]

    ncodes_probes = len([x for x in probes if probes[x]])

    logger.notice("probes output =", output)
    if output == "concentration":
        unit = "µg/L"
        coef_unit = 1e6
        of_name = "C"
    elif output == "potentiel":
        unit = "m"
        coef_unit = 1
        of_name = "potential"
    else:
        logger.notice("no result to plot")
        return

    for ipt, point in enumerate(points):

        with open(os.path.join(template_dir, point + ".gp"), "w") as fil:
            fil.write("set term png size 1200, 800\n")
            fil.write("set output '%s.png'\n" % (point))
            fil.write(
                "set label '%s' at graph 0.95,0.05 right font 'LiberationSans-Regular,24'\n"
                % (point.replace("_", "\_"))
            )
            fil.write("set xlabel 'seconds'\n")
            fil.write("set ylabel '%s (%s)'\n" % (output, unit))
            fil.write("load 'codes.pal'\n")

            fil.write("plot ")
            if "openfoam" in codes:
                of_probe_dir = os.path.join(
                    "openfoam",
                    os.path.basename(template_dir).lower() + "_tmp",
                    "sample00001",
                    "sature",
                    "postProcessing",
                    "probes",
                )
                full_of_probe_dir = os.path.join(template_dir, of_probe_dir)
                of_probe_file = os.path.join(
                    of_probe_dir, os.listdir(full_of_probe_dir)[0], of_name
                )
                full_of_probe_file = os.path.join(template_dir, of_probe_file)
                if os.path.isfile(full_of_probe_file):
                    fil.write(
                        "     '%s' u 1:($%d*%f) w linespoints ls %d ti 'openfoam.postprocessing', \\\n"
                        % (of_probe_file, ipt + 2, coef_unit, 3)
                    )
            if False and "metis" in codes:
                metis_probe_file = os.path.join(
                    "metis",
                    os.path.basename(template_dir).lower() + "_tmp",
                    "sample00001",
                    "sature",
                    output + "_probes.txt",
                )
                full_metis_probe_file = os.path.join(template_dir, metis_probe_file)
                if os.path.isfile(full_metis_probe_file):
                    fil.write(
                        "     '%s' u 1:($%d*%f) w linespoints ls %d ti 'metis.postprocessing', \\\n"
                        % (metis_probe_file, ipt + 2, coef_unit, 4)
                    )
            for ic, code in enumerate(codes):
                if probes[code]:
                    fil.write(
                        "     '%s/%s' u 1:($%d*%f) w linespoints ls %d ti '%s'"
                        % (
                            code,
                            probes[code],
                            ipt + 2,
                            coef_unit,
                            ic + 1,
                            code + ".thyrsis",
                        )
                    )
                    if ic < ncodes_probes - 1:
                        fil.write(", \\\n")

        Popen(["gnuplot", point + ".gp"], cwd=template_dir).wait()


if __name__ == "__main__":
    logger.enable_console(True)
    logger.set_level("debug")

    if len(sys.argv) == 2 and sys.argv[1] in ["-h", "--help"]:
        help(sys.modules[__name__])
        exit(0)

    topdir = "thyrsis_compute_validation_" + datetime.strftime(date.today(), "%y%m%d")
    init = False
    run = False
    resu = False
    plot = False
    raz = False

    references_file = os.path.join(
        os.getenv("HOME"), ".thyrsis", "compute_validation.csv"
    )
    if not os.path.isfile(references_file):
        exit("File %s does not exist. Stop" % (references_file))
    logger.notice("reading file %s" % (references_file))
    ref = pd.read_csv(references_file, sep=";")
    REFERENCES = ref.set_index("template").to_dict("index")

    results = REFERENCES
    results_raz = False

    codes = {c: find_executable(c) for c in ["metis", "openfoam"] if find_executable(c)}
    if find_executable("groundwaterFoam"):
        codes["openfoam"] = ""
    codes_all = codes

    for arg in sys.argv[1:]:
        if arg == "-init":
            init = True
        elif arg == "-run":
            run = True
        elif arg == "-raz":
            raz = True
        elif arg == "-results":
            resu = True
        elif arg == "-plot":
            plot = True
        elif arg[:5] == "-dir=":
            topdir = arg[5:]
        elif arg == "-a":
            init = True
            run = True
            resu = True
            plot = True
        elif arg[0] == "-" and arg[1:] in list(REFERENCES.keys()):
            del results[arg[1:]]
        elif arg in list(REFERENCES.keys()):
            if not results_raz:
                results = {}
                results_raz = True
            results[arg] = REFERENCES[arg]
        elif arg[:6] == "-code=":
            codes = {
                c: find_executable(c) for c in arg[6:].split(",") if c in codes_all
            }
        else:
            logger.error("argument not recognized", arg)
            exit(1)

    compute_validation(topdir, init, run, resu, plot, results, codes, raz)
