from pathlib import Path

import matplotlib.pyplot as plt
from os import makedirs
from os.path import join

from ClassifierAccuracy.util.commons import get_method_names, open_results


def plot_diagonal(basedir, cls_name, measure_name, dataset_name='*'):
    methods = get_method_names()
    results = open_results(basedir, cls_name, measure_name, dataset_name=dataset_name, method_name=methods)
    methods, xs, ys = [], [], []
    for method_name in results.keys():
        methods.append(method_name)
        xs.append(results[method_name]['true_acc'])
        ys.append(results[method_name]['estim_acc'])
    plotsubdir = 'all' if dataset_name=='*' else dataset_name
    save_path = join('plots', basedir, measure_name, plotsubdir, 'diagonal.png')
    _plot_diagonal(methods, xs, ys, save_path, measure_name)


def _plot_diagonal(methods_names, true_xs, estim_ys, save_path, measure_name, title=None):

    makedirs(Path(save_path).parent, exist_ok=True)

    # Create scatter plot
    plt.figure(figsize=(10, 10))
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.plot([0, 1], [0, 1], color='black', linestyle='--')

    for (method_name, xs, ys) in zip(methods_names, true_xs, estim_ys):
        plt.scatter(xs, ys, label=f'{method_name}', alpha=0.5, linewidths=0)

    plt.legend()

    # Add labels and title
    if title is not None:
        plt.title(title)
    plt.xlabel(f'True {measure_name}')
    plt.ylabel(f'Estimated {measure_name}')

    # Display the plot
    plt.savefig(save_path)
    plt.cla()