QuaPy/Retrieval/kdey_bandwith_selection_APP.py

78 lines
2.2 KiB
Python

import itertools
import os.path
import pickle
from collections import defaultdict
from pathlib import Path
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.svm import LinearSVC
import quapy as qp
from Retrieval.commons import RetrievedSamples, load_sample
from quapy.protocol import UPP
from quapy.method.non_aggregative import MaximumLikelihoodPrevalenceEstimation as Naive
from quapy.model_selection import GridSearchQ
from quapy.method.aggregative import ClassifyAndCount, EMQ, ACC, PCC, PACC, KDEyML
from quapy.data.base import LabelledCollection
from os.path import join
from tqdm import tqdm
from result_table.src.table import Table
"""
"""
data_home = 'data'
datasets = ['continent', 'gender', 'years_category'] #, 'relative_pageviews_category', 'num_sitelinks_category']
for class_name in datasets:
train_data_path = join(data_home, class_name, 'FULL', 'classifier_training.json') # <-------- fixed classifier
texts, labels = load_sample(train_data_path, class_name=class_name)
classifier_path = join('classifiers', 'FULL', f'classifier_{class_name}.pkl')
tfidf, classifier_trained = pickle.load(open(classifier_path, 'rb'))
classifier_hyper = classifier_trained.get_params()
print(f'{classifier_hyper=}')
X = tfidf.transform(texts)
print(f'Xtr shape={X.shape}')
pool = LabelledCollection(X, labels)
train, val = pool.split_stratified(train_prop=0.5, random_state=0)
q = KDEyML(LogisticRegression())
classifier_hyper = {'classifier__C':[classifier_hyper['C'], 0.00000001], 'classifier__class_weight':[classifier_hyper['class_weight']]}
quantifier_hyper = {'bandwidth': np.linspace(0.01, 0.2, 20)}
hyper = {**classifier_hyper, **quantifier_hyper}
qp.environ['SAMPLE_SIZE'] = 100
modsel = GridSearchQ(
model=q,
param_grid=hyper,
protocol=UPP(val, sample_size=100),
n_jobs=-1,
error='mrae',
verbose=True
)
modsel.fit(train)
print(class_name)
print(f'{modsel.best_params_}')
print(f'{modsel.best_score_}')