import gzip
import os
from collections import Counter
from Ordinal.utils import jaggedness
import quapy as qp
import pickle
import numpy as np
import pandas as pd

base_path = '/media/moreo/Volume/Datasets/Amazon/reviews'
categories_path = '/media/moreo/Volume/Datasets/Amazon/raw/amazon_categories.txt'


def get_prevalence_merchandise(category):
    input_file = os.path.join(base_path, category+'.txt.gz')
    labels = []
    print(f'{category} starts')
    with gzip.open(input_file, 'rt') as f:
        for line in f:
            try:
                stars, doc = line.split('\t')
                labels.append(stars)
            except:
                print('error in line: ', line)
    counts = Counter(labels)
    print(f'\t{category} done')
    return counts

target_file = './counters_Amazon_merchandise.pkl'

if not os.path.exists(target_file):
    categories = [c.strip().replace(' ', '_') for c in open(categories_path, 'rt').readlines()]

    # categories = ['Gift_Cards', 'Magazine_Subscriptions']
    counters = qp.util.parallel(get_prevalence_merchandise, categories, n_jobs=-1)

    print('saving pickle')
    pickle.dump((categories, counters), open(target_file, 'wb'), pickle.HIGHEST_PROTOCOL)

else:
    (categories, counters) = pickle.load(open(target_file, 'rb'))

index_gift_cards = categories.index('Gift_Cards')
del categories[index_gift_cards]
del counters[index_gift_cards]

class_smooth = []
for cat, counter in zip(categories, counters):
    total = sum(count for label, count in counter.items())
    counts = [counter[i] for i in map(str, [1,2,3,4,5])]
    p = np.asarray(counts)/total
    smooth = jaggedness(p)
    class_smooth.append([smooth, cat, p])

class_smooth = sorted(class_smooth)

# df = pd.DataFrame(class_smooth, columns=['smoothness', 'category', 'prevalence'])

import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme('paper')
sns.set_style('dark')
sns.set(font_scale=0.5)

nrows = 7
ncols = 4
figure, axis = plt.subplots(nrows, ncols, figsize=(ncols*2, nrows))
with open('categories.txt', 'wt') as foo:
    foo.write(f'Category\tSmooth\tPrevalence\n')
    for i, (smooth, category, prevalence) in enumerate(class_smooth):
        row = i // 4
        col = i % 4
        # print(i, row, col)
        axis[row, col].bar([1,2,3,4,5], prevalence, width=1)
        axis[row, col].set_ylim(0, 0.75)
        axis[row, col].set_facecolor('white')
        for spine in axis[row, col].spines.values():
            spine.set_edgecolor('black')
            spine.set_linewidth(0.3)
        # axis[row, col].set_xticks(loc=0)
        if row==6:
            axis[row, col].set_xlabel("stars")
            # axis[row, col].set_xticks([1,2,3,4,5])
        # else:
        #     axis[row, col].set_xticks([])
        if col==0:
            axis[row, col].set_ylabel("")
            axis[row, col].set_yticks([])
        else:
            axis[row, col].set_ylabel("")
            axis[row, col].set_yticks([])

        category = category.replace('_', ' ').title()
        category = category.replace(' And ', ' & ')
        axis[row, col].set_title(f'{category} ({smooth:.4f})', x=0.5, y=0.75)
        # axis[row, col].set_title

        foo.write(f'{category}\t{smooth}\t{prevalence}\n')

# plt.show()
plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig('Amazon_categories_plotgrid.pdf', bbox_inches='tight')