import matplotlib.pyplot as plt import pandas as pd import sys, os, pathlib class eDiscoveryPlot: def __init__(self, datapath, outdir='./plots', loop=True, save=True): self.outdir = outdir self.datapath = datapath self.plotname = pathlib.Path(datapath).name.replace(".csv", ".png") self.loop = loop self.save = save if not loop: plt.rcParams['figure.figsize'] = [12, 12] plt.rcParams['figure.dpi'] = 200 else: plt.rcParams['figure.figsize'] = [17, 17] plt.rcParams['figure.dpi'] = 60 # plot the data self.fig, self.axs = plt.subplots(5) def plot(self): fig, axs = self.fig, self.axs loop, save = self.loop, self.save aXn = 0 df = pd.read_csv(self.datapath, sep='\t') xs = df['it'] y_r = df['R'] y_rhat = df['Rhat'] y_rhatCC = df['RhatCC'] axs[aXn].plot(xs, y_rhat, label='$\hat{R}_{Q}$') axs[aXn].plot(xs, y_rhatCC, label='$\hat{R}_{CC}$') axs[aXn].plot(xs, y_r, label='$R$') axs[aXn].legend() axs[aXn].grid() axs[aXn].set_ylabel('Recall') axs[aXn].set_ylim(0, 1) aXn += 1 y_r = df['te-prev'] y_rhat = df['te-estim'] y_rhatCC = df['te-estimCC'] axs[aXn].plot(xs, y_rhat, label='te-$\hat{Pr}(\oplus)_{Q}$') axs[aXn].plot(xs, y_rhatCC, label='te-$\hat{Pr}(\oplus)_{CC}$') axs[aXn].plot(xs, y_r, label='te-$Pr(\oplus)$') axs[aXn].legend() axs[aXn].grid() axs[aXn].set_ylabel('Prevalence') aXn += 1 y_ae = df['AE'] y_ae_cc = df['AE_CC'] axs[aXn].plot(xs, y_ae, label='AE$_{Q}$') axs[aXn].plot(xs, y_ae_cc, label='AE$_{CC}$') axs[aXn].legend() axs[aXn].grid() axs[aXn].set_ylabel('Quantification error') aXn += 1 axs[aXn].plot(xs, df['MF1_Q'], label='$F_1(clf(Q))$') axs[aXn].plot(xs, df['MF1_Clf'], label='$F_1(clf(CC))$') axs[aXn].legend() axs[aXn].grid() axs[aXn].set_ylabel('Classifiers performance') aXn += 1 axs[aXn].plot(xs, df['Shift'], '--k', label='tr-te shift (AE)') axs[aXn].plot(xs, df['tr-prev'], 'y', label='tr-$Pr(\oplus)$') axs[aXn].plot(xs, df['te-prev'], 'r', label='te-$Pr(\oplus)$') axs[aXn].legend() axs[aXn].grid() axs[aXn].set_ylabel('Train-Test Shift') aXn += 1 if save: os.makedirs(self.outdir, exist_ok=True) plt.savefig(f'{self.outdir}/{self.plotname}') if loop: plt.pause(.5) for i in range(aXn): axs[i].cla() if __name__ == '__main__': assert len(sys.argv) == 3, f'wrong args, syntax is: python {sys.argv[0]} ' file = str(sys.argv[1]) loop = bool(int(sys.argv[2])) figure = eDiscoveryPlot(file) try: figure.plot(loop) except KeyboardInterrupt: print('\n[stop]')