import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sys, os, pathlib

assert len(sys.argv) == 3, f'wrong args, syntax is: python {sys.argv[0]} <result_input_path> <dynamic (0|1)>'

file = str(sys.argv[1])
loop = bool(int(sys.argv[2]))

print(file)

plotname = pathlib.Path(file).name.replace(".csv", ".png")

if not loop:
    plt.rcParams['figure.figsize'] = [12, 12]
    plt.rcParams['figure.dpi'] = 200

# plot the data
fig, axs = plt.subplots(5)


try:
    while True:
        aXn = 0
        df = pd.read_csv(file, sep='\t')

        xs = df['it']

        y_r = df['R']
        y_rhat = df['Rhat']
        y_rhatCC = df['RhatCC']
        axs[aXn].plot(xs, y_rhat, label='$\hat{R}_{Q}$')
        axs[aXn].plot(xs, y_rhatCC, label='$\hat{R}_{CC}$')
        axs[aXn].plot(xs, y_r, label='$R$')
        axs[aXn].legend()
        axs[aXn].grid()
        axs[aXn].set_ylabel('Recall')
        axs[aXn].set_ylim(0,1)
        aXn+=1

        y_r = df['te-prev']
        y_rhat = df['te-estim']
        y_rhatCC = df['te-estimCC']
        axs[aXn].plot(xs, y_rhat, label='te-$\hat{Pr}(\oplus)_{Q}$')
        axs[aXn].plot(xs, y_rhatCC, label='te-$\hat{Pr}(\oplus)_{CC}$')
        axs[aXn].plot(xs, y_r, label='te-$Pr(\oplus)$')
        axs[aXn].legend()
        axs[aXn].grid()
        axs[aXn].set_ylabel('Prevalence')
        aXn += 1

        y_ae = df['AE']
        y_ae_cc = df['AE_CC']
        axs[aXn].plot(xs, y_ae, label='AE$_{Q}$')
        axs[aXn].plot(xs, y_ae_cc, label='AE$_{CC}$')
        axs[aXn].legend()
        axs[aXn].grid()
        axs[aXn].set_ylabel('Quantification error')
        aXn += 1

        axs[aXn].plot(xs, df['MF1_Q'], label='$F_1(clf(Q))$')
        axs[aXn].plot(xs, df['MF1_Clf'], label='$F_1(clf(CC))$')
        axs[aXn].legend()
        axs[aXn].grid()
        axs[aXn].set_ylabel('Classifiers performance')
        aXn += 1

        axs[aXn].plot(xs, df['Shift'], '--k', label='tr-te shift (AE)')
        axs[aXn].plot(xs, df['tr-prev'], 'y', label='tr-$Pr(\oplus)$')
        axs[aXn].plot(xs, df['te-prev'], 'r', label='te-$Pr(\oplus)$')
        axs[aXn].legend()
        axs[aXn].grid()
        axs[aXn].set_ylabel('Train-Test Shift')
        aXn += 1

        os.makedirs('./plots', exist_ok=True)
        plt.savefig(f'./plots/{plotname}')

        if not loop:
            break
        else:
            plt.pause(.5)
            for i in range(aXn):
                axs[i].cla()

except KeyboardInterrupt:
    print("\n[exit]")