forked from moreo/QuaPy
210 lines
8.2 KiB
Python
210 lines
8.2 KiB
Python
from typing import List, Union
|
|
|
|
import numpy as np
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
from quapy.data import LabelledCollection
|
|
from quapy.functional import artificial_prevalence_sampling
|
|
|
|
from skmultilearn.model_selection import iterative_train_test_split
|
|
|
|
class MultilabelledCollection:
|
|
def __init__(self, instances, labels):
|
|
assert labels.ndim==2, f'data does not seem to be multilabel {labels}'
|
|
self.instances = instances
|
|
self.labels = labels
|
|
self.classes_ = np.arange(labels.shape[1])
|
|
|
|
@classmethod
|
|
def load(cls, path: str, loader_func: callable):
|
|
return MultilabelledCollection(*loader_func(path))
|
|
|
|
def __len__(self):
|
|
return self.instances.shape[0]
|
|
|
|
def prevalence(self):
|
|
# return self.labels.mean(axis=0)
|
|
pos = self.labels.mean(axis=0)
|
|
neg = 1-pos
|
|
return np.asarray([neg, pos]).T
|
|
|
|
def counts(self):
|
|
return self.labels.sum(axis=0)
|
|
|
|
@property
|
|
def n_classes(self):
|
|
return len(self.classes_)
|
|
|
|
@property
|
|
def n_features(self):
|
|
return self.instances.shape[1]
|
|
|
|
@property
|
|
def binary(self):
|
|
return False
|
|
|
|
def __gen_index(self):
|
|
return np.arange(len(self))
|
|
|
|
def sampling_multi_index(self, size, cat, prev=None):
|
|
if prev is None: # no prevalence was indicated; returns an index for uniform sampling
|
|
return np.random.choice(len(self), size, replace=size > len(self))
|
|
aux = LabelledCollection(self.__gen_index(), self.labels[:, cat])
|
|
return aux.sampling_index(size, *[1-prev, prev])
|
|
|
|
def uniform_sampling_multi_index(self, size):
|
|
return np.random.choice(len(self), size, replace=size>len(self))
|
|
|
|
def uniform_sampling(self, size):
|
|
unif_index = self.uniform_sampling_multi_index(size)
|
|
return self.sampling_from_index(unif_index)
|
|
|
|
def sampling(self, size, category, prev=None):
|
|
prev_index = self.sampling_multi_index(size, category, prev)
|
|
return self.sampling_from_index(prev_index)
|
|
|
|
def sampling_from_index(self, index):
|
|
documents = self.instances[index]
|
|
labels = self.labels[index]
|
|
return MultilabelledCollection(documents, labels)
|
|
|
|
def train_test_split(self, train_prop=0.6, random_state=None, iterative=False):
|
|
if iterative:
|
|
tr_docs, tr_labels, te_docs, te_labels = \
|
|
iterative_train_test_split(self.instances, self.labels, test_size=1-train_prop)
|
|
else:
|
|
tr_docs, te_docs, tr_labels, te_labels = \
|
|
train_test_split(self.instances, self.labels, train_size=train_prop, random_state=random_state)
|
|
return MultilabelledCollection(tr_docs, tr_labels), MultilabelledCollection(te_docs, te_labels)
|
|
|
|
def artificial_sampling_generator(self, sample_size, category, n_prevalences=101, repeats=1):
|
|
dimensions = 2
|
|
for prevs in artificial_prevalence_sampling(dimensions, n_prevalences, repeats).flatten():
|
|
yield self.sampling(sample_size, category, prevs)
|
|
|
|
def artificial_sampling_index_generator(self, sample_size, category, n_prevalences=101, repeats=1):
|
|
dimensions = 2
|
|
for prevs in artificial_prevalence_sampling(dimensions, n_prevalences, repeats).flatten():
|
|
yield self.sampling_multi_index(sample_size, category, prevs)
|
|
|
|
def natural_sampling_generator(self, sample_size, repeats=100):
|
|
for _ in range(repeats):
|
|
yield self.uniform_sampling(sample_size)
|
|
|
|
def natural_sampling_index_generator(self, sample_size, repeats=100):
|
|
for _ in range(repeats):
|
|
yield self.uniform_sampling_multi_index(sample_size)
|
|
|
|
def asLabelledCollection(self, category):
|
|
return LabelledCollection(self.instances, self.labels[:,category])
|
|
|
|
def genLabelledCollections(self):
|
|
for c in self.classes_:
|
|
yield self.asLabelledCollection(c)
|
|
|
|
# @property
|
|
# def label_cardinality(self):
|
|
# return self.labels.sum()/len(self)
|
|
|
|
@property
|
|
def Xy(self):
|
|
return self.instances, self.labels
|
|
|
|
|
|
class MultilingualLabelledCollection:
|
|
def __init__(self, langs:List[str], labelledCollections:List[Union[LabelledCollection, MultilabelledCollection]]):
|
|
assert len(langs) == len(labelledCollections), 'length mismatch for langs and labelledCollection lists'
|
|
assert all(isinstance(lc, LabelledCollection) or all(isinstance(lc, MultilabelledCollection)) for lc in labelledCollections), \
|
|
'unexpected type for labelledCollections'
|
|
assert all(labelledCollections[0].classes_ == lc_i.classes_ for lc_i in labelledCollections[1:]), \
|
|
'inconsistent classes found for some labelled collections'
|
|
self.llc = {l: lc for l, lc in zip(langs, labelledCollections)}
|
|
self.classes_=labelledCollections[0].classes_
|
|
|
|
@classmethod
|
|
def fromLangDict(cls, lang_labelledCollection:dict):
|
|
return MultilingualLabelledCollection(*list(zip(*list(lang_labelledCollection.items()))))
|
|
|
|
def langs(self):
|
|
return list(sorted(self.llc.keys()))
|
|
|
|
def __getitem__(self, lang)->LabelledCollection:
|
|
return self.llc[lang]
|
|
|
|
@classmethod
|
|
def load(cls, path: str, loader_func: callable):
|
|
return MultilingualLabelledCollection(*loader_func(path))
|
|
|
|
def __len__(self):
|
|
return sum(map(len, self.llc.values()))
|
|
|
|
def prevalence(self):
|
|
prev = np.asarray([lc.prevalence() * len(lc) for lc in self.llc.values()]).sum(axis=0)
|
|
return prev / prev.sum()
|
|
|
|
def language_prevalence(self):
|
|
lang_count = np.asarray([len(self.llc[l]) for l in self.langs()])
|
|
return lang_count / lang_count.sum()
|
|
|
|
def counts(self):
|
|
return np.asarray([lc.counts() for lc in self.llc.values()]).sum(axis=0)
|
|
|
|
@property
|
|
def n_classes(self):
|
|
return len(self.classes_)
|
|
|
|
@property
|
|
def binary(self):
|
|
return self.n_classes == 2
|
|
|
|
def __check_langs(self, l_dict:dict):
|
|
assert len(l_dict)==len(self.langs()), 'wrong number of languages'
|
|
assert all(l in l_dict for l in self.langs()), 'missing languages in l_sizes'
|
|
|
|
def __check_sizes(self, l_sizes: Union[int,dict]):
|
|
assert isinstance(l_sizes, int) or isinstance(l_sizes, dict), 'unexpected type for l_sizes'
|
|
if isinstance(l_sizes, int):
|
|
return {l:l_sizes for l in self.langs()}
|
|
self.__check_langs(l_sizes)
|
|
return l_sizes
|
|
|
|
def sampling_index(self, l_sizes: Union[int,dict], *prevs, shuffle=True):
|
|
l_sizes = self.__check_sizes(l_sizes)
|
|
return {l:lc.sampling_index(l_sizes[l], *prevs, shuffle=shuffle) for l,lc in self.llc.items()}
|
|
|
|
def uniform_sampling_index(self, l_sizes: Union[int, dict]):
|
|
l_sizes = self.__check_sizes(l_sizes)
|
|
return {l: lc.uniform_sampling_index(l_sizes[l]) for l,lc in self.llc.items()}
|
|
|
|
def uniform_sampling(self, l_sizes: Union[int, dict]):
|
|
l_sizes = self.__check_sizes(l_sizes)
|
|
return MultilingualLabelledCollection.fromLangDict(
|
|
{l: lc.uniform_sampling(l_sizes[l]) for l,lc in self.llc.items()}
|
|
)
|
|
|
|
def sampling(self, l_sizes: Union[int, dict], *prevs, shuffle=True):
|
|
l_sizes = self.__check_sizes(l_sizes)
|
|
return MultilingualLabelledCollection.fromLangDict(
|
|
{l: lc.sampling(l_sizes[l], *prevs, shuffle=shuffle) for l,lc in self.llc.items()}
|
|
)
|
|
|
|
def sampling_from_index(self, l_index:dict):
|
|
self.__check_langs(l_index)
|
|
return MultilingualLabelledCollection.fromLangDict(
|
|
{l: lc.sampling_from_index(l_index[l]) for l,lc in self.llc.items()}
|
|
)
|
|
|
|
def split_stratified(self, train_prop=0.6, random_state=None):
|
|
train, test = list(zip(*[self[l].split_stratified(train_prop, random_state) for l in self.langs()]))
|
|
return MultilingualLabelledCollection(self.langs(), train), MultilingualLabelledCollection(self.langs(), test)
|
|
|
|
def asLabelledCollection(self, return_langs=False):
|
|
lXy_list = [([l]*len(lc),*lc.Xy) for l, lc in self.llc.items()] # a list with (lang_i, Xi, yi)
|
|
ls,Xs,ys = list(zip(*lXy_list))
|
|
ls = np.concatenate(ls)
|
|
vertstack = vstack if issparse(Xs[0]) else np.vstack
|
|
Xs = vertstack(Xs)
|
|
ys = np.concatenate(ys)
|
|
lc = LabelledCollection(Xs, ys, classes_=self.classes_)
|
|
# return lc, ls if return_langs else lc
|