forked from moreo/QuaPy
54 lines
2.0 KiB
Python
54 lines
2.0 KiB
Python
|
from data import LabelledCollection
|
||
|
from method.base import BaseQuantifier
|
||
|
from utils.util import temp_seed
|
||
|
import numpy as np
|
||
|
from joblib import Parallel, delayed
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
|
||
|
def artificial_sampling_prediction(
|
||
|
model: BaseQuantifier,
|
||
|
test: LabelledCollection,
|
||
|
sample_size,
|
||
|
prevalence_points=21,
|
||
|
point_repetitions=1,
|
||
|
n_jobs=-1,
|
||
|
random_seed=42):
|
||
|
"""
|
||
|
Performs the predictions for all samples generated according to the artificial sampling protocol.
|
||
|
:param model: the model in charge of generating the class prevalence estimations
|
||
|
:param test: the test set on which to perform arificial sampling
|
||
|
:param sample_size: the size of the samples
|
||
|
:param prevalence_points: the number of different prevalences to sample
|
||
|
:param point_repetitions: the number of repetitions for each prevalence
|
||
|
:param n_jobs: number of jobs to be run in parallel
|
||
|
:param random_seed: allows to replicate the samplings. The seed is local to the method and does not affect
|
||
|
any other random process.
|
||
|
:return: two ndarrays of [m,n] with m the number of samples (prevalence_points*point_repetitions) and n the
|
||
|
number of classes. The first one contains the true prevalences for the samples generated while the second one
|
||
|
containing the the prevalences estimations
|
||
|
"""
|
||
|
|
||
|
with temp_seed(random_seed):
|
||
|
indexes = list(test.artificial_sampling_index_generator(sample_size, prevalence_points, point_repetitions))
|
||
|
|
||
|
def _predict_prevalences(index):
|
||
|
sample = test.sampling_from_index(index)
|
||
|
true_prevalence = sample.prevalence()
|
||
|
estim_prevalence = model.quantify(sample.instances)
|
||
|
return true_prevalence, estim_prevalence
|
||
|
|
||
|
results = Parallel(n_jobs=n_jobs)(
|
||
|
delayed(_predict_prevalences)(index) for index in tqdm(indexes)
|
||
|
)
|
||
|
|
||
|
true_prevalences, estim_prevalences = zip(*results)
|
||
|
true_prevalences = np.asarray(true_prevalences)
|
||
|
estim_prevalences = np.asarray(estim_prevalences)
|
||
|
|
||
|
return true_prevalences, estim_prevalences
|
||
|
|
||
|
|
||
|
|
||
|
|