import matplotlib.pyplot as plt
import pandas as pd
import sys, os, pathlib


class eDiscoveryPlot:

    def __init__(self, datapath, outdir='./plots', loop=True, save=True, showYdist=False, showCost=True, refreshEach=10):
        self.outdir = outdir
        self.datapath = datapath
        self.plotname = pathlib.Path(datapath).name.replace(".csv", ".png")
        self.loop = loop
        self.save = save
        self.showYdist = showYdist
        self.showCost = showCost
        self.refreshEach = refreshEach

        nPlots = 4
        if showYdist:
            nPlots+=1
        if showCost:
            nPlots += 1

        if not loop:
            plt.rcParams['figure.figsize'] = [12, 12]
            plt.rcParams['figure.dpi'] = 200
        else:
            plt.rcParams['figure.figsize'] = [14, 18]
            plt.rcParams['figure.dpi'] = 50
            plt.rcParams.update({'font.size': 15})

        # plot the data
        self.fig, self.axs = plt.subplots(nPlots)
        self.calls=0

    def plot(self, posteriors, y):

        if (self.calls+1) % self.refreshEach != 0:
            self.calls+=1
            return

        fig, axs = self.fig, self.axs
        loop, save = self.loop, self.save

        aXn = 0
        df = pd.read_csv(self.datapath, sep='\t')

        xs = df['it']

        y_r = df['R']
        y_rhat = df['Rhat']
        y_rhatCC = df['RhatCC']
        axs[aXn].plot(xs, y_rhat, label='$\hat{R}_{Q}$')
        axs[aXn].plot(xs, y_rhatCC, label='$\hat{R}_{CC}$')
        axs[aXn].plot(xs, y_r, label='$R$')
        axs[aXn].grid()
        axs[aXn].set_ylabel('Recall')
        axs[aXn].set_ylim(0, 1)
        aXn += 1

        y_r = df['te-prev']
        y_rhat = df['te-estim']
        y_rhatCC = df['te-estimCC']
        axs[aXn].plot(xs, y_rhat, label='te-$\hat{Pr}(\oplus)_{Q}$')
        axs[aXn].plot(xs, y_rhatCC, label='te-$\hat{Pr}(\oplus)_{CC}$')
        axs[aXn].plot(xs, y_r, label='te-$Pr(\oplus)$')
        axs[aXn].legend()
        axs[aXn].grid()
        axs[aXn].set_ylabel('Pool prevalence')
        aXn += 1

        y_ae = df['AE']
        y_ae_cc = df['AE_CC']
        axs[aXn].plot(xs, y_ae, label='AE$_{Q}$')
        axs[aXn].plot(xs, y_ae_cc, label='AE$_{CC}$')
        axs[aXn].legend()
        axs[aXn].grid()
        axs[aXn].set_ylabel('Quantification error')
        aXn += 1

        # classifier performance (not very reliable)
        #axs[aXn].plot(xs, df['MF1_Q'], label='$F_1(clf(Q))$')
        #axs[aXn].plot(xs, df['MF1_Clf'], label='$F_1(clf(CC))$')
        #axs[aXn].legend()
        #axs[aXn].grid()
        #axs[aXn].set_ylabel('Classifiers performance')
        #aXn += 1

        if self.showCost:
            cost = df['tr-size']
            idealcost = df['ICost']
            totalcost = cost + idealcost
            axs[aXn].plot(xs, cost, label='Cost')
            axs[aXn].plot(xs, idealcost, label='IdealCost')
            axs[aXn].plot(xs, totalcost, label='TotalCost')
            axs[aXn].legend()
            axs[aXn].grid()
            axs[aXn].set_ylabel('Cost')
            aXn += 1

        # distribution of posterior probabilities in the pool
        if self.showYdist:
            positive_posteriors = posteriors[y==1,1]
            negative_posteriors = posteriors[y==0,1]
            #axs[aXn].hist([negative_posteriors, positive_posteriors], bins=50,
            #         label=['negative', 'positive'])
            axs[aXn].hist(negative_posteriors, bins=50, label='negative', density=True, alpha=.75)
            axs[aXn].hist(positive_posteriors, bins=50, label='positive', density=True, alpha=.75)
            axs[aXn].legend()
            axs[aXn].grid()
            axs[aXn].set_xlim(0, 1)
            axs[aXn].set_ylabel('te-$Pr(\oplus)$ distribution')
            aXn += 1

        axs[aXn].plot(xs, df['Shift'], '--k', label='shift (AE)')
        axs[aXn].plot(xs, df['tr-prev'], 'y', label='tr-$Pr(\oplus)$')
        axs[aXn].plot(xs, df['te-prev'], 'r', label='te-$Pr(\oplus)$')
        axs[aXn].legend()
        axs[aXn].grid()
        axs[aXn].set_ylabel('Train-Test Shift')
        aXn += 1

        for i in range(aXn):
            if self.calls==0:
                # Shrink current axis by 20%
                box = axs[i].get_position()
                axs[i].set_position([box.x0, box.y0, box.width * 0.8, box.height])
                fig.tight_layout()

            # Put a legend to the right of the current axis
            axs[i].legend(loc='center left', bbox_to_anchor=(1, 0.5))

        if save:
            os.makedirs(self.outdir, exist_ok=True)
            plt.savefig(f'{self.outdir}/{self.plotname}')

        if loop:
            plt.pause(.5)
            for i in range(aXn):
                axs[i].cla()

        self.calls += 1


class InOutDistPlot:

    def __init__(self, refreshEach=1):
        self.refreshEach = refreshEach

        # plot the data
        self.fig, self.axs = plt.subplots(2)
        self.calls = 0

    def _plot_dist(self, posteriors, y, aXn, title):
        positive_posteriors = posteriors[y == 1, 1]
        negative_posteriors = posteriors[y == 0, 1]
        self.axs[aXn].hist(negative_posteriors, bins=50, label='$Pr(x|\ominus)$', density=False, alpha=.75)
        self.axs[aXn].hist(positive_posteriors, bins=50, label='$Pr(x|\oplus)$', density=False, alpha=.75)
        self.axs[aXn].legend()
        self.axs[aXn].grid()
        self.axs[aXn].set_xlim(0, 1)
        self.axs[aXn].set_ylabel(title)

    def plot(self, in_posteriors, in_y, out_posteriors, out_y):

        if (self.calls+1) % self.refreshEach != 0:
            self.calls += 1
            return

        fig, axs = self.fig, self.axs

        aXn = 0

        # in-posteriors distribution
        self._plot_dist(in_posteriors, in_y, aXn, title='training distribution')
        aXn += 1

        # out-posteriors distribution
        self._plot_dist(out_posteriors, out_y, aXn, title='pool distribution')
        aXn += 1

        for i in range(aXn):
            if self.calls==0:
                # Shrink current axis by 20%
                box = axs[i].get_position()
                axs[i].set_position([box.x0, box.y0, box.width * 0.8, box.height])
                fig.tight_layout()

            # Put a legend to the right of the current axis
            axs[i].legend(loc='center left', bbox_to_anchor=(1, 0.5))

        plt.pause(.5)
        for i in range(aXn):
            axs[i].cla()

        self.calls += 1


if __name__ == '__main__':

    assert len(sys.argv) == 3, f'wrong args, syntax is: python {sys.argv[0]} <result_input_path> <dynamic (0|1)>'

    file = str(sys.argv[1])
    loop = bool(int(sys.argv[2]))

    figure = eDiscoveryPlot(file)

    try:
        figure.plot(loop)
    except KeyboardInterrupt:
        print('\n[stop]')