1
0
Fork 0
QuaPy/Ordinal/amazon_prevalence_plotgrid.py

106 lines
3.3 KiB
Python

import gzip
import os
from collections import Counter
from Ordinal.utils import jaggedness
import quapy as qp
import pickle
import numpy as np
import pandas as pd
base_path = '/media/moreo/Volume/Datasets/Amazon/reviews'
categories_path = '/media/moreo/Volume/Datasets/Amazon/raw/amazon_categories.txt'
def get_prevalence_merchandise(category):
input_file = os.path.join(base_path, category+'.txt.gz')
labels = []
print(f'{category} starts')
with gzip.open(input_file, 'rt') as f:
for line in f:
try:
stars, doc = line.split('\t')
labels.append(stars)
except:
print('error in line: ', line)
counts = Counter(labels)
print(f'\t{category} done')
return counts
target_file = './counters_Amazon_merchandise.pkl'
if not os.path.exists(target_file):
categories = [c.strip().replace(' ', '_') for c in open(categories_path, 'rt').readlines()]
# categories = ['Gift_Cards', 'Magazine_Subscriptions']
counters = qp.util.parallel(get_prevalence_merchandise, categories, n_jobs=-1)
print('saving pickle')
pickle.dump((categories, counters), open(target_file, 'wb'), pickle.HIGHEST_PROTOCOL)
else:
(categories, counters) = pickle.load(open(target_file, 'rb'))
index_gift_cards = categories.index('Gift_Cards')
del categories[index_gift_cards]
del counters[index_gift_cards]
class_smooth = []
for cat, counter in zip(categories, counters):
total = sum(count for label, count in counter.items())
counts = [counter[i] for i in map(str, [1,2,3,4,5])]
p = np.asarray(counts)/total
smooth = jaggedness(p)
class_smooth.append([smooth, cat, p])
class_smooth = sorted(class_smooth)
# df = pd.DataFrame(class_smooth, columns=['smoothness', 'category', 'prevalence'])
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme('paper')
sns.set_style('dark')
sns.set(font_scale=0.5)
nrows = 7
ncols = 4
figure, axis = plt.subplots(nrows, ncols, figsize=(ncols*2, nrows))
with open('categories.txt', 'wt') as foo:
foo.write(f'Category\tSmooth\tPrevalence\n')
for i, (smooth, category, prevalence) in enumerate(class_smooth):
row = i // 4
col = i % 4
# print(i, row, col)
axis[row, col].bar([1,2,3,4,5], prevalence, width=1)
axis[row, col].set_ylim(0, 0.75)
axis[row, col].set_facecolor('white')
for spine in axis[row, col].spines.values():
spine.set_edgecolor('black')
spine.set_linewidth(0.3)
# axis[row, col].set_xticks(loc=0)
if row==6:
axis[row, col].set_xlabel("stars")
# axis[row, col].set_xticks([1,2,3,4,5])
# else:
# axis[row, col].set_xticks([])
if col==0:
axis[row, col].set_ylabel("")
axis[row, col].set_yticks([])
else:
axis[row, col].set_ylabel("")
axis[row, col].set_yticks([])
category = category.replace('_', ' ').title()
category = category.replace(' And ', ' & ')
axis[row, col].set_title(f'{category} ({smooth:.4f})', x=0.5, y=0.75)
# axis[row, col].set_title
foo.write(f'{category}\t{smooth}\t{prevalence}\n')
# plt.show()
plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig('Amazon_categories_plotgrid.pdf', bbox_inches='tight')