from abc import abstractmethod, ABC
from copy import deepcopy
from typing import List, Iterable

import numpy as np

import quapy as qp
from quapy.method.aggregative import AggregativeQuantifier
from quapy.method.non_aggregative import MaximumLikelihoodPrevalenceEstimation as MLPE
from quapy.data import LabelledCollection
from quapy.method.base import BaseQuantifier


class AreaQuantifier:
    def __init__(self, area:int, quantifier: BaseQuantifier):
        self.area = area
        self.quantifier = quantifier

    def quantify(self, X):
        return self.quantifier.quantify(X)


class CombinationRule(ABC):

    def __init__(self, area_quantifiers: List[AreaQuantifier]):
        self.area_quantifiers = area_quantifiers

    @abstractmethod
    def select_quantifiers(self, area:int, X):
        ...

    @abstractmethod
    def combination(self, choice, X):
        ...

    def predict(self, area:int, X):
        choice = self.select_quantifiers(area, X)
        prevalence = self.combination(choice, X)
        return prevalence


def optimize_ensemble(area_data: Iterable, q: BaseQuantifier, Madj=None, hyper=None, error='mae'):
    if hyper is None:
        hyper = {
            'classifier__C': np.logspace(-4, 4, 9),
            'classifier__class_weight': ['balanced', None]
        }

    labelled_collections = [(A, LabelledCollection(X, y)) for A, X, y in area_data]

    area_quantifiers = []
    for A, lc in labelled_collections:
        if Madj is None:
            rest = [lc_j for Aj, lc_j in labelled_collections if Aj != A]
        else:
            rest = [lc_j for Aj, lc_j in labelled_collections if Aj != A and Aj in Madj.get_adjacent(A)]
        q = optim(q, lc, rest, hyper, error)
        area_quantifiers.append(AreaQuantifier(A, q))

    return area_quantifiers


class AggregationRule(CombinationRule):

    def __init__(self, area_quantifiers: List[AreaQuantifier], adjacent_matrix: 'AdjMatrix' = None, aggr='median'):
        assert aggr in ['mean', 'median'], f'unknown {aggr=}'
        self.area_quantifiers = area_quantifiers
        self.adjacent_matrix = adjacent_matrix
        self.aggr = aggr

    def select_quantifiers(self, area:int, X):
        if self.adjacent_matrix is None:
            chosen = self.area_quantifiers
        else:
            adjacent = self.adjacent_matrix.get_adjacent(area)
            chosen = [q_i for q_i in self.area_quantifiers if q_i.area in adjacent]
        return chosen

    def combination(self, choice, X):
        prevs = np.asarray([q.quantify(X) for q in choice])
        if self.aggr == 'median':
            prev = np.median(prevs, axis=0)
        elif self.aggr == 'mean':
            prev = np.mean(prevs, axis=0)
        else:
            raise NotImplementedError(f'{self.aggr=} not implemented')
        return prev


def optim(q: BaseQuantifier, train: LabelledCollection, labelled_collections: Iterable[LabelledCollection], hyper:dict, error='mae'):
    q = deepcopy(q)

    prot = qp.protocol.IterateProtocol(labelled_collections)
    try:
        mod_sel = qp.model_selection.GridSearchQ(
            model=q,
            param_grid=hyper,
            protocol=prot,
            error=error,
            refit=False,
            n_jobs=-1
        ).fit(train)

        fitted = mod_sel.best_model_
    except ValueError:
        print(f'method {q} failed; training without model selection')
        fitted = q.fit(train)

    return fitted