1
0
Fork 0
QuaPy/MultiLabel/mlevaluation.py

118 lines
4.5 KiB
Python

from typing import Union, Callable
import numpy as np
import quapy as qp
from MultiLabel.mlquantification import MLAggregativeQuantifier
from mldata import MultilabelledCollection
import itertools
from tqdm import tqdm
def check_error_str(error_metric):
if isinstance(error_metric, str):
error_metric = qp.error.from_name(error_metric)
assert hasattr(error_metric, '__call__'), 'invalid error function'
return error_metric
def _ml_prevalence_predictions(model,
test: MultilabelledCollection,
test_indexes):
predict_batch_fn = _predict_quantification_batch
if isinstance(model, MLAggregativeQuantifier):
test = MultilabelledCollection(model.preclassify(test.instances), test.labels)
predict_batch_fn = _predict_aggregative_batch
args = tuple([model, test, test_indexes])
true_prevs, estim_prevs = predict_batch_fn(args)
return true_prevs, estim_prevs
def ml_natural_prevalence_prediction(model,
test:MultilabelledCollection,
sample_size,
repeats=100,
random_seed=42):
with qp.util.temp_seed(random_seed):
test_indexes = list(test.natural_sampling_index_generator(sample_size=sample_size, repeats=repeats))
return _ml_prevalence_predictions(model, test, test_indexes)
def ml_natural_prevalence_evaluation(model,
test:MultilabelledCollection,
sample_size,
repeats=100,
error_metric:Union[str,Callable]='mae',
random_seed=42):
error_metric = check_error_str(error_metric)
true_prevs, estim_prevs = ml_natural_prevalence_prediction(model, test, sample_size, repeats, random_seed)
errs = [error_metric(true_prev_i, estim_prev_i) for true_prev_i, estim_prev_i in zip(true_prevs, estim_prevs)]
return np.mean(errs)
def ml_artificial_prevalence_prediction(model,
test:MultilabelledCollection,
sample_size,
n_prevalences=21,
repeats=10,
random_seed=42):
nested_test_indexes = []
with qp.util.temp_seed(random_seed):
for cat in test.classes_:
nested_test_indexes.append(list(test.artificial_sampling_index_generator(sample_size=sample_size,
category=cat,
n_prevalences=n_prevalences,
repeats=repeats)))
def _predict_batch(test_indexes):
return _ml_prevalence_predictions(model, test, test_indexes)
predictions = qp.util.parallel(_predict_batch, nested_test_indexes, n_jobs=-1)
true_prevs = list(itertools.chain.from_iterable(trues for trues, estims in predictions))
estim_prevs = list(itertools.chain.from_iterable(estims for trues, estims in predictions))
return true_prevs, estim_prevs
def ml_artificial_prevalence_evaluation(model,
test:MultilabelledCollection,
sample_size,
n_prevalences=21,
repeats=10,
error_metric:Union[str,Callable]='mae',
random_seed=42):
error_metric = check_error_str(error_metric)
true_prevs, estim_prevs = ml_artificial_prevalence_prediction(model, test, sample_size, n_prevalences, repeats, random_seed)
errs = [error_metric(true_prev_i, estim_prev_i) for true_prev_i, estim_prev_i in zip(true_prevs, estim_prevs)]
return np.mean(errs)
def _predict_quantification_batch(args):
model, test, indexes = args
return __predict_batch_fn(args, model.quantify)
def _predict_aggregative_batch(args):
model, test, indexes = args
return __predict_batch_fn(args, model.aggregate)
def __predict_batch_fn(args, quant_fn):
model, test, indexes = args
trues, estims = [], []
for index in indexes:
sample = test.sampling_from_index(index)
estims.append(quant_fn(sample.instances))
trues.append(sample.prevalence())
return trues, estims