1
0
Fork 0

kde working with kfcv

This commit is contained in:
Alejandro Moreo Fernandez 2023-07-24 16:28:21 +02:00
parent 0ce6106ee1
commit 96758834f3
6 changed files with 184 additions and 102 deletions

View File

@ -27,7 +27,7 @@ if __name__ == '__main__':
'classifier__class_weight': ['balanced', None] 'classifier__class_weight': ['balanced', None]
} }
for method in ['PACC', 'SLD', 'DM', 'KDE', 'HDy', 'DIR']: for method in ['KDE', 'PACC', 'SLD', 'DM', 'HDy-OvA', 'DIR']:
#if os.path.exists(result_path): #if os.path.exists(result_path):
# print('Result already exit. Nothing to do') # print('Result already exit. Nothing to do')
@ -43,7 +43,7 @@ if __name__ == '__main__':
dataset = 'T1B' dataset = 'T1B'
train, val_gen, test_gen = qp.datasets.fetch_lequa2022(dataset) train, val_gen, test_gen = qp.datasets.fetch_lequa2022(dataset)
print('init', dataset) print(f'init {dataset} #instances: {len(train)}')
if method == 'KDE': if method == 'KDE':
param_grid = { param_grid = {
'bandwidth': np.linspace(0.001, 0.2, 21), 'bandwidth': np.linspace(0.001, 0.2, 21),
@ -51,6 +51,11 @@ if __name__ == '__main__':
'classifier__class_weight': ['balanced', None] 'classifier__class_weight': ['balanced', None]
} }
quantifier = KDEy(LogisticRegression(), target='max_likelihood') quantifier = KDEy(LogisticRegression(), target='max_likelihood')
elif method == 'KDE-debug':
param_grid = None
qp.environ['N_JOBS'] = 1
quantifier = KDEy(LogisticRegression(), target='max_likelihood', bandwidth=0.02)
#train = train.sampling(280, *[1./train.n_classes]*(train.n_classes-1))
elif method == 'DIR': elif method == 'DIR':
param_grid = hyper_LR param_grid = hyper_LR
quantifier = DIRy(LogisticRegression()) quantifier = DIRy(LogisticRegression())
@ -62,7 +67,7 @@ if __name__ == '__main__':
quantifier = PACC(LogisticRegression()) quantifier = PACC(LogisticRegression())
elif method == 'HDy-OvA': elif method == 'HDy-OvA':
param_grid = { param_grid = {
'binary_quantifier__classifier__C': np.logspace(-4,4,9), 'binary_quantifier__classifier__C': np.logspace(-3,3,9),
'binary_quantifier__classifier__class_weight': ['balanced', None] 'binary_quantifier__classifier__class_weight': ['balanced', None]
} }
quantifier = OneVsAllAggregative(HDy(LogisticRegression())) quantifier = OneVsAllAggregative(HDy(LogisticRegression()))
@ -76,13 +81,17 @@ if __name__ == '__main__':
else: else:
raise NotImplementedError('unknown method', method) raise NotImplementedError('unknown method', method)
modsel = GridSearchQ(quantifier, param_grid, protocol=val_gen, refit=False, n_jobs=-1, verbose=1, error=optim) if param_grid is not None:
modsel = GridSearchQ(quantifier, param_grid, protocol=val_gen, refit=False, n_jobs=-1, verbose=1, error=optim)
modsel.fit(train) modsel.fit(train)
print(f'best params {modsel.best_params_}') print(f'best params {modsel.best_params_}')
pickle.dump(modsel.best_params_, open(f'{result_dir}/{method}_{dataset}.hyper.pkl', 'wb'), pickle.HIGHEST_PROTOCOL) pickle.dump(modsel.best_params_, open(f'{result_dir}/{method}_{dataset}.hyper.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)
quantifier = modsel.best_model() quantifier = modsel.best_model()
else:
print('debug mode... skipping model selection')
quantifier.fit(train)
report = qp.evaluation.evaluation_report(quantifier, protocol=test_gen, error_metrics=['mae', 'mrae', 'kld'], verbose=True) report = qp.evaluation.evaluation_report(quantifier, protocol=test_gen, error_metrics=['mae', 'mrae', 'kld'], verbose=True)
means = report.mean() means = report.mean()

View File

@ -6,12 +6,13 @@ import sys
import pandas as pd import pandas as pd
import quapy as qp import quapy as qp
from quapy.method.aggregative import EMQ, DistributionMatching, PACC, HDy, OneVsAllAggregative from quapy.method.aggregative import EMQ, DistributionMatching, PACC, ACC, CC, PCC, HDy, OneVsAllAggregative
from method_kdey import KDEy from method_kdey import KDEy
from method_dirichlety import DIRy from method_dirichlety import DIRy
from quapy.model_selection import GridSearchQ from quapy.model_selection import GridSearchQ
from quapy.protocol import UPP from quapy.protocol import UPP
SEED=1
if __name__ == '__main__': if __name__ == '__main__':
@ -29,7 +30,7 @@ if __name__ == '__main__':
'classifier__class_weight': ['balanced', None] 'classifier__class_weight': ['balanced', None]
} }
for method in ['PACC', 'SLD', 'DM', 'KDE', 'HDy', 'DIR']: for method in ['KDE-nomonte', 'KDE-monte2', 'SLD', 'KDE-kfcv']:# , 'DIR', 'DM', 'HDy-OvA', 'CC', 'ACC', 'PCC']:
#if os.path.exists(result_path): #if os.path.exists(result_path):
# print('Result already exit. Nothing to do') # print('Result already exit. Nothing to do')
@ -49,69 +50,100 @@ if __name__ == '__main__':
for dataset in qp.datasets.TWITTER_SENTIMENT_DATASETS_TEST: for dataset in qp.datasets.TWITTER_SENTIMENT_DATASETS_TEST:
print('init', dataset) print('init', dataset)
is_semeval = dataset.startswith('semeval')
if not is_semeval or not semeval_trained:
if method == 'KDE':
param_grid = {
'bandwidth': np.linspace(0.001, 0.2, 21),
'classifier__C': np.logspace(-4,4,9),
'classifier__class_weight': ['balanced', None]
}
quantifier = KDEy(LogisticRegression(), target='max_likelihood')
elif method == 'DIR':
param_grid = hyper_LR
quantifier = DIRy(LogisticRegression())
elif method == 'SLD':
param_grid = hyper_LR
quantifier = EMQ(LogisticRegression())
elif method == 'PACC':
param_grid = hyper_LR
quantifier = PACC(LogisticRegression())
elif method == 'HDy-OvA':
param_grid = {
'binary_quantifier__classifier__C': np.logspace(-4,4,9),
'binary_quantifier__classifier__class_weight': ['balanced', None]
}
quantifier = OneVsAllAggregative(HDy(LogisticRegression()))
elif method == 'DM':
param_grid = {
'nbins': [5,10,15],
'classifier__C': np.logspace(-4,4,9),
'classifier__class_weight': ['balanced', None]
}
quantifier = DistributionMatching(LogisticRegression())
else:
raise NotImplementedError('unknown method', method)
# model selection
data = qp.datasets.fetch_twitter(dataset, min_df=3, pickle=True, for_model_selection=True)
protocol = UPP(data.test, repeats=n_bags_val)
modsel = GridSearchQ(quantifier, param_grid, protocol, refit=False, n_jobs=-1, verbose=1, error=optim)
modsel.fit(data.training)
print(f'best params {modsel.best_params_}')
pickle.dump(modsel.best_params_, open(f'{result_dir}/{method}_{dataset}.hyper.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)
quantifier = modsel.best_model()
if is_semeval:
semeval_trained = True
else: with qp.util.temp_seed(SEED):
print(f'model selection for {dataset} already done; skipping')
data = qp.datasets.fetch_twitter(dataset, min_df=3, pickle=True, for_model_selection=False) is_semeval = dataset.startswith('semeval')
quantifier.fit(data.training)
protocol = UPP(data.test, repeats=n_bags_test) if not is_semeval or not semeval_trained:
report = qp.evaluation.evaluation_report(quantifier, protocol, error_metrics=['mae', 'mrae', 'kld'], verbose=True)
report.to_csv(result_path+'.dataframe') if method == 'KDE':
means = report.mean() param_grid = {
csv.write(f'{method}\t{data.name}\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\t{means["kld"]:.5f}\n') 'bandwidth': np.linspace(0.001, 0.2, 21),
csv.flush() 'classifier__C': np.logspace(-4,4,9),
'classifier__class_weight': ['balanced', None]
}
quantifier = KDEy(LogisticRegression(), target='max_likelihood')
elif method == 'KDE-kfcv':
param_grid = {
'bandwidth': np.linspace(0.001, 0.2, 21),
'classifier__C': np.logspace(-4,4,9),
'classifier__class_weight': ['balanced', None]
}
quantifier = KDEy(LogisticRegression(), target='max_likelihood', val_split=10)
elif method in ['KDE-monte2']:
param_grid = {
'bandwidth': np.linspace(0.001, 0.2, 21),
}
quantifier = KDEy(LogisticRegression(), target='min_divergence')
elif method in ['KDE-nomonte']:
param_grid = {
'bandwidth': np.linspace(0.001, 0.2, 21),
}
quantifier = KDEy(LogisticRegression(), target='max_likelihood')
elif method == 'DIR':
param_grid = hyper_LR
quantifier = DIRy(LogisticRegression())
elif method == 'SLD':
param_grid = hyper_LR
quantifier = EMQ(LogisticRegression())
elif method == 'PACC':
param_grid = hyper_LR
quantifier = PACC(LogisticRegression())
elif method == 'PACC-kfcv':
param_grid = hyper_LR
quantifier = PACC(LogisticRegression(), val_split=10)
elif method == 'PCC':
param_grid = hyper_LR
quantifier = PCC(LogisticRegression())
elif method == 'ACC':
param_grid = hyper_LR
quantifier = ACC(LogisticRegression())
elif method == 'CC':
param_grid = hyper_LR
quantifier = CC(LogisticRegression())
elif method == 'HDy-OvA':
param_grid = {
'binary_quantifier__classifier__C': np.logspace(-4,4,9),
'binary_quantifier__classifier__class_weight': ['balanced', None]
}
quantifier = OneVsAllAggregative(HDy(LogisticRegression()))
elif method == 'DM':
param_grid = {
'nbins': [5,10,15],
'classifier__C': np.logspace(-4,4,9),
'classifier__class_weight': ['balanced', None]
}
quantifier = DistributionMatching(LogisticRegression())
else:
raise NotImplementedError('unknown method', method)
# model selection
data = qp.datasets.fetch_twitter(dataset, min_df=3, pickle=True, for_model_selection=True)
protocol = UPP(data.test, repeats=n_bags_val)
modsel = GridSearchQ(quantifier, param_grid, protocol, refit=False, n_jobs=-1, verbose=1, error=optim)
modsel.fit(data.training)
print(f'best params {modsel.best_params_}')
pickle.dump(modsel.best_params_, open(f'{result_dir}/{method}_{dataset}.hyper.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)
quantifier = modsel.best_model()
if is_semeval:
semeval_trained = True
else:
print(f'model selection for {dataset} already done; skipping')
data = qp.datasets.fetch_twitter(dataset, min_df=3, pickle=True, for_model_selection=False)
quantifier.fit(data.training)
protocol = UPP(data.test, repeats=n_bags_test)
report = qp.evaluation.evaluation_report(quantifier, protocol, error_metrics=['mae', 'mrae', 'kld'], verbose=True)
report.to_csv(result_path+'.dataframe')
means = report.mean()
csv.write(f'{method}\t{data.name}\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\t{means["kld"]:.5f}\n')
csv.flush()
df = pd.read_csv(result_path+'.csv', sep='\t') df = pd.read_csv(result_path+'.csv', sep='\t')

View File

@ -30,7 +30,7 @@ class KDEy(AggregativeProbabilisticQuantifier):
TARGET = ['min_divergence', 'max_likelihood'] TARGET = ['min_divergence', 'max_likelihood']
def __init__(self, classifier: BaseEstimator, val_split=0.4, divergence: Union[str, Callable]='HD', def __init__(self, classifier: BaseEstimator, val_split=0.4, divergence: Union[str, Callable]='HD',
bandwidth='scott', engine='sklearn', target='min_divergence', n_jobs=None): bandwidth='scott', engine='sklearn', target='min_divergence', n_jobs=None, random_state=0):
assert bandwidth in KDEy.BANDWIDTH_METHOD or isinstance(bandwidth, float), \ assert bandwidth in KDEy.BANDWIDTH_METHOD or isinstance(bandwidth, float), \
f'unknown bandwidth_method, valid ones are {KDEy.BANDWIDTH_METHOD}' f'unknown bandwidth_method, valid ones are {KDEy.BANDWIDTH_METHOD}'
assert engine in KDEy.ENGINE, f'unknown engine, valid ones are {KDEy.ENGINE}' assert engine in KDEy.ENGINE, f'unknown engine, valid ones are {KDEy.ENGINE}'
@ -42,6 +42,7 @@ class KDEy(AggregativeProbabilisticQuantifier):
self.engine = engine self.engine = engine
self.target = target self.target = target
self.n_jobs = n_jobs self.n_jobs = n_jobs
self.random_state=random_state
def search_bandwidth_maxlikelihood(self, posteriors, labels): def search_bandwidth_maxlikelihood(self, posteriors, labels):
grid = {'bandwidth': np.linspace(0.001, 0.2, 100)} grid = {'bandwidth': np.linspace(0.001, 0.2, 100)}
@ -84,14 +85,20 @@ class KDEy(AggregativeProbabilisticQuantifier):
kde = scipy.stats.gaussian_kde(posteriors) kde = scipy.stats.gaussian_kde(posteriors)
kde.set_bandwidth(self.bandwidth) kde.set_bandwidth(self.bandwidth)
elif self.engine == 'sklearn': elif self.engine == 'sklearn':
#print('fitting kde')
kde = KernelDensity(bandwidth=self.bandwidth).fit(posteriors) kde = KernelDensity(bandwidth=self.bandwidth).fit(posteriors)
#print('[fitting done]')
return kde return kde
def pdf(self, kde, posteriors): def pdf(self, kde, posteriors):
if self.engine == 'scipy': if self.engine == 'scipy':
return kde(posteriors[:, :-1].T) return kde(posteriors[:, :-1].T)
elif self.engine == 'sklearn': elif self.engine == 'sklearn':
return np.exp(kde.score_samples(posteriors)) #print('pdf...')
densities = np.exp(kde.score_samples(posteriors))
#print('[pdf done]')
return densities
#return np.exp(kde.score_samples(posteriors))
def fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, LabelledCollection] = None): def fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, LabelledCollection] = None):
""" """
@ -118,13 +125,13 @@ class KDEy(AggregativeProbabilisticQuantifier):
return self return self
def val_pdf(self, prev): #def val_pdf(self, prev):
""" """
Returns a function that computes the mixture model with the given prev as mixture factor Returns a function that computes the mixture model with the given prev as mixture factor
:param prev: a prevalence vector, ndarray :param prev: a prevalence vector, ndarray
:return: a function implementing the validation distribution with fixed mixture factor :return: a function implementing the validation distribution with fixed mixture factor
""" """
return lambda posteriors: sum(prev_i * self.pdf(kde_i, posteriors) for kde_i, prev_i in zip(self.val_densities, prev)) # return lambda posteriors: sum(prev_i * self.pdf(kde_i, posteriors) for kde_i, prev_i in zip(self.val_densities, prev))
def aggregate(self, posteriors: np.ndarray): def aggregate(self, posteriors: np.ndarray):
if self.target == 'min_divergence': if self.target == 'min_divergence':
@ -134,14 +141,9 @@ class KDEy(AggregativeProbabilisticQuantifier):
else: else:
raise ValueError('unknown target') raise ValueError('unknown target')
def _target_divergence(self, posteriors): def _target_divergence_depr(self, posteriors):
""" # this variant is, I think, ill-formed, since it evaluates the likelihood on the test points, which are
Searches for the mixture model parameter (the sought prevalence values) that yields a validation distribution # overconfident in the KDE-test.
(the mixture) that best matches the test distribution, in terms of the divergence measure of choice.
:param instances: instances in the sample
:return: a vector of class prevalence estimates
"""
test_density = self.get_kde(posteriors) test_density = self.get_kde(posteriors)
# val_test_posteriors = np.concatenate([self.val_posteriors, posteriors]) # val_test_posteriors = np.concatenate([self.val_posteriors, posteriors])
test_likelihood = self.pdf(test_density, posteriors) test_likelihood = self.pdf(test_density, posteriors)
@ -164,6 +166,31 @@ class KDEy(AggregativeProbabilisticQuantifier):
r = optimize.minimize(match, x0=uniform_distribution, method='SLSQP', bounds=bounds, constraints=constraints) r = optimize.minimize(match, x0=uniform_distribution, method='SLSQP', bounds=bounds, constraints=constraints)
return r.x return r.x
def _target_divergence(self, posteriors, montecarlo_samples=5000):
# in this variant we evaluate the divergence using a Montecarlo approach
n_classes = len(self.val_densities)
samples = qp.functional.uniform_prevalence_sampling(n_classes, size=montecarlo_samples)
test_kde = self.get_kde(posteriors)
test_likelihood = self.pdf(test_kde, samples)
divergence = _get_divergence(self.divergence)
sample_densities = [self.pdf(kde_i, samples) for kde_i in self.val_densities]
def match(prev):
val_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip (prev, sample_densities))
return divergence(val_likelihood, test_likelihood)
# the initial point is set as the uniform distribution
uniform_distribution = np.full(fill_value=1 / n_classes, shape=(n_classes,))
# solutions are bounded to those contained in the unit-simplex
bounds = tuple((0, 1) for _ in range(n_classes)) # values in [0,1]
constraints = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)}) # values summing up to 1
r = optimize.minimize(match, x0=uniform_distribution, method='SLSQP', bounds=bounds, constraints=constraints)
return r.x
def _target_likelihood(self, posteriors, eps=0.000001): def _target_likelihood(self, posteriors, eps=0.000001):
""" """
Searches for the mixture model parameter (the sought prevalence values) that yields a validation distribution Searches for the mixture model parameter (the sought prevalence values) that yields a validation distribution
@ -172,13 +199,20 @@ class KDEy(AggregativeProbabilisticQuantifier):
:param instances: instances in the sample :param instances: instances in the sample
:return: a vector of class prevalence estimates :return: a vector of class prevalence estimates
""" """
np.random.RandomState(self.random_state)
n_classes = len(self.val_densities) n_classes = len(self.val_densities)
test_densities = [self.pdf(kde_i, posteriors) for kde_i in self.val_densities]
#return lambda posteriors: sum(prev_i * self.pdf(kde_i, posteriors) for kde_i, prev_i in zip(self.val_densities, prev))
def neg_loglikelihood(prev): def neg_loglikelihood(prev):
val_pdf = self.val_pdf(prev) #print('-neg_likelihood')
test_likelihood = val_pdf(posteriors) #val_pdf = self.val_pdf(prev)
test_loglikelihood = np.log(test_likelihood + eps) #test_likelihood = val_pdf(posteriors)
return -np.sum(test_loglikelihood) test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip (prev, test_densities))
test_loglikelihood = np.log(test_mixture_likelihood + eps)
neg_log_likelihood = -np.sum(test_loglikelihood)
#print('-neg_likelihood [done!]')
return neg_log_likelihood
#return -np.prod(test_likelihood) #return -np.prod(test_likelihood)
# the initial point is set as the uniform distribution # the initial point is set as the uniform distribution
@ -187,5 +221,7 @@ class KDEy(AggregativeProbabilisticQuantifier):
# solutions are bounded to those contained in the unit-simplex # solutions are bounded to those contained in the unit-simplex
bounds = tuple((0, 1) for _ in range(n_classes)) # values in [0,1] bounds = tuple((0, 1) for _ in range(n_classes)) # values in [0,1]
constraints = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)}) # values summing up to 1 constraints = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)}) # values summing up to 1
#print('searching for alpha')
r = optimize.minimize(neg_loglikelihood, x0=uniform_distribution, method='SLSQP', bounds=bounds, constraints=constraints) r = optimize.minimize(neg_loglikelihood, x0=uniform_distribution, method='SLSQP', bounds=bounds, constraints=constraints)
#print('[optimization ended]')
return r.x return r.x

