forked from moreo/QuaPy
76 lines
2.8 KiB
Python
76 lines
2.8 KiB
Python
import argparse
|
|
from sklearn.calibration import CalibratedClassifierCV
|
|
from sklearn.linear_model import LogisticRegression
|
|
import itertools
|
|
from sklearn.multioutput import ClassifierChain
|
|
from tqdm import tqdm
|
|
from skmultilearn.dataset import load_dataset, available_data_sets
|
|
from scipy.sparse import csr_matrix
|
|
import quapy as qp
|
|
from MultiLabel.main import load_results
|
|
from MultiLabel.mlclassification import MultilabelStackedClassifier
|
|
from MultiLabel.mldata import MultilabelledCollection
|
|
from MultiLabel.mlquantification import MultilabelNaiveQuantifier, MLCC, MLPCC, MLRegressionQuantification, \
|
|
MLACC, \
|
|
MLPACC, MultilabelNaiveAggregativeQuantifier
|
|
from MultiLabel.tabular import Table
|
|
from method.aggregative import PACC, CC, EMQ, PCC, ACC, HDy
|
|
import numpy as np
|
|
from data.dataset import Dataset
|
|
from mlevaluation import ml_natural_prevalence_prediction, ml_artificial_prevalence_prediction, check_error_str
|
|
import sys
|
|
import os
|
|
import pickle
|
|
|
|
models = ['NaiveCC', 'NaivePCC', 'NaiveACC', 'NaivePACC', 'NaiveHDy', 'NaiveSLD']
|
|
datasets = sorted(set([x[0] for x in available_data_sets().keys()]))
|
|
|
|
|
|
def generate_table(path, protocol, error):
|
|
print(f'generating {path}')
|
|
table = Table(datasets, models)
|
|
for dataset, model in itertools.product(datasets, models):
|
|
result_path = f'{opt.results}/{dataset}_{model}.pkl'
|
|
if os.path.exists(result_path):
|
|
result = load_results(result_path)
|
|
true_prevs, estim_prevs = result[protocol]
|
|
scores = np.asarray([error(trues, estims) for trues, estims in zip(true_prevs, estim_prevs)]).flatten()
|
|
table.add(dataset, model, scores)
|
|
|
|
tabular = """
|
|
\\resizebox{\\textwidth}{!}{%
|
|
\\begin{tabular}{|c||""" + ('c|' * len(models)) + """} \hline
|
|
"""
|
|
dataset_replace = {'tmc2007_500': 'tmc2007\_500'}
|
|
method_replace = {}
|
|
|
|
tabular += table.latexTabular(benchmark_replace=dataset_replace, method_replace=method_replace)
|
|
tabular += """
|
|
\end{tabular}%
|
|
}
|
|
"""
|
|
with open(path, 'wt') as foo:
|
|
foo.write(tabular)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Experiments for multi-label quantification')
|
|
parser.add_argument('--results', type=str, default='./results', metavar='str',
|
|
help=f'path where to store the results')
|
|
parser.add_argument('--tablepath', type=str, default='./tables', metavar='str',
|
|
help=f'path where to store the tables')
|
|
opt = parser.parse_args()
|
|
|
|
os.makedirs(opt.results, exist_ok=True)
|
|
os.makedirs(opt.tablepath, exist_ok=True)
|
|
|
|
eval_error = qp.error.ae
|
|
generate_table(f'{opt.tablepath}/npp.ae.tex', protocol='npp', error=eval_error)
|
|
generate_table(f'{opt.tablepath}/app.ae.tex', protocol='app', error=eval_error)
|
|
|
|
|
|
|
|
|
|
|
|
|