from typing import List, Tuple

import numpy as np
import scipy.sparse as sp
from quapy.data import LabelledCollection

# Extended classes
#
# 0 ~ True 0
# 1 ~ False 1
# 2 ~ False 0
# 3 ~ True 1
#      _____________________
#     |          |          |
#     |  True 0  |  False 1 |
#     |__________|__________|
#     |          |          |
#     |  False 0 |  True 1  |
#     |__________|__________|
#


class ExtensionPolicy:
    def __init__(self, collapse_false=False):
        self.collapse_false = collapse_false


class ExtendedData:
    def __init__(
        self,
        instances: np.ndarray | sp.csr_matrix,
        pred_proba: np.ndarray,
        ext: np.ndarray = None,
        extpol=None,
    ):
        self.extpol = ExtensionPolicy() if extpol is None else extpol
        self.b_instances_ = instances
        self.pred_proba_ = pred_proba
        self.ext_ = ext
        self.instances = self.__extend_instances(instances, pred_proba, ext=ext)

    def __extend_instances(
        self,
        instances: np.ndarray | sp.csr_matrix,
        pred_proba: np.ndarray,
        ext: np.ndarray = None,
    ) -> np.ndarray | sp.csr_matrix:
        to_append = ext
        if ext is None:
            to_append = pred_proba

        if isinstance(instances, sp.csr_matrix):
            _to_append = sp.csr_matrix(to_append)
            n_x = sp.hstack([instances, _to_append])
        elif isinstance(instances, np.ndarray):
            n_x = np.concatenate((instances, to_append), axis=1)
        else:
            raise ValueError("Unsupported matrix format")

        return n_x

    @property
    def X(self):
        return self.instances

    def __split_index_by_pred(self) -> List[np.ndarray]:
        _pred_label = np.argmax(self.pred_proba_, axis=1)

        return [
            (_pred_label == cl).nonzero()[0]
            for cl in np.arange(self.pred_proba_.shape[1])
        ]

    def split_by_pred(self, return_indexes=False):
        def _empty_matrix():
            if isinstance(self.instances, np.ndarray):
                return np.asarray([], dtype=int)
            elif isinstance(self.instances, sp.csr_matrix):
                return sp.csr_matrix(np.empty((0, 0), dtype=int))

        _indexes = self.__split_index_by_pred()
        _instances = [
            self.instances[ind] if ind.shape[0] > 0 else _empty_matrix()
            for ind in _indexes
        ]

        if return_indexes:
            return _instances, _indexes

        return _instances

    def __len__(self):
        return self.instances.shape[0]


class ExtendedLabels:
    def __init__(
        self,
        true: np.ndarray,
        pred: np.ndarray,
        ncl: np.ndarray,
        extpol: ExtensionPolicy = None,
    ):
        self.extpol = ExtensionPolicy() if extpol is None else extpol
        self.true = true
        self.pred = pred
        self.ncl = ncl

    @property
    def y(self):
        if self.extpol.collapse_false:
            return self.true + self.pred
        else:
            return self.true * self.ncl + self.pred

    @property
    def classes(self):
        if self.extpol.collapse_false:
            return np.arange(self.ncl + 1)
        else:
            return np.arange(self.ncl**2)

    def __getitem__(self, idx):
        return ExtendedLabels(self.true[idx], self.pred[idx], self.ncl)


class ExtendedCollection(LabelledCollection):
    def __init__(
        self,
        instances: np.ndarray | sp.csr_matrix,
        labels: np.ndarray,
        pred_proba: np.ndarray = None,
        ext: np.ndarray = None,
        extpol=None,
    ):
        self.extpol = ExtensionPolicy() if extpol is None else extpol
        e_data, e_labels = self.__extend_collection(
            instances=instances,
            labels=labels,
            pred_proba=pred_proba,
            ext=ext,
        )
        self.e_data_ = e_data
        self.e_labels_ = e_labels
        super().__init__(e_data.X, e_labels.y, classes=e_labels.classes)

    @classmethod
    def from_lc(
        cls,
        lc: LabelledCollection,
        pred_proba: np.ndarray,
        ext: np.ndarray = None,
        extpol=None,
    ):
        return ExtendedCollection(
            lc.X, lc.y, pred_proba=pred_proba, ext=ext, extpol=extpol
        )

    @property
    def pred_proba(self):
        return self.e_data_.pred_proba_

    @property
    def ext(self):
        return self.e_data_.ext_

    @property
    def eX(self):
        return self.e_data_

    @property
    def ey(self):
        return self.e_labels_

    def counts(self):
        _counts = super().counts()
        if self.extpol.collapse_false:
            _counts = np.insert(_counts, 2, 0)

        return _counts

    def split_by_pred(self):
        _ncl = self.pred_proba.shape[1]
        _instances, _indexes = self.e_data_.split_by_pred(return_indexes=True)
        _labels = [self.ey[ind] for ind in _indexes]
        return [
            LabelledCollection(inst, lbl.true, classes=range(0, _ncl))
            for inst, lbl in zip(_instances, _labels)
        ]

    def __extend_collection(
        self,
        instances: sp.csr_matrix | np.ndarray,
        labels: np.ndarray,
        pred_proba: np.ndarray,
        ext: np.ndarray = None,
        extpol=None,
    ) -> Tuple[ExtendedData, ExtendedLabels]:
        n_classes = pred_proba.shape[1]
        # n_X = [ X | predicted probs. ]
        e_instances = ExtendedData(instances, pred_proba, ext=ext, extpol=self.extpol)

        # n_y = (exptected y, predicted y)
        preds = np.argmax(pred_proba, axis=-1)
        e_labels = ExtendedLabels(labels, preds, n_classes, extpol=self.extpol)

        return e_instances, e_labels