View File

@ -2,18 +2,18 @@ import sys
from pathlib import Path from pathlib import Path
import pandas as pd import pandas as pd
#result_dir = 'results_tweet_1000' result_dir = 'results_tweet_1000'
result_dir = 'results_lequa' #result_dir = 'results_lequa'
dfs = [] dfs = []
pathlist = Path(result_dir).rglob('*.csv') pathlist = Path(result_dir).rglob('*.csv')
for path in pathlist: for path in pathlist:
path_in_str = str(path) path_in_str = str(path)
print(path_in_str)
try: try:
df = pd.read_csv(path_in_str, sep='\t') df = pd.read_csv(path_in_str, sep='\t')
df = df[df.iloc[:, 0] != df.columns[0]]
if not df.empty: if not df.empty:
dfs.append(df) dfs.append(df)
except Exception: except Exception:
@ -21,7 +21,7 @@ for path in pathlist:
df = pd.concat(dfs) df = pd.concat(dfs)
for err in ['MAE', 'MRAE']: for err in ['MAE', 'MRAE', 'KLD']:
print('-'*100) print('-'*100)
print(err) print(err)
print('-'*100) print('-'*100)

View File

@ -4,23 +4,28 @@ y el otro es un KDE en test), de las que luego se calculará la divergencia (obj
generar solo una distribución (mixture model de train) y tomar la likelihood de los puntos de test como objetivo generar solo una distribución (mixture model de train) y tomar la likelihood de los puntos de test como objetivo
a maximizar. a maximizar.
- quedarse con hyperparametros mejores por verlos - echar un ojo a los hyperparametros
- sacar los dataframes en resultados para hcer test estadisticos
- hacer dibujitos - hacer dibujitos
- estudiar el caso en que el target es minimizar una divergencia. Posibilidades:
- evaluar los puntos de test solo
- evaluar un APP sobre el simplexo?
- evaluar un UPP sobre el simplexo? (=Montecarlo)
- qué divergencias? HD, topsoe, L1?
1) aclarar: only test?
2) implementar el auto 2) implementar el auto
- optimización interna para likelihood [ninguno parece funcionar bien] - optimización interna para likelihood [ninguno parece funcionar bien]
- de todo (e.g., todo el training)? - de todo (e.g., todo el training)?
- independiente para cada conjunto etiquetado? (e.g., positivos, negativos, neutros, y test) - independiente para cada conjunto etiquetado? (e.g., positivos, negativos, neutros, y test)
- optimización como un parámetro GridSearchQ - optimización como un parámetro GridSearchQ
3) aclarar: topsoe? 6) optimizar kernel? optimizar distancia?
4) otro tipo de model selection?
5) aumentar numero de bags
6) optimizar parametro C? optimizar kernel? optimizar distancia?
7) KDE de sklearn o multivariate KDE de statsmodel? ver también qué es esto (parece que da P(Y|X) o sea que podría 7) KDE de sklearn o multivariate KDE de statsmodel? ver también qué es esto (parece que da P(Y|X) o sea que podría
eliminar el clasificador?): eliminar el clasificador?):
https://www.statsmodels.org/dev/_modules/statsmodels/nonparametric/kernel_density.html#KDEMultivariateConditional https://www.statsmodels.org/dev/_modules/statsmodels/nonparametric/kernel_density.html#KDEMultivariateConditional
8) quitar la ultima dimension en sklearn también? 8) quitar la ultima dimension en sklearn también? No veo porqué
9) optimizar para RAE en vez de AE? 9) optimizar para RAE en vez de AE? No va bien...
10) Definir un clasificador que devuelva, para cada clase, una posterior como la likelihood en la class-conditional KDE dividida
por la likelihood en en todas las clases (como propone Juanjo) y meterlo en EMD. Hacer al contario: re-calibrar con
EMD y meterlo en KDEy
11) KDEx?
12) Dirichlet (el método DIR) habría que arreglarlo y mostrar resultados...
13) Test estadisticos.

View File

@ -123,7 +123,7 @@ class LabelledCollection:
return self.uniform_sampling_index(size, random_state=random_state) return self.uniform_sampling_index(size, random_state=random_state)
if len(prevs) == self.n_classes - 1: if len(prevs) == self.n_classes - 1:
prevs = prevs + (1 - sum(prevs),) prevs = prevs + (1 - sum(prevs),)
assert len(prevs) == self.n_classes, 'unexpected number of prevalences' assert len(prevs) == self.n_classes, f'unexpected number of prevalences (found {len(prevs)}, expected {self.n_classes})'
assert sum(prevs) == 1, f'prevalences ({prevs}) wrong range (sum={sum(prevs)})' assert sum(prevs) == 1, f'prevalences ({prevs}) wrong range (sum={sum(prevs)})'
# Decide how many instances should be taken for each class in order to satisfy the requested prevalence # Decide how many instances should be taken for each class in order to satisfy the requested prevalence