1
0
Fork 0

some preliminary experiments with density ratio

This commit is contained in:
Alejandro Moreo Fernandez 2023-02-20 18:33:07 +01:00
parent fb2390e8d7
commit 24e755dcc1
6 changed files with 512 additions and 3 deletions

View File

@ -0,0 +1,305 @@
import itertools
from functools import cache
import numpy as np
from densratio import densratio
from scipy.sparse import issparse, vstack
from scipy.stats import multivariate_normal
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
import quapy as qp
from Transduction_office.pykliep import DensityRatioEstimator
from quapy.protocol import AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol
from quapy.data import LabelledCollection
from quapy.method.aggregative import *
import quapy.functional as F
from time import time
def gaussian(mean, cov=1., label=0, size=100, random_state=0):
"""
Creates a label collection in which the instances are distributed according to a Gaussian with specified
parameters and labels all data points with a specific label.
:param mean: ndarray of shape (n_dimensions) with the center
:param cov: ndarray of shape (n_dimensions, n_dimensions) with the covariance matrix, or a number for np.eye
:param label: the class label for the collection
:param size: number of instances
:param random_state: allows for replicating experiments
:return: an instance of LabelledCollection
"""
mean = np.asarray(mean)
assert mean.ndim==1, 'wrong shape for mean'
n_features = mean.shape[0]
if isinstance(cov, (int, float)):
cov = np.eye(n_features) * cov
instances = multivariate_normal.rvs(mean, cov, size, random_state=random_state)
return LabelledCollection(instances, labels=[label]*size)
# ------------------------------------------------------------------------------------
# Protocol for generating prior probability shift + covariate shift by mixing "domains"
# ------------------------------------------------------------------------------------
class CovPriorShift(AbstractStochasticSeededProtocol):
def __init__(self, domains: list[LabelledCollection], sample_size=None, repeats=100, min_support=0, random_state=0,
return_type='sample_prev'):
super(CovPriorShift, self).__init__(random_state)
self.domains = list(itertools.chain.from_iterable(lc.separate() for lc in domains))
self.sample_size = qp._get_sample_size(sample_size)
self.repeats = repeats
self.min_support = min_support
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
def samples_parameters(self):
"""
Return all the necessary parameters to replicate the samples as according to the UPP protocol.
:return: a list of indexes that realize the UPP sampling
"""
indexes = []
tentatives = 0
while len(indexes) < self.repeats:
alpha = F.uniform_simplex_sampling(n_classes=len(self.domains))
# sizes = np.asarray([round(len(lc_i) * alpha_i) for lc_i, alpha_i in zip(self.domains, alpha)])
sizes = (alpha * self.sample_size).astype(int)
if all(sizes > self.min_support):
indexes_i = [lc.sampling_index(size) for lc, size in zip(self.domains, sizes)]
indexes.append(indexes_i)
tentatives = 0
else:
tentatives += 1
if tentatives > 100:
raise ValueError('the support is too strict, and it is difficult '
'or impossible to generate valid samples')
return indexes
def sample(self, params):
indexes = params
lcs = [lc.sampling_from_index(index) for index, lc in zip(indexes, self.domains)]
return LabelledCollection.join(*lcs)
def total(self):
"""
Returns the number of samples that will be generated
:return: int
"""
return self.repeats
# ---------------------------------------------------------------------------------------
# Methods of "importance weight", e.g., by ratio density estimation (KLIEP, SILF, LogReg)
# ---------------------------------------------------------------------------------------
class ImportanceWeight:
@abstractmethod
def weights(self, Xtr, ytr, Xte):
pass
class KLIEP(ImportanceWeight):
def __init__(self):
pass
def weights(self, Xtr, ytr, Xte):
kliep = DensityRatioEstimator()
kliep.fit(Xtr, Xte)
return kliep.predict(Xtr)
class USILF(ImportanceWeight):
def __init__(self, alpha=0.):
self.alpha = alpha
def weights(self, Xtr, ytr, Xte):
dense_ratio_obj = densratio(Xtr, Xte, alpha=self.alpha, verbose=False)
return dense_ratio_obj.compute_density_ratio(Xtr)
class LogReg(ImportanceWeight):
def __init__(self):
pass
def weights(self, Xtr, ytr, Xte):
# check "Direct Density Ratio Estimation for
# Large-scale Covariate Shift Adaptation", Eq.28
if issparse(Xtr):
X = vstack([Xtr, Xte])
else:
X = np.concatenate([Xtr, Xte])
y = [0]*len(Xtr) + [1]*len(Xte)
logreg = GridSearchCV(
LogisticRegression(),
param_grid={'C':np.logspace(-3,3,7), 'class_weight': ['balanced', None]},
n_jobs=-1
)
logreg.fit(X, y)
prob_train = logreg.predict_proba(Xtr)[:,0]
prob_test = logreg.predict_proba(Xtr)[:,1]
prior_train = len(Xtr)
prior_test = len(Xte)
w = (prior_train/prior_test)*(prob_test/prob_train)
return w
class MostTest(ImportanceWeight):
def __init__(self):
pass
def weights(self, Xtr, ytr, Xte):
# check "Direct Density Ratio Estimation for
# Large-scale Covariate Shift Adaptation", Eq.28
if issparse(Xtr):
X = vstack([Xtr, Xte])
else:
X = np.concatenate([Xtr, Xte])
y = [0]*len(Xtr) + [1]*len(Xte)
logreg = GridSearchCV(
LogisticRegression(),
param_grid={'C':np.logspace(-3,3,7), 'class_weight': ['balanced', None]},
n_jobs=-1
)
# logreg = LogisticRegression()
# logreg.fit(X, y)
# prob_test = logreg.predict_proba(Xtr)[:,1]
prob_test = cross_val_predict(logreg, X, y, n_jobs=-1, method="predict_proba")[:len(Xtr),1]
return prob_test
class Random(ImportanceWeight):
def __init__(self):
pass
def weights(self, Xtr, ytr, Xte):
return np.random.rand(len(Xtr))
# --------------------------------------------------------------------------------------------
# Quantification Methods that rely on Importance Weight for reweighting the training instances
# --------------------------------------------------------------------------------------------
class TransductiveQuantifier(BaseQuantifier):
def fit(self, data: LabelledCollection):
self.training_ = data
return self
@property
def training(self):
return self.training_
class ReweightingAggregative(TransductiveQuantifier):
def __init__(self, classifier, weighter: ImportanceWeight, quantif_method=CC):
self.classifier = classifier
self.weighter = weighter
self.quantif_method = quantif_method
def quantify(self, instances):
# time_weight = 2.95340 time_train = 0.00619
w = self.weighter.weights(*self.training.Xy, instances)
self.classifier.fit(*self.training.Xy, sample_weight=w)
quantifier = self.quantif_method(self.classifier).fit(self.training, fit_classifier=False)
return quantifier.quantify(instances)
# --------------------------------------------------------------------------------------------
# Quantification Methods that rely on Importance Weight for selecting a validation partition
# --------------------------------------------------------------------------------------------
def select_from_weights(w, data: LabelledCollection, val_prop=0.4):
# w[w<1]=0
order = np.argsort(w)
split_point = int(len(w)*val_prop)
train_idx, val_idx = order[:-split_point], order[-split_point:]
return data.sampling_from_index(train_idx), data.sampling_from_index(val_idx)
class SelectorQuantifiers(TransductiveQuantifier):
def __init__(self, classifier, weighter: ImportanceWeight, quantif_method=ACC, val_split=0.4):
self.classifier = classifier
self.weighter = weighter
self.quantif_method = quantif_method
self.val_split = val_split
def quantify(self, instances):
w = self.weighter.weights(*self.training.Xy, instances)
train, val = select_from_weights(w, self.training, self.val_split)
quantifier = self.quantif_method(self.classifier).fit(train, val_split=val)
return quantifier.quantify(instances)
if __name__ == '__main__':
qp.environ['SAMPLE_SIZE'] = 500
dA_l0 = gaussian(mean=[0,0], label=0, size=1000)
dA_l1 = gaussian(mean=[1,0], label=1, size=1000)
dB_l0 = gaussian(mean=[0,1], label=0, size=1000)
dB_l1 = gaussian(mean=[1,1], label=1, size=1000)
dA = LabelledCollection.join(dA_l0, dA_l1)
dB = LabelledCollection.join(dB_l0, dB_l1)
dA_train, dA_test = dA.split_stratified(0.5, random_state=0)
dB_train, dB_test = dB.split_stratified(0.5, random_state=0)
train = LabelledCollection.join(dA_train, dB_train)
def lr():
return LogisticRegression()
# def lr():
# return GridSearchCV(
# LogisticRegression(),
# param_grid={'C':np.logspace(-3,3,7), 'class_weight': ['balanced', None]},
# n_jobs=-1
# )
methods = [
('CC', CC(lr())),
('PCC', PCC(lr())),
('ACC', ACC(lr())),
('PACC', PACC(lr())),
('HDy', EMQ(lr())),
('EMQ', EMQ(lr())),
('Sel-ACC', SelectorQuantifiers(lr(), MostTest(), ACC)),
('Sel-PACC', SelectorQuantifiers(lr(), MostTest(), PACC)),
('Sel-HDy', SelectorQuantifiers(lr(), MostTest(), HDy)),
('LogReg-CC', ReweightingAggregative(lr(), LogReg(), CC)),
('LogReg-PCC', ReweightingAggregative(lr(), LogReg(), PCC)),
('LogReg-EMQ', ReweightingAggregative(lr(), LogReg(), EMQ)),
# ('KLIEP-CC', TransductiveAggregative(lr(), KLIEP(), CC)),
# ('KLIEP-PCC', TransductiveAggregative(lr(), KLIEP(), PCC)),
# ('KLIEP-EMQ', TransductiveAggregative(lr(), KLIEP(), EMQ)),
# ('SILF-CC', TransductiveAggregative(lr(), USILF(), CC)),
# ('SILF-PCC', TransductiveAggregative(lr(), USILF(), PCC)),
# ('SILF-EMQ', TransductiveAggregative(lr(), USILF(), EMQ))
]
for name, model in methods:
with qp.util.temp_seed(1):
model.fit(train)
prot = CovPriorShift([dA_test, dB_test], repeats=10)
mae = qp.evaluation.evaluate(model, protocol=prot, error_metric='mae')
print(f'{name}: {mae = :.4f}')
# mrae = qp.evaluation.evaluate(model, protocol=prot, error_metric='mrae')
# print(f'{name}: {mrae = :.4f}')

