import gzip
import os
from collections import Counter
from Ordinal.utils import jaggedness
import pickle
import numpy as np
import pandas as pd



nrows = 3
ncols = 4

prevalences = np.genfromtxt('fact_real_prevalences.csv', delimiter=',')[1:]
#prevalences = prevalences[:nrows*ncols]
print(prevalences)

n = prevalences.shape[1]

class_smooth = []
for i, sample in enumerate(prevalences):
    p = sample
    smooth = jaggedness(p)
    class_smooth.append([smooth, f'Sample {i+1}', p])

# these two lines pick the nrows*ncols examples that go from the less jagged to the most jagged
# at equal steps
class_smooth = sorted(class_smooth)
class_smooth = class_smooth[::len(class_smooth)//(nrows*ncols)] 
class_smooth = class_smooth[:nrows*ncols]
# print(class_smooth)
# print(len(class_smooth))

import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme('paper')
sns.set_style('dark')
sns.set(font_scale=0.5)

maxy = np.max(prevalences) + 0.1
class_labels = np.arange(1,n+1)

figure, axis = plt.subplots(nrows, ncols, figsize=(ncols*2, nrows))
for i, (smooth, category, prevalence) in enumerate(class_smooth):
    row = i // ncols
    col = i % ncols
    # print(i, row, col)
    #axis[row, col].bar(list(range(1,n+1)), prevalence, width=1)
    
    axis[row, col].bar(class_labels, prevalence, width=1)
    axis[row, col].set_ylim(0, maxy)
    axis[row, col].set_facecolor('white')
    for spine in axis[row, col].spines.values():
        spine.set_edgecolor('black')
        spine.set_linewidth(0.3)

    if row==nrows-1:
        axis[row, col].set_xlabel("energy bin")
        axis[row, col].set_xticks(class_labels)
    else:
        axis[row, col].set_xlabel("")
        axis[row, col].set_xticks([])
    axis[row, col].set_ylabel("")
    axis[row, col].set_yticks([])

    category = category.replace('_', ' ').title()
    category = category.replace(' And ', ' & ')
    axis[row, col].set_title(f'{category} ({smooth:.4f})', x=0.5, y=0.75)
    # axis[row, col].set_title

    print(smooth, category, prevalence)

# plt.show()
plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig('Telescope_sample_plotgrid.pdf', bbox_inches='tight')