forked from moreo/QuaPy
added Ensemble methods (methods ALL, ACC, Ptr, DS from Pérez-Gallego et al 2017 and 2019) and some UCI ML datasets used in those articles (only 5 datasets out of 32 they used)
This commit is contained in:
parent
d8e2f7556e
commit
326a8ab803
16
TODO.txt
16
TODO.txt
|
@ -9,4 +9,18 @@ negative class). This is not covered in this new implementation, in which the bi
|
||||||
an instance of single-label with 2 labels. Check
|
an instance of single-label with 2 labels. Check
|
||||||
Add classnames to LabelledCollection ?
|
Add classnames to LabelledCollection ?
|
||||||
Check the overhead in OneVsAll for SVMperf-based (?)
|
Check the overhead in OneVsAll for SVMperf-based (?)
|
||||||
|
Add HDy to QuaNet? if so, wrap HDy into OneVsAll in case the dataset is not binary.
|
||||||
|
Plots (one for binary -- the "diagonal", or for a specific class), another for the error as a funcition of drift.
|
||||||
|
Add datasets for topic.
|
||||||
|
Add other methods
|
||||||
|
Clarify whether QuaNet is an aggregative method or not.
|
||||||
|
Add datasets from Pérez-Gallego et al. 2017, 2019
|
||||||
|
Add ensemble models from Pérez-Gallego et al. 2017, 2019
|
||||||
|
Add plots models like those in Pérez-Gallego et al. 2017 (error boxes)
|
||||||
|
Add support for CV prediction in ACC and PACC for tpr, fpr
|
||||||
|
Add medium swap method
|
||||||
|
Explore the hyperparameter "number of bins" in HDy
|
||||||
|
Implement HDy for single-label?
|
||||||
|
Rename EMQ to SLD ?
|
||||||
|
How many times is the system of equations for ACC and PACC not solved? How many times is it clipped? Do they sum up
|
||||||
|
to one always?
|
|
@ -1,8 +1,8 @@
|
||||||
from . import data
|
from . import error
|
||||||
from .data import datasets
|
from .data import datasets
|
||||||
from . import functional
|
from . import functional
|
||||||
from . import method
|
from . import method
|
||||||
from . import error
|
from . import data
|
||||||
from . import evaluation
|
from . import evaluation
|
||||||
from method.aggregative import isaggregative, isprobabilistic
|
from method.aggregative import isaggregative, isprobabilistic
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
from sklearn.decomposition import TruncatedSVD
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
|
||||||
|
|
||||||
|
class PCALR:
|
||||||
|
|
||||||
|
def __init__(self, n_components=300, C=10, class_weight=None):
|
||||||
|
self.n_components = n_components
|
||||||
|
self.learner = LogisticRegression(C=C, class_weight=class_weight, max_iter=1000)
|
||||||
|
|
||||||
|
def get_params(self):
|
||||||
|
params = {'n_components': self.n_components}
|
||||||
|
params.update(self.learner.get_params())
|
||||||
|
return params
|
||||||
|
|
||||||
|
def set_params(self, **params):
|
||||||
|
if 'n_components' in params:
|
||||||
|
self.n_components = params['n_components']
|
||||||
|
del params['n_components']
|
||||||
|
self.learner.set_params(**params)
|
||||||
|
|
||||||
|
def fit(self, documents, labels):
|
||||||
|
self.pca = TruncatedSVD(self.n_components)
|
||||||
|
embedded = self.pca.fit_transform(documents, labels)
|
||||||
|
self.learner.fit(embedded, labels)
|
||||||
|
self.classes_ = self.learner.classes_
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, documents):
|
||||||
|
embedded = self.transform(documents)
|
||||||
|
return self.learner.predict(embedded)
|
||||||
|
|
||||||
|
def predict_proba(self, documents):
|
||||||
|
embedded = self.transform(documents)
|
||||||
|
return self.learner.predict_proba(embedded)
|
||||||
|
|
||||||
|
def transform(self, documents):
|
||||||
|
return self.pca.transform(documents)
|
|
@ -1,11 +1,9 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.sparse import issparse
|
from scipy.sparse import issparse
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from quapy.functional import artificial_prevalence_sampling
|
from quapy.functional import artificial_prevalence_sampling, strprev
|
||||||
from scipy.sparse import vstack
|
from scipy.sparse import vstack
|
||||||
|
|
||||||
from util import temp_seed
|
|
||||||
|
|
||||||
|
|
||||||
class LabelledCollection:
|
class LabelledCollection:
|
||||||
|
|
||||||
|
@ -130,6 +128,21 @@ class LabelledCollection:
|
||||||
labels = np.concatenate([self.labels, other.labels])
|
labels = np.concatenate([self.labels, other.labels])
|
||||||
return LabelledCollection(join_instances, labels)
|
return LabelledCollection(join_instances, labels)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def Xy(self):
|
||||||
|
return self.instances, self.labels
|
||||||
|
|
||||||
|
def stats(self):
|
||||||
|
ninstances = len(self)
|
||||||
|
instance_type = type(self.instances[0])
|
||||||
|
if instance_type == list:
|
||||||
|
nfeats = len(self.instances[0])
|
||||||
|
elif instance_type == np.ndarray:
|
||||||
|
nfeats = self.instances.shape[1]
|
||||||
|
else:
|
||||||
|
nfeats = '?'
|
||||||
|
print(f'#instances={ninstances}, type={instance_type}, features={nfeats}, n_classes={self.n_classes}, '
|
||||||
|
f'prevs={strprev(self.prevalence())}')
|
||||||
|
|
||||||
|
|
||||||
class Dataset:
|
class Dataset:
|
||||||
|
@ -153,7 +166,7 @@ class Dataset:
|
||||||
return self.training.binary
|
return self.training.binary
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, train_path, test_path, loader_func:callable):
|
def load(cls, train_path, test_path, loader_func: callable):
|
||||||
training = LabelledCollection.load(train_path, loader_func)
|
training = LabelledCollection.load(train_path, loader_func)
|
||||||
test = LabelledCollection.load(test_path, loader_func)
|
test = LabelledCollection.load(test_path, loader_func)
|
||||||
return Dataset(training, test)
|
return Dataset(training, test)
|
||||||
|
|
|
@ -2,13 +2,15 @@ import zipfile
|
||||||
from util import download_file_if_not_exists, download_file, get_quapy_home, pickled_resource
|
from util import download_file_if_not_exists, download_file, get_quapy_home, pickled_resource
|
||||||
import os
|
import os
|
||||||
from os.path import join
|
from os.path import join
|
||||||
from data.base import Dataset
|
from data.base import Dataset, LabelledCollection
|
||||||
from data.reader import from_text, from_sparse
|
from data.reader import *
|
||||||
from data.preprocessing import text2tfidf, reduce_columns
|
from data.preprocessing import text2tfidf, reduce_columns
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
REVIEWS_SENTIMENT_DATASETS = ['hp', 'kindle', 'imdb']
|
REVIEWS_SENTIMENT_DATASETS = ['hp', 'kindle', 'imdb']
|
||||||
TWITTER_SENTIMENT_DATASETS = ['gasp', 'hcr', 'omd', 'sanders', 'semeval13', 'semeval14', 'semeval15', 'semeval16',
|
TWITTER_SENTIMENT_DATASETS = ['gasp', 'hcr', 'omd', 'sanders',
|
||||||
|
'semeval13', 'semeval14', 'semeval15', 'semeval16',
|
||||||
'sst', 'wa', 'wb']
|
'sst', 'wa', 'wb']
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,4 +119,88 @@ def fetch_twitter(dataset_name, for_model_selection=False, min_df=None, data_hom
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
UCI_DATASETS = ['acute.a', 'acute.b',
|
||||||
|
'balance.1', 'balance.2', 'balance.3']
|
||||||
|
|
||||||
|
def fetch_UCIDataset(dataset_name, data_home=None, verbose=False):
|
||||||
|
|
||||||
|
assert dataset_name in UCI_DATASETS, \
|
||||||
|
f'Name {dataset_name} does not match any known dataset from the UCI Machine Learning datasets repository. ' \
|
||||||
|
f'Valid ones are {UCI_DATASETS}'
|
||||||
|
if data_home is None:
|
||||||
|
data_home = get_quapy_home()
|
||||||
|
|
||||||
|
identifier_map = {
|
||||||
|
'acute.a': 'acute',
|
||||||
|
'acute.b': 'acute',
|
||||||
|
'balance.1': 'balance-scale',
|
||||||
|
'balance.2': 'balance-scale',
|
||||||
|
'balance.3': 'balance-scale',
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset_fullname = {
|
||||||
|
'acute.a': 'Acute Inflammations (urinary bladder)',
|
||||||
|
'acute.b': 'Acute Inflammations (renal pelvis)',
|
||||||
|
'balance.1': 'Balance Scale Weight & Distance Database (left)',
|
||||||
|
'balance.2': 'Balance Scale Weight & Distance Database (balanced)',
|
||||||
|
'balance.3': 'Balance Scale Weight & Distance Database (right)',
|
||||||
|
}
|
||||||
|
|
||||||
|
data_folder = {
|
||||||
|
'acute': 'diagnosis',
|
||||||
|
'balance-scale': 'balance-scale',
|
||||||
|
}
|
||||||
|
|
||||||
|
identifier = identifier_map[dataset_name]
|
||||||
|
URL = f'http://archive.ics.uci.edu/ml/machine-learning-databases/{identifier}'
|
||||||
|
data_path = join(data_home, 'uci_datasets', identifier)
|
||||||
|
download_file_if_not_exists(f'{URL}/{data_folder[identifier]}.data', f'{data_path}/{identifier}.data')
|
||||||
|
download_file_if_not_exists(f'{URL}/{data_folder[identifier]}.names', f'{data_path}/{identifier}.names')
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(open(f'{data_path}/{identifier}.names', 'rt').read())
|
||||||
|
|
||||||
|
print(f'Loading {dataset_name} ({dataset_fullname[dataset_name]})')
|
||||||
|
if identifier == 'acute':
|
||||||
|
df = pd.read_csv(f'{data_path}/{identifier}.data', header=None, encoding='utf-16', sep='\t')
|
||||||
|
if dataset_name == 'acute.a':
|
||||||
|
y = binarize(df[6], pos_class='yes')
|
||||||
|
elif dataset_name == 'acute.b':
|
||||||
|
y = binarize(df[7], pos_class='yes')
|
||||||
|
|
||||||
|
mintemp, maxtemp = 35, 42
|
||||||
|
df[0] = df[0].apply(lambda x:(float(x.replace(',','.'))-mintemp)/(maxtemp-mintemp)).astype(float, copy=False)
|
||||||
|
[df_replace(df, col) for col in range(1, 6)]
|
||||||
|
X = df.loc[:, 0:5].values
|
||||||
|
|
||||||
|
if identifier == 'balance-scale':
|
||||||
|
df = pd.read_csv(f'{data_path}/{identifier}.data', header=None, sep=',')
|
||||||
|
if dataset_name == 'balance.1':
|
||||||
|
y = binarize(df[0], pos_class='L')
|
||||||
|
elif dataset_name == 'balance.2':
|
||||||
|
y = binarize(df[0], pos_class='B')
|
||||||
|
elif dataset_name == 'balance.3':
|
||||||
|
y = binarize(df[0], pos_class='R')
|
||||||
|
X = df.loc[:, 1:].astype(float).values
|
||||||
|
|
||||||
|
data = LabelledCollection(X, y)
|
||||||
|
data.stats()
|
||||||
|
#print(df)
|
||||||
|
#print(df.loc[:, 0:5].values)
|
||||||
|
#print(y)
|
||||||
|
|
||||||
|
# X = __read_csv(f'{data_path}/{identifier}.data', separator='\t')
|
||||||
|
# print(X)
|
||||||
|
|
||||||
|
#X, y = from_csv(f'{data_path}/{dataset_name}.data')
|
||||||
|
#y, classnames = reindex_labels(y)
|
||||||
|
|
||||||
|
|
||||||
|
#def __read_csv(path, separator=','):
|
||||||
|
# x = []
|
||||||
|
# for instance in tqdm(open(path, 'rt', encoding='utf-16').readlines(), desc=f'reading {path}'):
|
||||||
|
# x.append(instance.strip().split(separator))
|
||||||
|
# return x
|
||||||
|
|
||||||
|
def df_replace(df, col, repl={'yes': 1, 'no':0}, astype=float):
|
||||||
|
df[col] = df[col].apply(lambda x:repl[x]).astype(astype, copy=False)
|
|
@ -1,6 +1,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.sparse import dok_matrix
|
from scipy.sparse import dok_matrix
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
def from_text(path):
|
def from_text(path):
|
||||||
|
@ -55,3 +56,42 @@ def from_sparse(path):
|
||||||
y = np.asarray(all_labels) + 1
|
y = np.asarray(all_labels) + 1
|
||||||
return X, y
|
return X, y
|
||||||
|
|
||||||
|
|
||||||
|
def from_csv(path):
|
||||||
|
"""
|
||||||
|
Reas a csv file in which columns are separated by ','.
|
||||||
|
File fomart <label>,<feat1>,<feat2>,...,<featn>\n
|
||||||
|
:param path: path to the csv file
|
||||||
|
:return: a ndarray for the labels and a ndarray (float) for the covariates
|
||||||
|
"""
|
||||||
|
|
||||||
|
X, y = [], []
|
||||||
|
for instance in tqdm(open(path, 'rt').readlines(), desc=f'reading {path}'):
|
||||||
|
yi, *xi = instance.strip().split(',')
|
||||||
|
X.append(list(map(float,xi)))
|
||||||
|
y.append(yi)
|
||||||
|
X = np.asarray(X)
|
||||||
|
y = np.asarray(y)
|
||||||
|
return X, y
|
||||||
|
|
||||||
|
|
||||||
|
def reindex_labels(y):
|
||||||
|
"""
|
||||||
|
Re-indexes a list of labels as a list of indexes, and returns the classnames corresponding to the indexes.
|
||||||
|
E.g., y=['B', 'B', 'A', 'C'] -> [1,1,0,2], ['A','B','C']
|
||||||
|
:param y: the list or array of original labels
|
||||||
|
:return: a ndarray (int) of class indexes, and a ndarray of classnames corresponding to the indexes.
|
||||||
|
"""
|
||||||
|
classnames = sorted(np.unique(y))
|
||||||
|
label2index = {label: index for index, label in enumerate(classnames)}
|
||||||
|
indexed = np.empty(y.shape, dtype=np.int)
|
||||||
|
for label in classnames:
|
||||||
|
indexed[y==label] = label2index[label]
|
||||||
|
return indexed, classnames
|
||||||
|
|
||||||
|
|
||||||
|
def binarize(y, pos_class):
|
||||||
|
y = np.asarray(y)
|
||||||
|
ybin = np.zeros(y.shape, dtype=np.int)
|
||||||
|
ybin[y == pos_class] = 1
|
||||||
|
return ybin
|
|
@ -77,6 +77,9 @@ def __check_eps(eps):
|
||||||
|
|
||||||
CLASSIFICATION_ERROR = {f1e, acce}
|
CLASSIFICATION_ERROR = {f1e, acce}
|
||||||
QUANTIFICATION_ERROR = {mae, mrae, mse, mkld, mnkld}
|
QUANTIFICATION_ERROR = {mae, mrae, mse, mkld, mnkld}
|
||||||
|
CLASSIFICATION_ERROR_NAMES = {func.__name__ for func in CLASSIFICATION_ERROR}
|
||||||
|
QUANTIFICATION_ERROR_NAMES = {func.__name__ for func in QUANTIFICATION_ERROR}
|
||||||
|
ERROR_NAMES = CLASSIFICATION_ERROR_NAMES | QUANTIFICATION_ERROR_NAMES
|
||||||
|
|
||||||
f1_error = f1e
|
f1_error = f1e
|
||||||
acc_error = acce
|
acc_error = acce
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
|
from typing import Union, Callable, Iterable
|
||||||
from data import LabelledCollection
|
from data import LabelledCollection
|
||||||
from quapy.method.aggregative import AggregativeQuantifier, AggregativeProbabilisticQuantifier
|
from method.aggregative import AggregativeQuantifier, AggregativeProbabilisticQuantifier
|
||||||
from method.base import BaseQuantifier
|
from method.base import BaseQuantifier
|
||||||
from util import temp_seed
|
from util import temp_seed
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from joblib import Parallel, delayed
|
from joblib import Parallel, delayed
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import error
|
||||||
|
|
||||||
|
|
||||||
def artificial_sampling_prediction(
|
def artificial_sampling_prediction(
|
||||||
|
@ -64,5 +66,19 @@ def artificial_sampling_prediction(
|
||||||
return true_prevalences, estim_prevalences
|
return true_prevalences, estim_prevalences
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(model: BaseQuantifier, test_samples:Iterable[LabelledCollection], err:Union[str, Callable], n_jobs:int=-1):
|
||||||
|
if isinstance(err, str):
|
||||||
|
err = getattr(error, err)
|
||||||
|
assert err.__name__ in error.QUANTIFICATION_ERROR_NAMES, \
|
||||||
|
f'error={err} does not seem to be a quantification error'
|
||||||
|
scores = Parallel(n_jobs=n_jobs)(
|
||||||
|
delayed(_delayed_eval)(model, Ti, err) for Ti in test_samples
|
||||||
|
)
|
||||||
|
return np.mean(scores)
|
||||||
|
|
||||||
|
|
||||||
|
def _delayed_eval(model:BaseQuantifier, test:LabelledCollection, error:Callable):
|
||||||
|
prev_estim = model.quantify(test.instances)
|
||||||
|
prev_true = test.prevalence()
|
||||||
|
return error(prev_true, prev_estim)
|
||||||
|
|
||||||
|
|
|
@ -57,6 +57,37 @@ def prevalence_from_probabilities(posteriors, binarize: bool = False):
|
||||||
return prevalences
|
return prevalences
|
||||||
|
|
||||||
|
|
||||||
|
def HellingerDistance(P, Q):
|
||||||
|
return np.sqrt(np.sum((np.sqrt(P) - np.sqrt(Q))**2))
|
||||||
|
|
||||||
|
|
||||||
|
#def uniform_simplex_sampling(n_classes):
|
||||||
|
# from https://cs.stackexchange.com/questions/3227/uniform-sampling-from-a-simplex
|
||||||
|
# r = [0.] + sorted(np.random.rand(n_classes-1)) + [1.]
|
||||||
|
# return np.asarray([b-a for a,b in zip(r[:-1],r[1:])])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def uniform_prevalence_sampling(n_classes, size=1):
|
||||||
|
if n_classes == 2:
|
||||||
|
u = np.random.rand(size)
|
||||||
|
u = np.vstack([1-u, u]).T
|
||||||
|
else:
|
||||||
|
# from https://cs.stackexchange.com/questions/3227/uniform-sampling-from-a-simplex
|
||||||
|
u = np.random.rand(size, n_classes-1)
|
||||||
|
u.sort(axis=-1)
|
||||||
|
_0s = np.zeros(shape=(size, 1))
|
||||||
|
_1s = np.ones(shape=(size, 1))
|
||||||
|
a = np.hstack([_0s, u])
|
||||||
|
b = np.hstack([u, _1s])
|
||||||
|
u = b-a
|
||||||
|
if size == 1:
|
||||||
|
u = u.flatten()
|
||||||
|
return u
|
||||||
|
#return np.asarray([uniform_simplex_sampling(n_classes) for _ in range(size)])
|
||||||
|
|
||||||
|
uniform_simplex_sampling = uniform_prevalence_sampling
|
||||||
|
|
||||||
def strprev(prevalences, prec=3):
|
def strprev(prevalences, prec=3):
|
||||||
return '['+ ', '.join([f'{p:.{prec}f}' for p in prevalences]) + ']'
|
return '['+ ', '.join([f'{p:.{prec}f}' for p in prevalences]) + ']'
|
||||||
|
|
||||||
|
@ -72,14 +103,17 @@ def adjusted_quantification(prevalence_estim, tpr, fpr, clip=True):
|
||||||
|
|
||||||
|
|
||||||
def normalize_prevalence(prevalences):
|
def normalize_prevalence(prevalences):
|
||||||
assert prevalences.ndim==1, 'unexpected shape'
|
prevalences = np.asarray(prevalences)
|
||||||
accum = prevalences.sum()
|
n_classes = prevalences.shape[-1]
|
||||||
if accum > 0:
|
accum = prevalences.sum(axis=-1, keepdims=True)
|
||||||
return prevalences / accum
|
prevalences = np.true_divide(prevalences, accum, where=accum>0)
|
||||||
else:
|
allzeros = accum.flatten()==0
|
||||||
# if all classifiers are trivial rejectors
|
if any(allzeros):
|
||||||
return np.ones_like(prevalences) / prevalences.size
|
if prevalences.ndim == 1:
|
||||||
|
prevalences = np.full(shape=n_classes, fill_value=1./n_classes)
|
||||||
|
else:
|
||||||
|
prevalences[accum.flatten()==0] = np.full(shape=n_classes, fill_value=1./n_classes)
|
||||||
|
return prevalences
|
||||||
|
|
||||||
|
|
||||||
def num_prevalence_combinations(n_prevpoints:int, n_classes:int, n_repeats:int=1):
|
def num_prevalence_combinations(n_prevpoints:int, n_classes:int, n_repeats:int=1):
|
||||||
|
|
|
@ -1,23 +1,28 @@
|
||||||
from . import base
|
from . import base
|
||||||
from . import aggregative as agg
|
from . import aggregative
|
||||||
from . import non_aggregative
|
from . import non_aggregative
|
||||||
|
from . import meta
|
||||||
|
|
||||||
|
|
||||||
AGGREGATIVE_METHODS = {
|
AGGREGATIVE_METHODS = {
|
||||||
agg.ClassifyAndCount,
|
aggregative.ClassifyAndCount,
|
||||||
agg.AdjustedClassifyAndCount,
|
aggregative.AdjustedClassifyAndCount,
|
||||||
agg.ProbabilisticClassifyAndCount,
|
aggregative.ProbabilisticClassifyAndCount,
|
||||||
agg.ProbabilisticAdjustedClassifyAndCount,
|
aggregative.ProbabilisticAdjustedClassifyAndCount,
|
||||||
agg.ExplicitLossMinimisation,
|
aggregative.ExplicitLossMinimisation,
|
||||||
agg.ExpectationMaximizationQuantifier,
|
aggregative.ExpectationMaximizationQuantifier,
|
||||||
agg.HellingerDistanceY
|
aggregative.HellingerDistanceY
|
||||||
}
|
}
|
||||||
|
|
||||||
NON_AGGREGATIVE_METHODS = {
|
NON_AGGREGATIVE_METHODS = {
|
||||||
non_aggregative.MaximumLikelihoodPrevalenceEstimation
|
non_aggregative.MaximumLikelihoodPrevalenceEstimation
|
||||||
}
|
}
|
||||||
|
|
||||||
QUANTIFICATION_METHODS = AGGREGATIVE_METHODS | NON_AGGREGATIVE_METHODS
|
META_METHODS = {
|
||||||
|
meta.QuaNet
|
||||||
|
}
|
||||||
|
|
||||||
|
QUANTIFICATION_METHODS = AGGREGATIVE_METHODS | NON_AGGREGATIVE_METHODS | META_METHODS
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import numpy as np
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import functional as F
|
import functional as F
|
||||||
import error
|
import error
|
||||||
from method.base import BaseQuantifier
|
from method.base import BaseQuantifier, BinaryQuantifier
|
||||||
from classification.svmperf import SVMperf
|
from classification.svmperf import SVMperf
|
||||||
from data import LabelledCollection
|
from data import LabelledCollection
|
||||||
from sklearn.metrics import confusion_matrix
|
from sklearn.metrics import confusion_matrix
|
||||||
|
@ -22,7 +22,7 @@ class AggregativeQuantifier(BaseQuantifier):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fit(self, data: LabelledCollection, fit_learner=True, *args): ...
|
def fit(self, data: LabelledCollection, fit_learner=True): ...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def learner(self):
|
def learner(self):
|
||||||
|
@ -35,12 +35,12 @@ class AggregativeQuantifier(BaseQuantifier):
|
||||||
def classify(self, instances):
|
def classify(self, instances):
|
||||||
return self.learner.predict(instances)
|
return self.learner.predict(instances)
|
||||||
|
|
||||||
def quantify(self, instances, *args):
|
def quantify(self, instances):
|
||||||
classif_predictions = self.classify(instances)
|
classif_predictions = self.classify(instances)
|
||||||
return self.aggregate(classif_predictions, *args)
|
return self.aggregate(classif_predictions)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def aggregate(self, classif_predictions:np.ndarray, *args): ...
|
def aggregate(self, classif_predictions:np.ndarray): ...
|
||||||
|
|
||||||
def get_params(self, deep=True):
|
def get_params(self, deep=True):
|
||||||
return self.learner.get_params()
|
return self.learner.get_params()
|
||||||
|
@ -68,9 +68,9 @@ class AggregativeProbabilisticQuantifier(AggregativeQuantifier):
|
||||||
def posterior_probabilities(self, data):
|
def posterior_probabilities(self, data):
|
||||||
return self.learner.predict_proba(data)
|
return self.learner.predict_proba(data)
|
||||||
|
|
||||||
def quantify(self, instances, *args):
|
def quantify(self, instances):
|
||||||
classif_posteriors = self.posterior_probabilities(instances)
|
classif_posteriors = self.posterior_probabilities(instances)
|
||||||
return self.aggregate(classif_posteriors, *args)
|
return self.aggregate(classif_posteriors)
|
||||||
|
|
||||||
def set_params(self, **parameters):
|
def set_params(self, **parameters):
|
||||||
if isinstance(self.learner, CalibratedClassifierCV):
|
if isinstance(self.learner, CalibratedClassifierCV):
|
||||||
|
@ -78,11 +78,6 @@ class AggregativeProbabilisticQuantifier(AggregativeQuantifier):
|
||||||
self.learner.set_params(**parameters)
|
self.learner.set_params(**parameters)
|
||||||
|
|
||||||
|
|
||||||
class BinaryQuantifier(BaseQuantifier):
|
|
||||||
def _check_binary(self, data : LabelledCollection, quantifier_name):
|
|
||||||
assert data.binary, f'{quantifier_name} works only on problems of binary classification. ' \
|
|
||||||
f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.'
|
|
||||||
|
|
||||||
|
|
||||||
# Helper
|
# Helper
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
|
@ -144,18 +139,17 @@ class ClassifyAndCount(AggregativeQuantifier):
|
||||||
def __init__(self, learner):
|
def __init__(self, learner):
|
||||||
self.learner = learner
|
self.learner = learner
|
||||||
|
|
||||||
def fit(self, data: LabelledCollection, fit_learner=True, *args):
|
def fit(self, data: LabelledCollection, fit_learner=True):
|
||||||
"""
|
"""
|
||||||
Trains the Classify & Count method unless _fit_learner_ is False, in which case it is assumed to be already fit.
|
Trains the Classify & Count method unless _fit_learner_ is False, in which case it is assumed to be already fit.
|
||||||
:param data: training data
|
:param data: training data
|
||||||
:param fit_learner: if False, the classifier is assumed to be fit
|
:param fit_learner: if False, the classifier is assumed to be fit
|
||||||
:param args: unused
|
|
||||||
:return: self
|
:return: self
|
||||||
"""
|
"""
|
||||||
self.learner, _ = training_helper(self.learner, data, fit_learner)
|
self.learner, _ = training_helper(self.learner, data, fit_learner)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def aggregate(self, classif_predictions, *args):
|
def aggregate(self, classif_predictions):
|
||||||
return F.prevalence_from_labels(classif_predictions, self.n_classes)
|
return F.prevalence_from_labels(classif_predictions, self.n_classes)
|
||||||
|
|
||||||
|
|
||||||
|
@ -186,7 +180,7 @@ class AdjustedClassifyAndCount(AggregativeQuantifier):
|
||||||
def classify(self, data):
|
def classify(self, data):
|
||||||
return self.cc.classify(data)
|
return self.cc.classify(data)
|
||||||
|
|
||||||
def aggregate(self, classif_predictions, *args):
|
def aggregate(self, classif_predictions):
|
||||||
prevs_estim = self.cc.aggregate(classif_predictions)
|
prevs_estim = self.cc.aggregate(classif_predictions)
|
||||||
return AdjustedClassifyAndCount.solve_adjustment(self.Pte_cond_estim_, prevs_estim)
|
return AdjustedClassifyAndCount.solve_adjustment(self.Pte_cond_estim_, prevs_estim)
|
||||||
|
|
||||||
|
@ -208,11 +202,11 @@ class ProbabilisticClassifyAndCount(AggregativeProbabilisticQuantifier):
|
||||||
def __init__(self, learner):
|
def __init__(self, learner):
|
||||||
self.learner = learner
|
self.learner = learner
|
||||||
|
|
||||||
def fit(self, data : LabelledCollection, fit_learner=True, *args):
|
def fit(self, data : LabelledCollection, fit_learner=True):
|
||||||
self.learner, _ = training_helper(self.learner, data, fit_learner, ensure_probabilistic=True)
|
self.learner, _ = training_helper(self.learner, data, fit_learner, ensure_probabilistic=True)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def aggregate(self, classif_posteriors, *args):
|
def aggregate(self, classif_posteriors):
|
||||||
return F.prevalence_from_probabilities(classif_posteriors, binarize=False)
|
return F.prevalence_from_probabilities(classif_posteriors, binarize=False)
|
||||||
|
|
||||||
|
|
||||||
|
@ -235,14 +229,22 @@ class ProbabilisticAdjustedClassifyAndCount(AggregativeProbabilisticQuantifier):
|
||||||
self.learner, data, fit_learner, ensure_probabilistic=True, val_split=val_split
|
self.learner, data, fit_learner, ensure_probabilistic=True, val_split=val_split
|
||||||
)
|
)
|
||||||
self.pcc = ProbabilisticClassifyAndCount(self.learner)
|
self.pcc = ProbabilisticClassifyAndCount(self.learner)
|
||||||
y_ = self.classify(validation.instances)
|
y_ = self.soft_classify(validation.instances)
|
||||||
y = validation.labels
|
y = validation.labels
|
||||||
|
confusion = np.empty(shape=(data.n_classes, data.n_classes))
|
||||||
|
for yi in range(data.n_classes):
|
||||||
|
confusion[yi] = y_[y==yi].mean(axis=0)
|
||||||
|
|
||||||
|
self.Pte_cond_estim_ = confusion.T
|
||||||
|
|
||||||
|
#y_ = self.classify(validation.instances)
|
||||||
|
#y = validation.labels
|
||||||
# estimate the matrix with entry (i,j) being the estimate of P(yi|yj), that is, the probability that a
|
# estimate the matrix with entry (i,j) being the estimate of P(yi|yj), that is, the probability that a
|
||||||
# document that belongs to yj ends up being classified as belonging to yi
|
# document that belongs to yj ends up being classified as belonging to yi
|
||||||
self.Pte_cond_estim_ = confusion_matrix(y, y_).T / validation.counts()
|
#self.Pte_cond_estim_ = confusion_matrix(y, y_).T / validation.counts()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def aggregate(self, classif_posteriors, *args):
|
def aggregate(self, classif_posteriors):
|
||||||
prevs_estim = self.pcc.aggregate(classif_posteriors)
|
prevs_estim = self.pcc.aggregate(classif_posteriors)
|
||||||
return AdjustedClassifyAndCount.solve_adjustment(self.Pte_cond_estim_, prevs_estim)
|
return AdjustedClassifyAndCount.solve_adjustment(self.Pte_cond_estim_, prevs_estim)
|
||||||
|
|
||||||
|
@ -261,7 +263,7 @@ class ExpectationMaximizationQuantifier(AggregativeProbabilisticQuantifier):
|
||||||
def __init__(self, learner):
|
def __init__(self, learner):
|
||||||
self.learner = learner
|
self.learner = learner
|
||||||
|
|
||||||
def fit(self, data: LabelledCollection, fit_learner=True, *args):
|
def fit(self, data: LabelledCollection, fit_learner=True):
|
||||||
self.learner, _ = training_helper(self.learner, data, fit_learner, ensure_probabilistic=True)
|
self.learner, _ = training_helper(self.learner, data, fit_learner, ensure_probabilistic=True)
|
||||||
self.train_prevalence = F.prevalence_from_labels(data.labels, self.n_classes)
|
self.train_prevalence = F.prevalence_from_labels(data.labels, self.n_classes)
|
||||||
return self
|
return self
|
||||||
|
@ -320,20 +322,20 @@ class HellingerDistanceY(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
||||||
self._check_binary(data, self.__class__.__name__)
|
self._check_binary(data, self.__class__.__name__)
|
||||||
self.learner, validation = training_helper(
|
self.learner, validation = training_helper(
|
||||||
self.learner, data, fit_learner, ensure_probabilistic=True, val_split=val_split)
|
self.learner, data, fit_learner, ensure_probabilistic=True, val_split=val_split)
|
||||||
Px = self.posterior_probabilities(validation.instances)
|
Px = self.posterior_probabilities(validation.instances)[:,1] # takes only the P(y=+1|x)
|
||||||
self.Pxy1 = Px[validation.labels == 1]
|
self.Pxy1 = Px[validation.labels == 1]
|
||||||
self.Pxy0 = Px[validation.labels == 0]
|
self.Pxy0 = Px[validation.labels == 0]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def aggregate(self, classif_posteriors, *args):
|
def aggregate(self, classif_posteriors):
|
||||||
# "In this work, the number of bins b used in HDx and HDy was chosen from 10 to 110 in steps of 10,
|
# "In this work, the number of bins b used in HDx and HDy was chosen from 10 to 110 in steps of 10,
|
||||||
# and the final estimated a priori probability was taken as the median of these 11 estimates."
|
# and the final estimated a priori probability was taken as the median of these 11 estimates."
|
||||||
# (González-Castro, et al., 2013).
|
# (González-Castro, et al., 2013).
|
||||||
|
|
||||||
Px = classif_posteriors
|
Px = classif_posteriors[:,1] # takes only the P(y=+1|x)
|
||||||
|
|
||||||
prev_estimations = []
|
prev_estimations = []
|
||||||
for bins in np.linspace(10, 110, 11, dtype=int): #[10, 20, 30, ..., 100, 110]
|
for bins in np.linspace(10, 110, 11, dtype=int): #[10, 20, 30, ..., 100, 110]
|
||||||
Pxy0_density, _ = np.histogram(self.Pxy0, bins=bins, range=(0, 1), density=True)
|
Pxy0_density, _ = np.histogram(self.Pxy0, bins=bins, range=(0, 1), density=True)
|
||||||
Pxy1_density, _ = np.histogram(self.Pxy1, bins=bins, range=(0, 1), density=True)
|
Pxy1_density, _ = np.histogram(self.Pxy1, bins=bins, range=(0, 1), density=True)
|
||||||
|
|
||||||
|
@ -342,7 +344,7 @@ class HellingerDistanceY(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
||||||
prev_selected, min_dist = None, None
|
prev_selected, min_dist = None, None
|
||||||
for prev in F.prevalence_linspace(n_prevalences=100, repeat=1, smooth_limits_epsilon=0.0):
|
for prev in F.prevalence_linspace(n_prevalences=100, repeat=1, smooth_limits_epsilon=0.0):
|
||||||
Px_train = prev*Pxy1_density + (1 - prev)*Pxy0_density
|
Px_train = prev*Pxy1_density + (1 - prev)*Pxy0_density
|
||||||
hdy = HellingerDistanceY.HellingerDistance(Px_train, Px_test)
|
hdy = F.HellingerDistance(Px_train, Px_test)
|
||||||
if prev_selected is None or hdy < min_dist:
|
if prev_selected is None or hdy < min_dist:
|
||||||
prev_selected, min_dist = prev, hdy
|
prev_selected, min_dist = prev, hdy
|
||||||
prev_estimations.append(prev_selected)
|
prev_estimations.append(prev_selected)
|
||||||
|
@ -350,10 +352,6 @@ class HellingerDistanceY(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
||||||
pos_class_prev = np.median(prev_estimations)
|
pos_class_prev = np.median(prev_estimations)
|
||||||
return np.asarray([1-pos_class_prev, pos_class_prev])
|
return np.asarray([1-pos_class_prev, pos_class_prev])
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def HellingerDistance(cls, P, Q):
|
|
||||||
return np.sqrt(np.sum((np.sqrt(P) - np.sqrt(Q))**2))
|
|
||||||
|
|
||||||
|
|
||||||
class ExplicitLossMinimisation(AggregativeQuantifier, BinaryQuantifier):
|
class ExplicitLossMinimisation(AggregativeQuantifier, BinaryQuantifier):
|
||||||
|
|
||||||
|
@ -362,13 +360,13 @@ class ExplicitLossMinimisation(AggregativeQuantifier, BinaryQuantifier):
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def fit(self, data: LabelledCollection, fit_learner=True, *args):
|
def fit(self, data: LabelledCollection, fit_learner=True):
|
||||||
self._check_binary(data, self.__class__.__name__)
|
self._check_binary(data, self.__class__.__name__)
|
||||||
assert fit_learner, 'the method requires that fit_learner=True'
|
assert fit_learner, 'the method requires that fit_learner=True'
|
||||||
self.learner = SVMperf(self.svmperf_base, loss=self.loss, **self.kwargs).fit(data.instances, data.labels)
|
self.learner = SVMperf(self.svmperf_base, loss=self.loss, **self.kwargs).fit(data.instances, data.labels)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def aggregate(self, classif_predictions:np.ndarray, *args):
|
def aggregate(self, classif_predictions:np.ndarray):
|
||||||
return F.prevalence_from_labels(classif_predictions, self.learner.n_classes_)
|
return F.prevalence_from_labels(classif_predictions, self.learner.n_classes_)
|
||||||
|
|
||||||
def classify(self, X, y=None):
|
def classify(self, X, y=None):
|
||||||
|
@ -423,23 +421,24 @@ class OneVsAll(AggregativeQuantifier):
|
||||||
self.binary_quantifier = binary_quantifier
|
self.binary_quantifier = binary_quantifier
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
|
|
||||||
def fit(self, data: LabelledCollection, **kwargs):
|
def fit(self, data: LabelledCollection, fit_learner=True):
|
||||||
assert not data.binary, \
|
assert not data.binary, \
|
||||||
f'{self.__class__.__name__} expect non-binary data'
|
f'{self.__class__.__name__} expect non-binary data'
|
||||||
assert isinstance(self.binary_quantifier, BaseQuantifier), \
|
assert isinstance(self.binary_quantifier, BaseQuantifier), \
|
||||||
f'{self.binary_quantifier} does not seem to be a Quantifier'
|
f'{self.binary_quantifier} does not seem to be a Quantifier'
|
||||||
|
assert fit_learner==True, 'fit_learner must be True'
|
||||||
if not isinstance(self.binary_quantifier, BinaryQuantifier):
|
if not isinstance(self.binary_quantifier, BinaryQuantifier):
|
||||||
raise ValueError(f'{self.binary_quantifier.__class__.__name__} does not seem to be an instance of '
|
raise ValueError(f'{self.binary_quantifier.__class__.__name__} does not seem to be an instance of '
|
||||||
f'{BinaryQuantifier.__class__.__name__}')
|
f'{BinaryQuantifier.__class__.__name__}')
|
||||||
self.dict_binary_quantifiers = {c: deepcopy(self.binary_quantifier) for c in data.classes_}
|
self.dict_binary_quantifiers = {c: deepcopy(self.binary_quantifier) for c in data.classes_}
|
||||||
self.__parallel(self._delayed_binary_fit, data, **kwargs)
|
self.__parallel(self._delayed_binary_fit, data)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def classify(self, instances):
|
def classify(self, instances):
|
||||||
classif_predictions_bin = self.__parallel(self._delayed_binary_classification, instances)
|
classif_predictions_bin = self.__parallel(self._delayed_binary_classification, instances)
|
||||||
return classif_predictions_bin.T
|
return classif_predictions_bin.T
|
||||||
|
|
||||||
def aggregate(self, classif_predictions_bin, *args):
|
def aggregate(self, classif_predictions_bin):
|
||||||
assert set(np.unique(classif_predictions_bin)) == {0,1}, \
|
assert set(np.unique(classif_predictions_bin)) == {0,1}, \
|
||||||
'param classif_predictions_bin does not seem to be a valid matrix (ndarray) of binary ' \
|
'param classif_predictions_bin does not seem to be a valid matrix (ndarray) of binary ' \
|
||||||
'predictions for each document (row) and class (columns)'
|
'predictions for each document (row) and class (columns)'
|
||||||
|
@ -450,7 +449,7 @@ class OneVsAll(AggregativeQuantifier):
|
||||||
#prevalences = np.asarray(prevalences)
|
#prevalences = np.asarray(prevalences)
|
||||||
return F.normalize_prevalence(prevalences)
|
return F.normalize_prevalence(prevalences)
|
||||||
|
|
||||||
def quantify(self, X, *args):
|
def quantify(self, X):
|
||||||
prevalences = self.__parallel(self._delayed_binary_quantify, X)
|
prevalences = self.__parallel(self._delayed_binary_quantify, X)
|
||||||
return F.normalize_prevalence(prevalences)
|
return F.normalize_prevalence(prevalences)
|
||||||
|
|
||||||
|
@ -480,9 +479,9 @@ class OneVsAll(AggregativeQuantifier):
|
||||||
def _delayed_binary_aggregate(self, c, classif_predictions):
|
def _delayed_binary_aggregate(self, c, classif_predictions):
|
||||||
return self.dict_binary_quantifiers[c].aggregate(classif_predictions[:,c])[1] # the estimation for the positive class prevalence
|
return self.dict_binary_quantifiers[c].aggregate(classif_predictions[:,c])[1] # the estimation for the positive class prevalence
|
||||||
|
|
||||||
def _delayed_binary_fit(self, c, data, **kwargs):
|
def _delayed_binary_fit(self, c, data):
|
||||||
bindata = LabelledCollection(data.instances, data.labels == c, n_classes=2)
|
bindata = LabelledCollection(data.instances, data.labels == c, n_classes=2)
|
||||||
self.dict_binary_quantifiers[c].fit(bindata, **kwargs)
|
self.dict_binary_quantifiers[c].fit(bindata)
|
||||||
|
|
||||||
|
|
||||||
def isaggregative(model):
|
def isaggregative(model):
|
||||||
|
@ -497,5 +496,3 @@ def isbinary(model):
|
||||||
return isinstance(model, BinaryQuantifier)
|
return isinstance(model, BinaryQuantifier)
|
||||||
|
|
||||||
|
|
||||||
from . import neural
|
|
||||||
QuaNet = neural.QuaNetTrainer
|
|
|
@ -1,15 +1,18 @@
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from data import LabelledCollection
|
||||||
|
|
||||||
|
|
||||||
# Base Quantifier abstract class
|
# Base Quantifier abstract class
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class BaseQuantifier(metaclass=ABCMeta):
|
class BaseQuantifier(metaclass=ABCMeta):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fit(self, data, *args): ...
|
def fit(self, data): ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def quantify(self, instances, *args): ...
|
def quantify(self, instances): ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_params(self, **parameters): ...
|
def set_params(self, **parameters): ...
|
||||||
|
@ -18,6 +21,12 @@ class BaseQuantifier(metaclass=ABCMeta):
|
||||||
def get_params(self, deep=True): ...
|
def get_params(self, deep=True): ...
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryQuantifier(BaseQuantifier):
|
||||||
|
def _check_binary(self, data: LabelledCollection, quantifier_name):
|
||||||
|
assert data.binary, f'{quantifier_name} works only on problems of binary classification. ' \
|
||||||
|
f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.'
|
||||||
|
|
||||||
|
|
||||||
# class OneVsAll(AggregativeQuantifier):
|
# class OneVsAll(AggregativeQuantifier):
|
||||||
# """
|
# """
|
||||||
# Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary
|
# Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary
|
||||||
|
|
|
@ -0,0 +1,304 @@
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
|
||||||
|
|
||||||
|
import quapy as qp
|
||||||
|
from sklearn.model_selection import GridSearchCV, cross_val_predict
|
||||||
|
from model_selection import GridSearchQ
|
||||||
|
from .base import BaseQuantifier, BinaryQuantifier
|
||||||
|
from joblib import Parallel, delayed
|
||||||
|
from copy import deepcopy
|
||||||
|
from data import LabelledCollection
|
||||||
|
from quapy import functional as F
|
||||||
|
from . import neural
|
||||||
|
from evaluation import evaluate
|
||||||
|
|
||||||
|
QuaNet = neural.QuaNetTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class Ensemble(BaseQuantifier):
|
||||||
|
|
||||||
|
VALID_POLICIES = {'ave', 'ptr', 'ds'} | qp.error.QUANTIFICATION_ERROR_NAMES
|
||||||
|
|
||||||
|
"""
|
||||||
|
Methods from the articles:
|
||||||
|
Pérez-Gállego, P., Quevedo, J. R., & del Coz, J. J. (2017).
|
||||||
|
Using ensembles for problems with characterizable changes in data distribution: A case study on quantification.
|
||||||
|
Information Fusion, 34, 87-100.
|
||||||
|
and
|
||||||
|
Pérez-Gállego, P., Castano, A., Quevedo, J. R., & del Coz, J. J. (2019).
|
||||||
|
Dynamic ensemble selection for quantification tasks.
|
||||||
|
Information Fusion, 45, 1-15.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quantifier: BaseQuantifier, size=50, min_pos=1, red_size=25, policy='ave', n_jobs=1):
|
||||||
|
assert policy in Ensemble.VALID_POLICIES, f'unknown policy={policy}; valid are {Ensemble.VALID_POLICIES}'
|
||||||
|
self.base_quantifier = quantifier
|
||||||
|
self.size = size
|
||||||
|
self.min_pos = min_pos
|
||||||
|
self.red_size = red_size
|
||||||
|
self.policy = policy
|
||||||
|
self.n_jobs = n_jobs
|
||||||
|
self.post_proba_fn = None
|
||||||
|
|
||||||
|
def fit(self, data: LabelledCollection):
|
||||||
|
if self.policy=='ds' and not data.binary:
|
||||||
|
raise ValueError(f'ds policy is only defined for binary quantification, but this dataset is not binary')
|
||||||
|
|
||||||
|
# randomly chooses the prevalences for each member of the ensemble (preventing classes with less than
|
||||||
|
# min_pos positive examples)
|
||||||
|
prevs = [_draw_simplex(ndim=data.n_classes, min_val=self.min_pos / len(data)) for _ in range(self.size)]
|
||||||
|
|
||||||
|
posteriors = None
|
||||||
|
if self.policy == 'ds':
|
||||||
|
# precompute the training posterior probabilities
|
||||||
|
posteriors, self.post_proba_fn = self.ds_policy_get_posteriors(data)
|
||||||
|
|
||||||
|
is_static_policy = (self.policy in qp.error.QUANTIFICATION_ERROR_NAMES)
|
||||||
|
self.ensemble = Parallel(n_jobs=self.n_jobs)(
|
||||||
|
delayed(_delayed_new_instance)(
|
||||||
|
self.base_quantifier, data, prev, posteriors, keep_samples=is_static_policy
|
||||||
|
) for prev in prevs
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.ensemble = [deepcopy(self.base_quantifier) for _ in range(self.size)]
|
||||||
|
# self.prevs = [self._valid_simplex_sampling(data.n_classes, min_val=min_freq) for _ in range(self.size)]
|
||||||
|
# self.samples = [data.sampling(sample_size, *Pi) for Pi in self.prevs]
|
||||||
|
|
||||||
|
# Parallel(n_jobs=self.n_jobs)(
|
||||||
|
# delayed(_delayed_fit)(Qi, Si) for Si, Qi, in zip(self.samples, self.ensemble)
|
||||||
|
# )
|
||||||
|
|
||||||
|
# static selection policy (the name of a quantification-oriented error function to minimize)
|
||||||
|
if self.policy in qp.error.QUANTIFICATION_ERROR_NAMES:
|
||||||
|
self.accuracy_policy(error_name=self.policy)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def quantify(self, instances):
|
||||||
|
predictions = np.asarray(Parallel(n_jobs=self.n_jobs)(
|
||||||
|
delayed(_delayed_quantify)(Qi, instances) for Qi in self.ensemble
|
||||||
|
))
|
||||||
|
|
||||||
|
if self.policy == 'ptr':
|
||||||
|
predictions = self.ptr_policy(predictions)
|
||||||
|
elif self.policy == 'ds':
|
||||||
|
predictions = self.ds_policy(predictions, instances)
|
||||||
|
|
||||||
|
predictions = np.mean(predictions, axis=0)
|
||||||
|
return F.normalize_prevalence(predictions)
|
||||||
|
|
||||||
|
def set_params(self, **parameters):
|
||||||
|
raise NotImplementedError(f'{self.__class__.__name__} should not be used within GridSearchQ; '
|
||||||
|
f'instead, use GridSearchQ within Ensemble, or GridSearchCV whithin the '
|
||||||
|
f'base quantifier if it is an aggregative one.')
|
||||||
|
|
||||||
|
def get_params(self, deep=True):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def accuracy_policy(self, error_name):
|
||||||
|
"""
|
||||||
|
Selects the red_size best performant quantifiers in a static way (i.e., dropping all non-selected instances).
|
||||||
|
For each model in the ensemble, the performance is measured in terms of _error_name_ on the quantification of
|
||||||
|
the samples used for training the rest of the models in the ensemble.
|
||||||
|
"""
|
||||||
|
error = getattr(qp.error, error_name)
|
||||||
|
tests = [m[3] for m in self.ensemble]
|
||||||
|
scores = []
|
||||||
|
for i, model in enumerate(self.ensemble):
|
||||||
|
scores.append(evaluate(model[0], tests[:i] + tests[i+1:], error, self.n_jobs))
|
||||||
|
order = np.argsort(scores)
|
||||||
|
|
||||||
|
self.ensemble = select_k(self.ensemble, order, k=self.red_size)
|
||||||
|
|
||||||
|
def ptr_policy(self, predictions):
|
||||||
|
"""
|
||||||
|
Selects the predictions made by models that have been trained on samples with a prevalence that is most similar
|
||||||
|
to a first approximation of the test prevalence as made by all models in the ensemble.
|
||||||
|
"""
|
||||||
|
test_prev_estim = predictions.mean(axis=0)
|
||||||
|
tr_prevs = [m[1] for m in self.ensemble]
|
||||||
|
ptr_differences = [qp.error.mse(ptr_i, test_prev_estim) for ptr_i in tr_prevs]
|
||||||
|
order = np.argsort(ptr_differences)
|
||||||
|
return select_k(predictions, order, k=self.red_size)
|
||||||
|
|
||||||
|
def ds_policy_get_posteriors(self, data: LabelledCollection):
|
||||||
|
"""
|
||||||
|
In the original article, this procedure is not described in a sufficient level of detail. The paper only says
|
||||||
|
that the distribution of posterior probabilities from training and test examples is compared by means of the
|
||||||
|
Hellinger Distance. However, how these posterior probabilities are generated is not specified. In the article,
|
||||||
|
a Logistic Regressor (LR) is used as the classifier device and that could be used for this purpose. However, in
|
||||||
|
general, a Quantifier is not necessarily an instance of Aggreggative Probabilistic Quantifiers, and so that the
|
||||||
|
quantifier builds on top of a probabilistic classifier cannot be given for granted. Additionally, it would not
|
||||||
|
be correct to generate the posterior probabilities for training documents that have concurred in training the
|
||||||
|
classifier that generates them.
|
||||||
|
This function thus generates the posterior probabilities for all training documents in a cross-validation way,
|
||||||
|
using a LR with hyperparameters that have previously been optimized via grid search in 5FCV.
|
||||||
|
:return P,f, where P is a ndarray containing the posterior probabilities of the training data, generated via
|
||||||
|
cross-validation and using an optimized LR, and the function to be used in order to generate posterior
|
||||||
|
probabilities for test instances.
|
||||||
|
"""
|
||||||
|
X, y = data.Xy
|
||||||
|
lr_base = LogisticRegression(class_weight='balanced', max_iter=1000)
|
||||||
|
|
||||||
|
optim = GridSearchCV(
|
||||||
|
lr_base, param_grid={'C': np.logspace(-4,4,9)}, cv=5, n_jobs=self.n_jobs, refit=True
|
||||||
|
).fit(X, y)
|
||||||
|
|
||||||
|
posteriors = cross_val_predict(
|
||||||
|
optim.best_estimator_, X, y, cv=5, n_jobs=self.n_jobs, method='predict_proba'
|
||||||
|
)
|
||||||
|
posteriors_generator = optim.best_estimator_.predict_proba
|
||||||
|
|
||||||
|
return posteriors, posteriors_generator
|
||||||
|
|
||||||
|
def ds_policy(self, predictions, test):
|
||||||
|
test_posteriors = self.post_proba_fn(test)
|
||||||
|
test_distribution = get_probability_distribution(test_posteriors)
|
||||||
|
tr_distributions = [m[2] for m in self.ensemble]
|
||||||
|
dist = [F.HellingerDistance(tr_dist_i, test_distribution) for tr_dist_i in tr_distributions]
|
||||||
|
order = np.argsort(dist)
|
||||||
|
return select_k(predictions, order, k=self.red_size)
|
||||||
|
|
||||||
|
|
||||||
|
def get_probability_distribution(posterior_probabilities, bins=8):
|
||||||
|
assert posterior_probabilities.shape[1]==2, 'the posterior probabilities do not seem to be for a binary problem'
|
||||||
|
posterior_probabilities = posterior_probabilities[:,1] # take the positive posteriors only
|
||||||
|
distribution, _ = np.histogram(posterior_probabilities, bins=bins, range=(0, 1), density=True)
|
||||||
|
return distribution
|
||||||
|
|
||||||
|
|
||||||
|
def select_k(elements, order, k):
|
||||||
|
return [elements[idx] for idx in order[:k]]
|
||||||
|
|
||||||
|
|
||||||
|
def _delayed_new_instance(base_quantifier, data:LabelledCollection, prev, posteriors, keep_samples):
|
||||||
|
model = deepcopy(base_quantifier)
|
||||||
|
sample_index = data.sampling_index(len(data), *prev)
|
||||||
|
sample = data.sampling_from_index(sample_index)
|
||||||
|
model.fit(sample)
|
||||||
|
tr_prevalence = sample.prevalence()
|
||||||
|
tr_distribution = get_probability_distribution(posteriors[sample_index]) if (posteriors is not None) else None
|
||||||
|
return (model, tr_prevalence, tr_distribution, sample if keep_samples else None)
|
||||||
|
|
||||||
|
|
||||||
|
def _delayed_fit(quantifier, data):
|
||||||
|
quantifier.fit(data)
|
||||||
|
|
||||||
|
|
||||||
|
def _delayed_quantify(quantifier, instances):
|
||||||
|
return quantifier[0].quantify(instances)
|
||||||
|
|
||||||
|
|
||||||
|
def _draw_simplex(ndim, min_val, max_trials=100):
|
||||||
|
"""
|
||||||
|
returns a uniform sampling from the ndim-dimensional simplex but guarantees that all dimensions
|
||||||
|
are >= min_class_prev (for min_val>0, this makes the sampling not truly uniform)
|
||||||
|
:param ndim: number of dimensions of the simplex
|
||||||
|
:param min_val: minimum class prevalence allowed. If less than 1/ndim a ValueError will be throw since
|
||||||
|
there is no possible solution.
|
||||||
|
:return: a sample from the ndim-dimensional simplex that is uniform in S(ndim)-R where S(ndim) is the simplex
|
||||||
|
and R is the simplex subset containing dimensions lower than min_val
|
||||||
|
"""
|
||||||
|
if min_val >= 1/ndim:
|
||||||
|
raise ValueError(f'no sample can be draw from the {ndim}-dimensional simplex so that '
|
||||||
|
f'all its values are >={min_val} (try with a larger value for min_pos)')
|
||||||
|
trials = 0
|
||||||
|
while True:
|
||||||
|
u = F.uniform_simplex_sampling(ndim)
|
||||||
|
if all(u >= min_val):
|
||||||
|
return u
|
||||||
|
trials += 1
|
||||||
|
if trials >= max_trials:
|
||||||
|
raise ValueError(f'it looks like finding a random simplex with all its dimensions being'
|
||||||
|
f'>= {min_val} is unlikely (it failed after {max_trials} trials)')
|
||||||
|
|
||||||
|
|
||||||
|
def _instantiate_ensemble(learner, base_quantifier_class, param_grid, optim, sample_size, eval_budget, **kwargs):
|
||||||
|
if optim is None:
|
||||||
|
base_quantifier = base_quantifier_class(learner)
|
||||||
|
elif optim in qp.error.CLASSIFICATION_ERROR:
|
||||||
|
learner = GridSearchCV(learner, param_grid)
|
||||||
|
base_quantifier = base_quantifier_class(learner)
|
||||||
|
elif optim in qp.error.QUANTIFICATION_ERROR:
|
||||||
|
base_quantifier = GridSearchQ(base_quantifier_class(learner),
|
||||||
|
param_grid=param_grid,
|
||||||
|
sample_size=sample_size,
|
||||||
|
eval_budget=eval_budget,
|
||||||
|
error=optim)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'value optim={optim} not understood')
|
||||||
|
|
||||||
|
return Ensemble(base_quantifier, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class EnsembleFactory(BaseQuantifier):
|
||||||
|
|
||||||
|
def __init__(self, learner, base_quantifier_class, param_grid=None, optim=None, sample_size=None, eval_budget=None,
|
||||||
|
size=50, min_pos=1, red_size=25, policy='ave', n_jobs=1):
|
||||||
|
if param_grid is None and optim is not None:
|
||||||
|
raise ValueError(f'param_grid is None but optim was requested.')
|
||||||
|
error = self._check_error(optim)
|
||||||
|
self.model = _instantiate_ensemble(learner, base_quantifier_class, param_grid, error, sample_size,
|
||||||
|
eval_budget, size=size, min_pos=min_pos, red_size=red_size,
|
||||||
|
policy=policy, n_jobs=n_jobs)
|
||||||
|
|
||||||
|
def fit(self, data):
|
||||||
|
return self.model.fit(data)
|
||||||
|
|
||||||
|
def quantify(self, instances):
|
||||||
|
return self.model.quantify(instances)
|
||||||
|
|
||||||
|
def set_params(self, **parameters):
|
||||||
|
return self.model.set_params(**parameters)
|
||||||
|
|
||||||
|
def get_params(self, deep=True):
|
||||||
|
return self.model.get_params(deep)
|
||||||
|
|
||||||
|
def _check_error(self, error):
|
||||||
|
if error is None:
|
||||||
|
return None
|
||||||
|
if error in qp.error.QUANTIFICATION_ERROR or error in qp.error.CLASSIFICATION_ERROR:
|
||||||
|
return error
|
||||||
|
elif isinstance(error, str):
|
||||||
|
assert error in qp.error.ERROR_NAMES, \
|
||||||
|
f'unknown error name; valid ones are {qp.error.ERROR_NAMES}'
|
||||||
|
return getattr(qp.error, error)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
||||||
|
f'the name of an error function in {qp.error.ERROR_NAMES}')
|
||||||
|
|
||||||
|
|
||||||
|
class ECC(EnsembleFactory):
|
||||||
|
def __init__(self, learner, param_grid=None, optim=None, sample_size=None, eval_budget=None,
|
||||||
|
size=50, min_pos=1, red_size=25, policy='ave', n_jobs=1):
|
||||||
|
super().__init__(
|
||||||
|
learner, qp.method.aggregative.CC, param_grid, optim, sample_size, eval_budget, size, min_pos,
|
||||||
|
red_size, policy, n_jobs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EACC(EnsembleFactory):
|
||||||
|
def __init__(self, learner, param_grid=None, optim=None, sample_size=None, eval_budget=None,
|
||||||
|
size=50, min_pos=1, red_size=25, policy='ave', n_jobs=1):
|
||||||
|
super().__init__(
|
||||||
|
learner, qp.method.aggregative.ACC, param_grid, optim, sample_size, eval_budget, size, min_pos,
|
||||||
|
red_size, policy, n_jobs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EHDy(EnsembleFactory):
|
||||||
|
def __init__(self, learner, param_grid=None, optim=None, sample_size=None, eval_budget=None,
|
||||||
|
size=50, min_pos=1, red_size=25, policy='ave', n_jobs=1):
|
||||||
|
super().__init__(
|
||||||
|
learner, qp.method.aggregative.HDy, param_grid, optim, sample_size, eval_budget, size, min_pos,
|
||||||
|
red_size, policy, n_jobs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EEMQ(EnsembleFactory):
|
||||||
|
def __init__(self, learner, param_grid=None, optim=None, sample_size=None, eval_budget=None,
|
||||||
|
size=50, min_pos=1, red_size=25, policy='ave', n_jobs=1):
|
||||||
|
super().__init__(
|
||||||
|
learner, qp.method.aggregative.EMQ, param_grid, optim, sample_size, eval_budget, size, min_pos,
|
||||||
|
red_size, policy, n_jobs
|
||||||
|
)
|
|
@ -75,23 +75,26 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
# estimate the hard and soft stats tpr and fpr of the classifier
|
# estimate the hard and soft stats tpr and fpr of the classifier
|
||||||
self.tr_prev = data.prevalence()
|
self.tr_prev = data.prevalence()
|
||||||
|
|
||||||
self.quantifiers = [
|
self.quantifiers = {
|
||||||
ClassifyAndCount(self.learner).fit(data, fit_learner=False),
|
'cc': ClassifyAndCount(self.learner).fit(data, fit_learner=False),
|
||||||
AdjustedClassifyAndCount(self.learner).fit(data, fit_learner=False),
|
'acc': AdjustedClassifyAndCount(self.learner).fit(data, fit_learner=False),
|
||||||
ProbabilisticClassifyAndCount(self.learner).fit(data, fit_learner=False),
|
'pcc': ProbabilisticClassifyAndCount(self.learner).fit(data, fit_learner=False),
|
||||||
ProbabilisticAdjustedClassifyAndCount(self.learner).fit(data, fit_learner=False),
|
'pacc': ProbabilisticAdjustedClassifyAndCount(self.learner).fit(data, fit_learner=False),
|
||||||
ExpectationMaximizationQuantifier(self.learner).fit(data, fit_learner=False),
|
'emq': ExpectationMaximizationQuantifier(self.learner).fit(data, fit_learner=False),
|
||||||
]
|
}
|
||||||
|
|
||||||
self.status = {
|
self.status = {
|
||||||
'tr-loss': -1,
|
'tr-loss': -1,
|
||||||
'va-loss': -1,
|
'va-loss': -1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nQ = len(self.quantifiers)
|
||||||
|
nC = data.n_classes
|
||||||
self.quanet = QuaNetModule(
|
self.quanet = QuaNetModule(
|
||||||
doc_embedding_size=train_data.instances.shape[1],
|
doc_embedding_size=train_data.instances.shape[1],
|
||||||
n_classes=data.n_classes,
|
n_classes=data.n_classes,
|
||||||
stats_size=len(self.quantifiers) * data.n_classes,
|
stats_size=nQ*nC + 2*nC*nC,
|
||||||
|
order_by=0 if data.binary else None,
|
||||||
**self.quanet_params
|
**self.quanet_params
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
|
@ -119,10 +122,15 @@ class QuaNetTrainer(BaseQuantifier):
|
||||||
def get_aggregative_estims(self, posteriors):
|
def get_aggregative_estims(self, posteriors):
|
||||||
label_predictions = np.argmax(posteriors, axis=-1)
|
label_predictions = np.argmax(posteriors, axis=-1)
|
||||||
prevs_estim = []
|
prevs_estim = []
|
||||||
for quantifier in self.quantifiers:
|
for quantifier in self.quantifiers.values():
|
||||||
predictions = posteriors if isprobabilistic(quantifier) else label_predictions
|
predictions = posteriors if isprobabilistic(quantifier) else label_predictions
|
||||||
prevs_estim.append(quantifier.aggregate(predictions))
|
prevs_estim.extend(quantifier.aggregate(predictions))
|
||||||
return np.asarray(prevs_estim).flatten()
|
|
||||||
|
# add the class-conditional predictions P(y'i|yj) from ACC and PACC
|
||||||
|
prevs_estim.extend(self.quantifiers['acc'].Pte_cond_estim_.flatten())
|
||||||
|
prevs_estim.extend(self.quantifiers['pacc'].Pte_cond_estim_.flatten())
|
||||||
|
|
||||||
|
return prevs_estim
|
||||||
|
|
||||||
def quantify(self, instances, *args):
|
def quantify(self, instances, *args):
|
||||||
posteriors = self.learner.predict_proba(instances)
|
posteriors = self.learner.predict_proba(instances)
|
||||||
|
|
|
@ -4,14 +4,14 @@ from evaluation import artificial_sampling_prediction
|
||||||
from data.base import LabelledCollection
|
from data.base import LabelledCollection
|
||||||
from method.aggregative import BaseQuantifier
|
from method.aggregative import BaseQuantifier
|
||||||
from typing import Union, Callable
|
from typing import Union, Callable
|
||||||
import quapy.functional as F
|
import functional as F
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
class GridSearchQ:
|
class GridSearchQ(BaseQuantifier):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model : BaseQuantifier,
|
model: BaseQuantifier,
|
||||||
param_grid: dict,
|
param_grid: dict,
|
||||||
sample_size: int,
|
sample_size: int,
|
||||||
n_prevpoints: int = None,
|
n_prevpoints: int = None,
|
||||||
|
@ -105,14 +105,14 @@ class GridSearchQ:
|
||||||
if error in qp.error.QUANTIFICATION_ERROR:
|
if error in qp.error.QUANTIFICATION_ERROR:
|
||||||
self.error = error
|
self.error = error
|
||||||
elif isinstance(error, str):
|
elif isinstance(error, str):
|
||||||
assert error in {func.__name__ for func in qp.error.QUANTIFICATION_ERROR}, \
|
assert error in qp.error.QUANTIFICATION_ERROR_NAMES, \
|
||||||
f'unknown error name; valid ones are {qp.error.QUANTIFICATION_ERROR}'
|
f'unknown error name; valid ones are {qp.error.QUANTIFICATION_ERROR_NAMES}'
|
||||||
self.error = getattr(qp.error, error)
|
self.error = getattr(qp.error, error)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
||||||
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR}')
|
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
||||||
|
|
||||||
def fit(self, training: LabelledCollection, validation: Union[LabelledCollection, float]):
|
def fit(self, training: LabelledCollection, validation: Union[LabelledCollection, float]=0.3):
|
||||||
"""
|
"""
|
||||||
:param training: the training set on which to optimize the hyperparameters
|
:param training: the training set on which to optimize the hyperparameters
|
||||||
:param validation: either a LabelledCollection on which to test the performance of the different settings, or
|
:param validation: either a LabelledCollection on which to test the performance of the different settings, or
|
||||||
|
@ -158,5 +158,14 @@ class GridSearchQ:
|
||||||
self.sout(f'refitting on the whole development set')
|
self.sout(f'refitting on the whole development set')
|
||||||
self.best_model_.fit(training + validation)
|
self.best_model_.fit(training + validation)
|
||||||
|
|
||||||
return self.best_model_
|
return self
|
||||||
|
|
||||||
|
def quantify(self, instances):
|
||||||
|
return self.best_model_.quantify(instances)
|
||||||
|
|
||||||
|
def set_params(self, **parameters):
|
||||||
|
self.param_grid = parameters
|
||||||
|
|
||||||
|
def get_params(self, deep=True):
|
||||||
|
return self.param_grid
|
||||||
|
|
||||||
|
|
52
test.py
52
test.py
|
@ -1,45 +1,65 @@
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.model_selection import GridSearchCV
|
||||||
from sklearn.svm import LinearSVC
|
from sklearn.svm import LinearSVC
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
import quapy.functional as F
|
import quapy.functional as F
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from classification.methods import PCALR
|
||||||
from classification.neural import NeuralClassifierTrainer, CNNnet
|
from classification.neural import NeuralClassifierTrainer, CNNnet
|
||||||
from quapy.model_selection import GridSearchQ
|
from quapy.model_selection import GridSearchQ
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#qp.datasets.fetch_UCIDataset('acute.b', verbose=True)
|
||||||
|
|
||||||
|
#sys.exit(0)
|
||||||
qp.environ['SAMPLE_SIZE'] = 500
|
qp.environ['SAMPLE_SIZE'] = 500
|
||||||
|
#param_grid = {'C': np.logspace(-3,3,7), 'class_weight': ['balanced', None]}
|
||||||
|
param_grid = {'C': np.logspace(0,3,4), 'class_weight': ['balanced']}
|
||||||
|
max_evaluations = 5000
|
||||||
|
|
||||||
sample_size = qp.environ['SAMPLE_SIZE']
|
sample_size = qp.environ['SAMPLE_SIZE']
|
||||||
binary = True
|
binary = True
|
||||||
svmperf_home = './svm_perf_quantification'
|
svmperf_home = './svm_perf_quantification'
|
||||||
|
|
||||||
if binary:
|
if binary:
|
||||||
dataset = qp.datasets.fetch_reviews('kindle', tfidf=False, min_df=5)
|
dataset = qp.datasets.fetch_reviews('kindle', tfidf=True, min_df=5)
|
||||||
qp.data.preprocessing.index(dataset, inplace=True)
|
#qp.data.preprocessing.index(dataset, inplace=True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
dataset = qp.datasets.fetch_twitter('hcr', for_model_selection=False, min_df=10, pickle=True)
|
dataset = qp.datasets.fetch_twitter('hcr', for_model_selection=False, min_df=10, pickle=True)
|
||||||
# dataset.training = dataset.training.sampling(SAMPLE_SIZE, 0.2, 0.5, 0.3)
|
dataset.training = dataset.training.sampling(sample_size, 0.2, 0.5, 0.3)
|
||||||
|
|
||||||
print(f'dataset loaded: #training={len(dataset.training)} #test={len(dataset.test)}')
|
print(f'dataset loaded: #training={len(dataset.training)} #test={len(dataset.test)}')
|
||||||
|
|
||||||
|
|
||||||
# training a quantifier
|
# training a quantifier
|
||||||
# learner = LogisticRegression(max_iter=1000)
|
# learner = LogisticRegression(max_iter=1000)
|
||||||
# model = qp.method.aggregative.ClassifyAndCount(learner)
|
#model = qp.method.aggregative.ClassifyAndCount(learner)
|
||||||
# model = qp.method.aggregative.AdjustedClassifyAndCount(learner)
|
# model = qp.method.aggregative.AdjustedClassifyAndCount(learner)
|
||||||
# model = qp.method.aggregative.ProbabilisticClassifyAndCount(learner)
|
# model = qp.method.aggregative.ProbabilisticClassifyAndCount(learner)
|
||||||
# model = qp.method.aggregative.ProbabilisticAdjustedClassifyAndCount(learner)
|
# model = qp.method.aggregative.ProbabilisticAdjustedClassifyAndCount(learner)
|
||||||
|
# model = qp.method.aggregative.HellingerDistanceY(learner)
|
||||||
# model = qp.method.aggregative.ExpectationMaximizationQuantifier(learner)
|
# model = qp.method.aggregative.ExpectationMaximizationQuantifier(learner)
|
||||||
# model = qp.method.aggregative.ExplicitLossMinimisationBinary(svmperf_home, loss='q', C=100)
|
# model = qp.method.aggregative.ExplicitLossMinimisationBinary(svmperf_home, loss='q', C=100)
|
||||||
# model = qp.method.aggregative.SVMQ(svmperf_home, C=1)
|
# model = qp.method.aggregative.SVMQ(svmperf_home, C=1)
|
||||||
|
|
||||||
learner = NeuralClassifierTrainer(CNNnet(dataset.vocabulary_size, dataset.n_classes))
|
#learner = PCALR()
|
||||||
print(learner.get_params())
|
#learner = NeuralClassifierTrainer(CNNnet(dataset.vocabulary_size, dataset.n_classes))
|
||||||
model = qp.method.aggregative.QuaNet(learner, sample_size, device='cpu')
|
#print(learner.get_params())
|
||||||
|
#model = qp.method.meta.QuaNet(learner, sample_size, device='cpu')
|
||||||
|
|
||||||
if qp.isbinary(model) and not qp.isbinary(dataset):
|
#learner = GridSearchCV(LogisticRegression(max_iter=1000), param_grid=param_grid, n_jobs=-1, verbose=1)
|
||||||
model = qp.method.aggregative.OneVsAll(model)
|
learner = LogisticRegression(max_iter=1000)
|
||||||
|
model = qp.method.meta.ECC(learner, size=20, red_size=10, param_grid=None, optim=None, policy='ds')
|
||||||
|
#model = qp.method.meta.EHDy(learner, param_grid=param_grid, optim='mae',
|
||||||
|
# sample_size=sample_size, eval_budget=max_evaluations//10, n_jobs=-1)
|
||||||
|
#model = qp.method.aggregative.ClassifyAndCount(learner)
|
||||||
|
|
||||||
|
|
||||||
|
#if qp.isbinary(model) and not qp.isbinary(dataset):
|
||||||
|
# model = qp.method.aggregative.OneVsAll(model)
|
||||||
|
|
||||||
|
|
||||||
# Model fit and Evaluation on the test data
|
# Model fit and Evaluation on the test data
|
||||||
|
@ -49,6 +69,10 @@ print(f'fitting model {model.__class__.__name__}')
|
||||||
#train, val = dataset.training.split_stratified(0.6)
|
#train, val = dataset.training.split_stratified(0.6)
|
||||||
#model.fit(train, val_split=val)
|
#model.fit(train, val_split=val)
|
||||||
model.fit(dataset.training)
|
model.fit(dataset.training)
|
||||||
|
#for i,e in enumerate(model.ensemble):
|
||||||
|
#print(i, e.learner.best_estimator_)
|
||||||
|
# print(i, e.best_model_.learner)
|
||||||
|
|
||||||
|
|
||||||
# estimating class prevalences
|
# estimating class prevalences
|
||||||
print('quantifying')
|
print('quantifying')
|
||||||
|
@ -67,7 +91,7 @@ print(f'mae={error:.3f}')
|
||||||
# Model fit and Evaluation according to the artificial sampling protocol
|
# Model fit and Evaluation according to the artificial sampling protocol
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
|
|
||||||
max_evaluations = 5000
|
|
||||||
n_prevpoints = F.get_nprevpoints_approximation(combinations_budget=max_evaluations, n_classes=dataset.n_classes)
|
n_prevpoints = F.get_nprevpoints_approximation(combinations_budget=max_evaluations, n_classes=dataset.n_classes)
|
||||||
n_evaluations = F.num_prevalence_combinations(n_prevpoints, dataset.n_classes)
|
n_evaluations = F.num_prevalence_combinations(n_prevpoints, dataset.n_classes)
|
||||||
print(f'the prevalence interval [0,1] will be split in {n_prevpoints} prevalence points for each class, so that\n'
|
print(f'the prevalence interval [0,1] will be split in {n_prevpoints} prevalence points for each class, so that\n'
|
||||||
|
@ -76,7 +100,7 @@ print(f'the prevalence interval [0,1] will be split in {n_prevpoints} prevalence
|
||||||
|
|
||||||
true_prev, estim_prev = qp.evaluation.artificial_sampling_prediction(model, dataset.test, sample_size, n_prevpoints)
|
true_prev, estim_prev = qp.evaluation.artificial_sampling_prediction(model, dataset.test, sample_size, n_prevpoints)
|
||||||
|
|
||||||
qp.error.SAMPLE_SIZE = sample_size
|
#qp.error.SAMPLE_SIZE = sample_size
|
||||||
print(f'Evaluation according to the artificial sampling protocol ({len(true_prev)} evals)')
|
print(f'Evaluation according to the artificial sampling protocol ({len(true_prev)} evals)')
|
||||||
for error in qp.error.QUANTIFICATION_ERROR:
|
for error in qp.error.QUANTIFICATION_ERROR:
|
||||||
score = error(true_prev, estim_prev)
|
score = error(true_prev, estim_prev)
|
||||||
|
@ -86,7 +110,7 @@ for error in qp.error.QUANTIFICATION_ERROR:
|
||||||
# Model selection and Evaluation according to the artificial sampling protocol
|
# Model selection and Evaluation according to the artificial sampling protocol
|
||||||
# ----------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
param_grid = {'C': np.logspace(-3,3,7), 'class_weight': ['balanced', None]}
|
|
||||||
|
|
||||||
model_selection = GridSearchQ(model,
|
model_selection = GridSearchQ(model,
|
||||||
param_grid=param_grid,
|
param_grid=param_grid,
|
||||||
|
@ -96,8 +120,8 @@ model_selection = GridSearchQ(model,
|
||||||
refit=True,
|
refit=True,
|
||||||
verbose=True)
|
verbose=True)
|
||||||
|
|
||||||
# model = model_selection.fit(dataset.training, validation=0.3)
|
model = model_selection.fit(dataset.training, validation=0.3)
|
||||||
model = model_selection.fit(train, validation=val)
|
#model = model_selection.fit(train, validation=val)
|
||||||
print(f'Model selection: best_params = {model_selection.best_params_}')
|
print(f'Model selection: best_params = {model_selection.best_params_}')
|
||||||
print(f'param scores:')
|
print(f'param scores:')
|
||||||
for params, score in model_selection.param_scores_.items():
|
for params, score in model_selection.param_scores_.items():
|
||||||
|
|
Loading…
Reference in New Issue