View File

@ -0,0 +1,188 @@
import numpy as np
import warnings
class DensityRatioEstimator:
"""
Class to accomplish direct density estimation implementing the original KLIEP
algorithm from Direct Importance Estimation with Model Selection
and Its Application to Covariate Shift Adaptation by Sugiyama et al.
The training set is distributed via
train ~ p(x)
and the test set is distributed via
test ~ q(x).
The KLIEP algorithm and its variants approximate w(x) = q(x) / p(x) directly. The predict function returns the
estimate of w(x). The function w(x) can serve as sample weights for the training set during
training to modify the expectation function that the model's loss function is optimized via,
i.e.
E_{x ~ w(x)p(x)} loss(x) = E_{x ~ q(x)} loss(x).
Usage :
The fit method is used to run the KLIEP algorithm using LCV and returns value of J
trained on the entire training/test set with the best sigma found.
Use the predict method on the training set to determine the sample weights from the KLIEP algorithm.
"""
def __init__(self, max_iter=5000, num_params=[.1, .2], epsilon=1e-4, cv=3, sigmas=[.01, .1, .25, .5, .75, 1],
random_state=None, verbose=0):
"""
Direct density estimation using an inner LCV loop to estimate the proper model. Can be used with sklearn
cross validation methods with or without storing the inner CV. To use a standard grid search.
max_iter : Number of iterations to perform
num_params : List of number of test set vectors used to construct the approximation for inner LCV.
Must be a float. Original paper used 10%, i.e. =.1
sigmas : List of sigmas to be used in inner LCV loop.
epsilon : Additive factor in the iterative algorithm for numerical stability.
"""
self.max_iter = max_iter
self.num_params = num_params
self.epsilon = epsilon
self.verbose = verbose
self.sigmas = sigmas
self.cv = cv
self.random_state = 0
def fit(self, X_train, X_test, alpha_0=None):
""" Uses cross validation to select sigma as in the original paper (LCV).
In a break from sklearn convention, y=X_test.
The parameter cv corresponds to R in the original paper.
Once found, the best sigma is used to train on the full set."""
# LCV loop, shuffle a copy in place for performance.
cv = self.cv
chunk = int(X_test.shape[0] / float(cv))
if self.random_state is not None:
np.random.seed(self.random_state)
X_test_shuffled = X_test.copy()
np.random.shuffle(X_test_shuffled)
j_scores = {}
if type(self.sigmas) != list:
self.sigmas = [self.sigmas]
if type(self.num_params) != list:
self.num_params = [self.num_params]
if len(self.sigmas) * len(self.num_params) > 1:
# Inner LCV loop
for num_param in self.num_params:
for sigma in self.sigmas:
j_scores[(num_param, sigma)] = np.zeros(cv)
for k in range(1, cv + 1):
if self.verbose > 0:
print('Training: sigma: %s R: %s' % (sigma, k))
X_test_fold = X_test_shuffled[(k - 1) * chunk:k * chunk, :]
j_scores[(num_param, sigma)][k - 1] = self._fit(X_train=X_train,
X_test=X_test_fold,
num_parameters=num_param,
sigma=sigma)
j_scores[(num_param, sigma)] = np.mean(j_scores[(num_param, sigma)])
sorted_scores = sorted([x for x in j_scores.items() if np.isfinite(x[1])], key=lambda x: x[1],
reverse=True)
if len(sorted_scores) == 0:
warnings.warn('LCV failed to converge for all values of sigma.')
return self
self._sigma = sorted_scores[0][0][1]
self._num_parameters = sorted_scores[0][0][0]
self._j_scores = sorted_scores
else:
self._sigma = self.sigmas[0]
self._num_parameters = self.num_params[0]
# best sigma
self._j = self._fit(X_train=X_train, X_test=X_test_shuffled, num_parameters=self._num_parameters,
sigma=self._sigma)
return self # Compatibility with sklearn
def _fit(self, X_train, X_test, num_parameters, sigma, alpha_0=None):
""" Fits the estimator with the given parameters w-hat and returns J"""
num_parameters = num_parameters
if type(num_parameters) == float:
num_parameters = int(X_test.shape[0] * num_parameters)
self._select_param_vectors(X_test=X_test,
sigma=sigma,
num_parameters=num_parameters)
X_train = self._reshape_X(X_train)
X_test = self._reshape_X(X_test)
if alpha_0 is None:
alpha_0 = np.ones(shape=(num_parameters, 1)) / float(num_parameters)
self._find_alpha(X_train=X_train,
X_test=X_test,
num_parameters=num_parameters,
epsilon=self.epsilon,
alpha_0=alpha_0,
sigma=sigma)
return self._calculate_j(X_test, sigma=sigma)
def _calculate_j(self, X_test, sigma):
pred = self.predict(X_test, sigma=sigma)+0.0000001
log = np.log(pred).sum()
return log / (X_test.shape[0])
def score(self, X_test):
""" Return the J score, similar to sklearn's API """
return self._calculate_j(X_test=X_test, sigma=self._sigma)
@staticmethod
def _reshape_X(X):
""" Reshape input from mxn to mx1xn to take advantage of numpy broadcasting. """
if len(X.shape) != 3:
return X.reshape((X.shape[0], 1, X.shape[1]))
return X
def _select_param_vectors(self, X_test, sigma, num_parameters):
""" X_test is the test set. b is the number of parameters. """
indices = np.random.choice(X_test.shape[0], size=num_parameters, replace=False)
self._test_vectors = X_test[indices, :].copy()
self._phi_fitted = True
def _phi(self, X, sigma=None):
if sigma is None:
sigma = self._sigma
if self._phi_fitted:
return np.exp(-np.sum((X - self._test_vectors) ** 2, axis=-1) / (2 * sigma ** 2))
raise Exception('Phi not fitted.')
def _find_alpha(self, alpha_0, X_train, X_test, num_parameters, sigma, epsilon):
A = np.zeros(shape=(X_test.shape[0], num_parameters))
b = np.zeros(shape=(num_parameters, 1))
A = self._phi(X_test, sigma)
b = self._phi(X_train, sigma).sum(axis=0) / X_train.shape[0]
b = b.reshape((num_parameters, 1))
out = alpha_0.copy()
for k in range(self.max_iter):
mat = np.dot(A, out)
mat += 0.000000001
out += epsilon * np.dot(np.transpose(A), 1. / mat)
out += b * (((1 - np.dot(np.transpose(b), out)) / np.dot(np.transpose(b), b)))
out = np.maximum(0, out)
out /= (np.dot(np.transpose(b), out))
self._alpha = out
self._fitted = True
def predict(self, X, sigma=None):
""" Equivalent of w(X) from the original paper."""
X = self._reshape_X(X)
if not self._fitted:
raise Exception('Not fitted!')
return np.dot(self._phi(X, sigma=sigma), self._alpha).reshape((X.shape[0],))

