from sklearn.base import BaseEstimator
import numpy as np
import quapy as qp
import quapy.functional as F
from data import LabelledCollection
from method.aggregative import ACC
from method.base import BaseQuantifier
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt


data = qp.datasets.fetch_reviews('kindle', tfidf=True, min_df=10)

class DecisionStump(BaseEstimator):
    def __init__(self, feat_id):
        self.feat_id = feat_id
        self.classes_ = np.asarray([0,1], dtype=int)

    def fit(self, X, y):
        return self

    def predict(self, X):
        return (X[:,self.feat_id].toarray().flatten()>0).astype(int)


class QuantificationStump(BaseQuantifier):
    def __init__(self, feat_id):
        self.feat_id = feat_id

    def fit(self, data: LabelledCollection):
        self.qs = ACC(DecisionStump(self.feat_id))
        self.qs.fit(data, fit_learner=False, val_split=data)
        self.classes = data.classes_
        return self

    def quantify(self, instances):
        return self.qs.quantify(instances)

    def set_params(self, **parameters):
        raise NotImplemented()

    def get_params(self, deep=True):
        raise NotImplemented()

    @property
    def classes_(self):
        return self.classes


train, dev = data.training.split_stratified()
test = data.test.sampling(1000, 0.3, 0.7)

print(f'test prevalence = {F.strprev(test.prevalence())}')

nF = train.instances.shape[1]

qs_scores = []
qs = np.asarray([QuantificationStump(i).fit(train) for i in tqdm(range(nF))])
scores = np.zeros(shape=(nF, 11*5))
for j, dev_sample in tqdm(enumerate(dev.artificial_sampling_generator(500, n_prevalences=11, repeats=5)), total=11*5):
    sample_prev = dev_sample.prevalence()
    for i, qs_i in enumerate(qs):
        estim_prev = qs_i.quantify(dev.instances)
        error = qp.error.ae(sample_prev, estim_prev)
        scores[i,j] = error

k=250
scores = scores.mean(axis=1)
order = np.argsort(scores)
qs = qs[order][:k]

prevs = np.asarray([qs_i.quantify(test.instances)[1] for qs_i in tqdm(qs)])

print(f'test estimation mean {prevs.mean():.3f}, median = {np.median(prevs)}')

# sns.histplot(data=prevs, binwidth=3)
# An "interface" to matplotlib.axes.Axes.hist() method
# n, bins, patches = plt.hist(x=prevs, bins='auto', alpha=0.7)
# plt.grid(axis='y', alpha=0.75)
# plt.xlabel('Value')
# plt.ylabel('Frequency')
# plt.title('My Very Own Histogram')
# maxfreq = n.max()
# Set a clean upper y-axis limit.
# plt.ylim(ymax=np.ceil(maxfreq / 10) * 10 if maxfreq % 10 else maxfreq + 10)
# plt.show()