1
0
Fork 0
QuaPy/LeQua2022/main_binary.py

67 lines
2.5 KiB
Python
Raw Normal View History

2021-10-13 20:36:53 +02:00
import pickle
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
import quapy as qp
from quapy.data import LabelledCollection
from quapy.method.aggregative import *
from data import load_binary_raw_document
import os
path_binary_raw = 'binary_raw'
result_path = os.path.join('results', 'binary_raw')
os.makedirs(result_path, exist_ok=True)
train_file = os.path.join(path_binary_raw, 'documents', 'training.txt')
train = LabelledCollection.load(train_file, load_binary_raw_document)
print(train.classes_)
print(len(train))
print(train.prevalence())
tfidf = TfidfVectorizer(min_df=5)
train.instances = tfidf.fit_transform(train.instances)
scores = {}
for quantifier in [CC, ACC, PCC, PACC, EMQ, HDy]:
classifier = CalibratedClassifierCV(LogisticRegression())
model = quantifier(classifier).fit(train)
quantifier_name = model.__class__.__name__
scores[quantifier_name]={}
for sample_set, sample_size in [('validation', 1000)]:#, ('test', 5000)]:
ae_errors, rae_errors = [], []
for i in tqdm(range(sample_size), total=sample_size, desc=f'testing {quantifier_name} in {sample_set}'):
test_file = os.path.join(path_binary_raw, 'documents', f'{sample_set}_{i}.txt')
test = LabelledCollection.load(test_file, load_binary_raw_document, classes=train.classes_)
test.instances = tfidf.transform(test.instances)
qp.environ['SAMPLE_SIZE'] = len(test)
prev_estim = model.quantify(test.instances)
prev_true = test.prevalence()
ae_errors.append(qp.error.mae(prev_true, prev_estim))
rae_errors.append(qp.error.mrae(prev_true, prev_estim))
ae_errors = np.asarray(ae_errors)
rae_errors = np.asarray(rae_errors)
mae = ae_errors.mean()
mrae = rae_errors.mean()
scores[quantifier_name][sample_set] = {'mae': mae, 'mrae': mrae}
pickle.dump(ae_errors, open(os.path.join(result_path, f'{quantifier_name}.{sample_set}.ae.pickle'), 'wb'), pickle.HIGHEST_PROTOCOL)
pickle.dump(rae_errors, open(os.path.join(result_path, f'{quantifier_name}.{sample_set}.rae.pickle'), 'wb'), pickle.HIGHEST_PROTOCOL)
print(f'{quantifier_name} {sample_set} MAE={mae:.4f}')
print(f'{quantifier_name} {sample_set} MRAE={mrae:.4f}')
for model in scores:
for sample_set in ['validation']:#, 'test']:
print(f'{model}\t{scores[model][sample_set]["mae"]:.4f}\t{scores[model][sample_set]["mrae"]:.4f}')
2021-10-14 12:28:11 +02:00
2021-10-13 20:36:53 +02:00