View File

@ -322,6 +322,22 @@ class LabelledCollection:
classes = np.unique(labels).sort()
return LabelledCollection(instances, labels, classes=classes)
def separate(self):
"""
Breaks down this labelled collection into a list of labelled collections such that each element in the list
contains all instances from a different class. The order in the list is consistent with the order in
`self.classes_`. If some class has 0 elements, then None will be returned in that position in the list.
:return: list `L` of :class:`LabelledCollection` with `len(L)==len(self.classes_)`
"""
lcs = []
for class_label in self.classes_:
instances = self.instances[self.labels == class_label]
n_instances = len(instances)
new_lc = LabelledCollection(instances, [class_label]*n_instances) if (n_instances > 0) else None
lcs.append(new_lc)
return lcs
@property
def Xy(self):
"""

View File

@ -223,7 +223,7 @@ def fetch_UCILabelledCollection(dataset_name, data_home=None, verbose=False) ->
>>> import quapy as qp
>>> collection = qp.datasets.fetch_UCILabelledCollection("yeast")
>>> for data in qp.data.Dataset.kFCV(collection, nfolds=5, nrepeats=2):
>>> for data in qp.domains.Dataset.kFCV(collection, nfolds=5, nrepeats=2):
>>> ...
The list of valid dataset names can be accessed in `quapy.data.datasets.UCI_DATASETS`

View File

@ -28,7 +28,7 @@ class QuaNetTrainer(BaseQuantifier):
>>>
>>> # load the kindle dataset as text, and convert words to numerical indexes
>>> dataset = qp.datasets.fetch_reviews('kindle', pickle=True)
>>> qp.data.preprocessing.index(dataset, min_df=5, inplace=True)
>>> qp.domains.preprocessing.index(dataset, min_df=5, inplace=True)
>>>
>>> # the text classifier is a CNN trained by NeuralClassifierTrainer
>>> cnn = CNNnet(dataset.vocabulary_size, dataset.n_classes)

View File

@ -218,7 +218,7 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
to "labelled_collection" to get instead instances of LabelledCollection
"""
def __init__(self, data:LabelledCollection, sample_size=None, n_prevalences=21, repeats=10,
def __init__(self, data: LabelledCollection, sample_size=None, n_prevalences=21, repeats=10,
smooth_limits_epsilon=0, random_state=0, return_type='sample_prev'):
super(APP, self).__init__(random_state)
self.data = data