1
0
Fork 0
QuaPy/MultiLabel/gentables.py

119 lines
4.7 KiB
Python

import argparse
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
import itertools
from sklearn.multioutput import ClassifierChain
from tqdm import tqdm
from skmultilearn.dataset import load_dataset, available_data_sets
from scipy.sparse import csr_matrix
import quapy as qp
from MultiLabel.main import load_results, SKMULTILEARN_RED_DATASETS, TC_DATASETS, sample_size
from MultiLabel.mlclassification import MLStackedClassifier
from MultiLabel.mldata import MultilabelledCollection
from MultiLabel.mlquantification import MLNaiveQuantifier, MLCC, MLPCC, MLRegressionQuantification, \
MLACC, \
MLPACC, MLNaiveAggregativeQuantifier
from MultiLabel.tabular import Table
from method.aggregative import PACC, CC, EMQ, PCC, ACC, HDy
import numpy as np
from data.dataset import Dataset
from mlevaluation import ml_natural_prevalence_prediction, ml_artificial_prevalence_prediction, check_error_str
import sys
import os
import pickle
models = [#'MLPE',
'NaiveCC', 'NaivePCC', 'NaivePCCcal', 'NaiveACC', 'NaivePACC', 'NaivePACCcal', 'NaiveACCit', 'NaivePACCit',
#'NaiveHDy', 'NaiveSLD',
'ChainCC', 'ChainPCC', 'ChainACC', 'ChainPACC',
'StackCC', 'StackPCC', 'StackPCCcal', 'StackACC', 'StackPACC', 'StackPACCcal', 'StackACCit', 'StackP'
'ACCit',
'MRQ-CC', 'MRQ-PCC', 'MRQ-ACC', 'MRQ-PACC', 'MRQ-ACCit', 'MRQ-PACCit',
'StackMRQ-CC', 'StackMRQ-PCC', 'StackMRQ-ACC', 'StackMRQ-PACC',
'MRQ-StackCC', 'MRQ-StackPCC', 'MRQ-StackACC', 'MRQ-StackPACC',
'StackMRQ-StackCC', 'StackMRQ-StackPCC', 'StackMRQ-StackACC', 'StackMRQ-StackPACC',
'MRQ-StackCC-app', 'MRQ-StackPCC-app', 'MRQ-StackACC-app', 'MRQ-StackPACC-app',
'StackMRQ-StackCC-app', 'StackMRQ-StackPCC-app', 'StackMRQ-StackACC-app', 'StackMRQ-StackPACC-app',
'LSP-CC', 'LSP-ACC', 'MLKNN-CC', 'MLKNN-ACC',
'MLAdjustedC', 'MLStackAdjustedC', 'MLprobAdjustedC', 'MLStackProbAdjustedC'
]
# datasets = sorted(set([x[0] for x in available_data_sets().keys()]))
datasets = TC_DATASETS
def generate_table(path, protocol, error):
def compute_score_job(args):
dataset, model = args
result_path = f'{opt.results}/{dataset}_{model}.pkl'
if os.path.exists(result_path):
print('+', end='')
sys.stdout.flush()
result = load_results(result_path)
true_prevs, estim_prevs = result[protocol]
scores = np.asarray([error(trues, estims) for trues, estims in zip(true_prevs, estim_prevs)]).flatten()
return dataset, model, scores
print('-', end='')
sys.stdout.flush()
return None
print(f'\ngenerating {path}')
table = Table(datasets, models, prec_mean=4, significance_test='wilcoxon')
results = qp.util.parallel(compute_score_job, list(itertools.product(datasets, models)), n_jobs=-1)
print()
for r in results:
if r is not None:
dataset, model, scores = r
table.add(dataset, model, scores)
save_table(table, path)
save_table(table.getRankTable(), path.replace('.tex','.rank.tex'))
def save_table(table, path):
tabular = """
\\resizebox{\\textwidth}{!}{%
\\begin{tabular}{|c||""" + ('c|' * len(models)) + """} \hline
"""
dataset_replace = {'tmc2007_500': 'tmc2007\_500', 'tmc2007_500-red': 'tmc2007\_500-red'}
method_replace = {}
tabular += table.latexTabularT(benchmark_replace=dataset_replace, method_replace=method_replace, side=True)
tabular += """
\end{tabular}%
}
"""
with open(path, 'wt') as foo:
foo.write(tabular)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Experiments for multi-label quantification')
parser.add_argument('--results', type=str, default='./results', metavar='str',
help=f'path where to store the results')
parser.add_argument('--tablepath', type=str, default='./tables', metavar='str',
help=f'path where to store the tables')
opt = parser.parse_args()
assert os.path.exists(opt.results), f'result directory {opt.results} does not exist'
os.makedirs(opt.tablepath, exist_ok=True)
qp.environ["SAMPLE_SIZE"] = sample_size
absolute_error = qp.error.ae
relative_absolute_error = qp.error.rae
generate_table(f'{opt.tablepath}/npp.ae.tex', protocol='npp', error=absolute_error)
generate_table(f'{opt.tablepath}/app.ae.tex', protocol='app', error=absolute_error)
generate_table(f'{opt.tablepath}/npp.rae.tex', protocol='npp', error=relative_absolute_error)
generate_table(f'{opt.tablepath}/app.rae.tex', protocol='app', error=relative_absolute_error)