from sklearn.linear_model import LogisticRegression
import numpy as np

import quapy as qp
from data import LabelledCollection
from method.base import BaseQuantifier
from quapy.method.aggregative import AggregativeQuantifier, AggregativeProbabilisticQuantifier, CC, ACC, PCC, PACC



class ClassWeightPCC(BaseQuantifier):

    def __init__(self):
        self.learner = None

    def fit(self, data: LabelledCollection, fit_learner=True):
        self.train = data
        self.prompt = PACC(LogisticRegression()).fit(self.train)
        return self

    def quantify(self, instances):
        guessed_prevalence = self.prompt.quantify(instances)
        class_weight = self._get_class_weight(guessed_prevalence)
        return PCC(LogisticRegression(class_weight=class_weight)).fit(self.train).quantify(instances)

    def _get_class_weight(self, prevalence):
        # class_weight = compute_class_weight('balanced', classes=[0, 1], y=mock_y(prevalence))
        # return {0: class_weight[1], 1: class_weight[0]}
        # weights = prevalence/prevalence.min()
        weights = prevalence / self.train.prevalence()
        normfactor = weights.min()
        if normfactor <= 0:
            normfactor = 1E-3
        weights /= normfactor
        return {0:weights[0], 1:weights[1]}

    def set_params(self, **parameters):
        pass

    def get_params(self, deep=True):
        return self.prompt.get_params()

    @property
    def classes_(self):
        return self.train.classes_