import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
import math
from sklearn.model_selection import train_test_split
from model.early_stop import EarlyStop
from model.transformations import FFProjection


class AuthorshipAttributionClassifier(nn.Module):
    def __init__(self, projector, num_authors, pad_index, pad_length=500, device='cpu'):
        super(AuthorshipAttributionClassifier, self).__init__()
        self.projector = projector.to(device)
        #self.ff = FFProjection(input_size=projector.space_dimensions(),
        #                       hidden_sizes=[1024],
        #                       output_size=num_authors).to(device)
        self.ff = FFProjection(input_size=projector.space_dimensions(),
                               hidden_sizes=[],
                               output_size=num_authors).to(device)
        self.padder = Padding(pad_index=pad_index, max_length=pad_length, dynamic=True, pad_at_end=False, device=device)
        self.device = device

    def fit(self, X, y, batch_size, epochs, patience=10, lr=0.001, val_prop=0.1, alpha=1., log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
        assert 0 <= alpha <= 1, 'wrong range, alpha must be in [0,1]'
        early_stop = EarlyStop(patience)
        batcher = Batch(batch_size=batch_size, n_epochs=epochs)
        #batcher = TwoClassBatch(batch_size=batch_size, n_epochs=epochs, steps_per_epoch=X.shape[0]//batch_size)
        batcher_val = Batch(batch_size=batch_size, n_epochs=epochs, shuffle=False)
        criterion = torch.nn.CrossEntropyLoss().to(self.device)
        optim = torch.optim.Adam(self.parameters(), lr=lr)

        X, Xval, y, yval = train_test_split(X, y, test_size=val_prop, stratify=y)

        with open(log, 'wt') as foo:
            foo.write('epoch\ttr-loss\tval-loss\tval-acc\tval-Mf1\tval-mf1\n')
            tr_loss, val_loss = -1, -1
            pbar = tqdm(range(1, batcher.n_epochs+1))
            for epoch in pbar:
                # training
                self.train()
                losses, attr_losses, sav_losses = [], [], []
                for xi, yi in batcher.epoch(X, y):
                    optim.zero_grad()
                    xi = self.padder.transform(xi)
                    phi = self.projector(xi)

                    loss_attr = loss_sav = 0
                    loss_attr_value = loss_sav_value = -1

                    if alpha > 0:
                        logits = self.ff(phi)
                        loss_attr = criterion(logits, torch.as_tensor(yi).to(self.device))
                        loss_attr_value = loss_attr.item()

                    if alpha < 1:
                        # todo: optimize (only upper diagonal)
                        kernel = torch.matmul(phi, phi.T)
                        ideal_kernel = torch.as_tensor(1 * (np.outer(1 + yi, 1 / (yi + 1)) == 1)).to(self.device)
                        # todo: maybe the KALoss should take into consideration the balance (it is more likely to have
                        # a pair of negative examples than positives)
                        loss_sav = KernelAlignmentLoss(kernel, ideal_kernel)
                        loss_sav_value = loss_sav.item()

                    loss = loss_attr*alpha + loss_sav*(1.-alpha)

                    loss.backward()
                    optim.step()

                    attr_losses.append(loss_attr_value)
                    sav_losses.append(loss_sav_value)
                    losses.append(loss.item())
                    tr_loss = np.mean(losses)
                    pbar.set_description(f'training epoch={epoch} '
                                         f'loss={tr_loss:.5f} '
                                         f'attr-loss={np.mean(attr_losses):.5f} '
                                         f'sav-loss={np.mean(sav_losses):.5f} '
                                         f'val_loss={val_loss:.5f}'
                                         )

                # validation
                self.eval()
                predictions, losses = [], []
                for xi, yi in batcher_val.epoch(Xval, yval):
                    xi = self.padder.transform(xi)
                    logits = self.forward(xi)
                    loss = criterion(logits, torch.as_tensor(yi).to(self.device))
                    losses.append(loss.item())
                    logits = nn.functional.log_softmax(logits, dim=1)
                    prediction = tensor2numpy(torch.argmax(logits, dim=1).view(-1))
                    predictions.append(prediction)
                val_loss = np.mean(losses)
                predictions = np.concatenate(predictions)
                acc = accuracy_score(yval, predictions)
                macrof1 = f1_score(yval, predictions, average='macro')
                microf1 = f1_score(yval, predictions, average='micro')

                foo.write(f'{epoch}\t{tr_loss:.8f}\t{val_loss:.8f}\t{acc:.3f}\t{macrof1:.3f}\t{microf1:.3f}\n')
                foo.flush()

                early_stop(microf1, epoch)
                if early_stop.IMPROVED:
                    torch.save(self.state_dict(), checkpointpath)
                elif early_stop.STOP:
                    break
        print(f'training ended; loading best model parameters in {checkpointpath} for epoch {early_stop.best_epoch}')
        self.load_state_dict(torch.load(checkpointpath))
        return early_stop.best_score

    def predict(self, x, batch_size=100):
        self.eval()
        batcher = Batch(batch_size=batch_size, n_epochs=1, shuffle=False)
        predictions = []
        for xi in tqdm(batcher.epoch(x), desc='test'):
            xi = self.padder.transform(xi)
            logits = self.forward(xi)
            logits = nn.functional.log_softmax(logits, dim=1)
            prediction = tensor2numpy(torch.argmax(logits, dim=1).view(-1))
            predictions.append(prediction)
        return np.concatenate(predictions)

    def forward(self, x):
        phi = self.projector(x)
        return self.ff(phi)


class SameAuthorClassifier(nn.Module):
    def __init__(self, projector, num_authors, pad_index, pad_length=500, device='cpu'):
        super(SameAuthorClassifier, self).__init__()
        self.projector = projector.to(device)
        self.padder = Padding(pad_index=pad_index, max_length=pad_length, dynamic=True, pad_at_end=False, device=device)
        self.device = device

    def fit(self, X, y, batch_size, epochs, lr=0.001, steps_per_epoch=100):
        self.train()
        batcher = TwoClassBatch(batch_size=batch_size, n_epochs=epochs, steps_per_epoch=steps_per_epoch)
        optim = torch.optim.Adam(self.parameters(), lr=lr)

        pbar = tqdm(range(batcher.n_epochs))
        for epoch in pbar:
            losses = []
            for xi, yi in batcher.epoch(X, y):
                optim.zero_grad()
                xi = self.padder.transform(xi)
                phi = self.projector(xi)
                #normalize phi to have norm 1? maybe better as the last step of projector
                kernel = torch.matmul(phi, phi.T)
                ideal_kernel = torch.as_tensor(1 * (np.outer(1 + yi, 1 / (yi + 1)) == 1)).to(self.device)
                loss = KernelAlignmentLoss(kernel, ideal_kernel)
                loss.backward()
                #clip_gradient(model)
                optim.step()
                losses.append(loss.item())
                pbar.set_description(f'training epoch={epoch} loss={np.mean(losses):.5f}')

    def predict(self, x, z, batch_size=100):
        self.eval()
        batcher = Batch(batch_size=batch_size, n_epochs=1, shuffle=False)
        predictions = []
        for xi, zi in tqdm(batcher.epoch(x, z), desc='test'):
            xi = self.padder.transform(xi)
            zi = self.padder.transform(zi)
            inners = self.forward(xi, zi)
            prediction = tensor2numpy(inners) > 0.5 # is this correct? should it be > 0 and the ideal kernel in field {-1,+1}?
            predictions.append(prediction)
        return np.concatenate(predictions)

    def forward(self, x, z):
        assert x.shape == z.shape, 'shape mismatch between matrices x and z'
        phi_x = self.projector(x)
        phi_z = self.projector(z)
        rows, cols = phi_x.shape
        pairwise_inners = torch.bmm(phi_x.view(rows, 1, cols), phi_z.view(rows, cols, 1)).squeeze()
        return pairwise_inners


class FullAuthorClassifier(nn.Module):
    def __init__(self, projector, num_authors, pad_index, pad_length=500, device='cpu'):
        super(FullAuthorClassifier, self).__init__()
        self.projector = projector.to(device)
        self.ff = FFProjection(input_size=projector.space_dimensions(),
                               hidden_sizes=[1024],
                               output_size=num_authors).to(device)
        self.padder = Padding(pad_index=pad_index, max_length=pad_length, dynamic=True, pad_at_end=False, device=device)
        self.device = device

    def fit(self, X, y, batch_size, epochs, lr=0.001, steps_per_epoch=100):
        self.train()
        batcher = TwoClassBatch(batch_size=batch_size, n_epochs=epochs, steps_per_epoch=steps_per_epoch)
        criterion = torch.nn.CrossEntropyLoss().to(self.device)
        optim = torch.optim.Adam(self.parameters(), lr=lr)
        alpha = 0.5

        pbar = tqdm(range(batcher.n_epochs))
        for epoch in pbar:
            losses, sav_losses, attr_losses = [], [], []
            for xi, yi in batcher.epoch(X, y):
                optim.zero_grad()
                xi = self.padder.transform(xi)
                phi = self.projector(xi)
                #normalize phi to have norm 1? maybe better as the last step of projector

                #sav-loss
                kernel = torch.matmul(phi, phi.T)
                ideal_kernel = torch.as_tensor(1 * (np.outer(1 + yi, 1 / (yi + 1)) == 1)).to(self.device)
                sav_loss = KernelAlignmentLoss(kernel, ideal_kernel)
                sav_losses.append(sav_loss.item())

                #attr-loss
                logits = self.ff(phi)
                attr_loss = criterion(logits, torch.as_tensor(yi).to(self.device))
                attr_losses.append(attr_loss.item())

                #loss
                loss = (alpha)*sav_loss + (1-alpha)*attr_loss
                losses.append(loss.item())

                loss.backward()
                #clip_gradient(model)
                optim.step()
                pbar.set_description(
                    f'training epoch={epoch} '
                    f'sav-loss={np.mean(sav_losses):.5f} '
                    f'attr-loss={np.mean(attr_losses):.5f} '
                    f'loss={np.mean(losses):.5f}'
                )

    def predict_sav(self, x, z, batch_size=100):
        self.eval()
        batcher = Batch(batch_size=batch_size, n_epochs=1, shuffle=False)
        predictions = []
        for xi, zi in tqdm(batcher.epoch(x, z), desc='test'):
            xi = self.padder.transform(xi)
            zi = self.padder.transform(zi)
            phi_xi = self.projector(xi)
            phi_zi = self.projector(zi)
            rows, cols = phi_xi.shape
            pairwise_inners = torch.bmm(phi_xi.view(rows, 1, cols), phi_zi.view(rows, cols, 1)).squeeze()
            prediction = tensor2numpy(pairwise_inners) > 0.5 # is this correct? should it be > 0 and the ideal kernel in field {-1,+1}?
            predictions.append(prediction)
        return np.concatenate(predictions)

    def predict_labels(self, x, batch_size=100):
        self.eval()
        batcher = Batch(batch_size=batch_size, n_epochs=1, shuffle=False)
        predictions = []
        for xi in tqdm(batcher.epoch(x), desc='test'):
            xi = self.padder.transform(xi)
            phi = self.projector(xi)
            logits = self.ff(phi)
            prediction = tensor2numpy( torch.argmax(logits, dim=1).view(-1))
            predictions.append(prediction)
        return np.concatenate(predictions)


def KernelAlignmentLoss(K, Y):
    n_el = K.shape[0]*K.shape[1]
    loss = torch.norm(K - Y, p='fro')  # in Nello's paper this is different
    loss = loss / n_el  # this is in order to factor out the accumulation which is only due to the size
    return loss



class Batch:
    def __init__(self, batch_size, n_epochs=1, shuffle=True):
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.shuffle = shuffle
        self.current_epoch = 0

    def epoch(self, *args):
        lengths = list(map(len, args))
        assert max(lengths) == min(lengths), 'inconsistent sizes in args'
        n_batches = math.ceil(lengths[0] / self.batch_size)
        offset = 0
        if self.shuffle:
            index = np.random.permutation(len(args[0]))
            args = [arg[index] for arg in args]
        for b in range(n_batches):
            batch_idx = slice(offset, offset+self.batch_size)
            batch = [arg[batch_idx] for arg in args]
            yield batch if len(batch) > 1 else batch[0]
            offset += self.batch_size
        self.current_epoch += 1


class TwoClassBatch:
    """
    given a X and y (multi-label) produces batches of elements of X, y for two classes (e.g., c1, c2)
    of equal size, i.e., the batch is [(x1,c1), ..., (xn,c1), (xn+1,c2), ..., (x2n,c2)]
    """
    def __init__(self, batch_size, n_epochs, steps_per_epoch):
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.steps_per_epoch = steps_per_epoch
        self.current_epoch = 0
        if self.batch_size % 2 != 0:
            raise ValueError('warning, batch size is not even')

    def epoch(self, X, y):
        n_el = len(y)
        assert X.shape[0] == n_el, 'inconsistent sizes in X, y'
        classes = np.unique(y)
        groups = {ci: X[y==ci] for ci in classes}
        class_prevalences = [len(groups[ci])/n_el for ci in classes]
        n_choices = self.batch_size // 2

        for b in range(self.steps_per_epoch):
            class1, class2 = np.random.choice(classes, p=class_prevalences, size=2, replace=False)
            X1 = np.random.choice(groups[class1], size=n_choices)
            X2 = np.random.choice(groups[class2], size=n_choices)
            X_batch = np.concatenate([X1,X2])
            y_batch = np.repeat([class1, class2], repeats=[n_choices,n_choices])
            yield X_batch, y_batch
        self.current_epoch += 1


class Padding:
    def __init__(self, pad_index, max_length, dynamic=True, pad_at_end=True, device='cpu'):
        """
        :param pad_index: the index representing the PAD token
        :param max_length: the length that defines the padding
        :param dynamic: if True (default) pads at min(max_length, max_local_length) where max_local_length is the
        length of the longest example
        :param pad_at_end: if True, the pad tokens are added at the end of the lists, if otherwise they are added
        at the beginning
        """
        self.pad = pad_index
        self.max_length = max_length
        self.dynamic = dynamic
        self.pad_at_end = pad_at_end
        self.device = device

    def transform(self, X):
        """
        :param X: a list of lists of indexes (integers)
        :return: a ndarray of shape (n,m) where n is the number of elements in X and m is the pad length (the maximum
        in elements of X if dynamic, or self.max_length if otherwise)
        """
        X = [x[:self.max_length] for x in X]
        lengths = list(map(len, X))
        pad_length = min(max(lengths), self.max_length) if self.dynamic else self.max_length
        if self.pad_at_end:
            padded = [x + [self.pad] * (pad_length - x_len) for x, x_len in zip(X, lengths)]
        else:
            padded = [[self.pad] * (pad_length - x_len) + x for x, x_len in zip(X, lengths)]
        return torch.from_numpy(np.asarray(padded, dtype=int)).to(self.device)


def tensor2numpy(t):
    return t.to('cpu').detach().numpy()