Compare commits
47 Commits
|
|
@ -0,0 +1,3 @@
|
|||
- Add other methods that natively provide uncertainty quantification methods?
|
||||
- Explore neighbourhood in the CLR space instead than in the simplex!
|
||||
-
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
from sklearn.base import BaseEstimator
|
||||
import numpy as np
|
||||
from quapy.method._kdey import KDEBase
|
||||
from quapy.method.confidence import WithConfidenceABC, ConfidenceRegionABC, CLRtransformation
|
||||
from quapy.method.aggregative import AggregativeSoftQuantifier
|
||||
from tqdm import tqdm
|
||||
import quapy.functional as F
|
||||
|
||||
|
||||
class BayesianKDEy(AggregativeSoftQuantifier, KDEBase, WithConfidenceABC):
|
||||
"""
|
||||
`Bayesian version of KDEy.
|
||||
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a tuple `(X,y)` defining the specific set of data to use for validation. Set to
|
||||
None when the method does not require any validation data, in order to avoid that some portion of
|
||||
the training data be wasted.
|
||||
:param num_warmup: number of warmup iterations for the MCMC sampler (default 500)
|
||||
:param num_samples: number of samples to draw from the posterior (default 1000)
|
||||
:param mcmc_seed: random seed for the MCMC sampler (default 0)
|
||||
:param confidence_level: float in [0,1] to construct a confidence region around the point estimate (default 0.95)
|
||||
:param region: string, set to `intervals` for constructing confidence intervals (default), or to
|
||||
`ellipse` for constructing an ellipse in the probability simplex, or to `ellipse-clr` for
|
||||
constructing an ellipse in the Centered-Log Ratio (CLR) unconstrained space.
|
||||
:param verbose: bool, whether to display progress bar
|
||||
"""
|
||||
def __init__(self,
|
||||
classifier: BaseEstimator=None,
|
||||
fit_classifier=True,
|
||||
val_split: int = 5,
|
||||
kernel='gaussian',
|
||||
bandwidth=0.1,
|
||||
num_warmup: int = 500,
|
||||
num_samples: int = 1_000,
|
||||
mcmc_seed: int = 0,
|
||||
confidence_level: float = 0.95,
|
||||
region: str = 'intervals',
|
||||
explore_CLR=False,
|
||||
step_size=0.05,
|
||||
verbose: bool = False):
|
||||
|
||||
if num_warmup <= 0:
|
||||
raise ValueError(f'parameter {num_warmup=} must be a positive integer')
|
||||
if num_samples <= 0:
|
||||
raise ValueError(f'parameter {num_samples=} must be a positive integer')
|
||||
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
self.bandwidth = KDEBase._check_bandwidth(bandwidth, kernel)
|
||||
self.kernel = self._check_kernel(kernel)
|
||||
self.num_warmup = num_warmup
|
||||
self.num_samples = num_samples
|
||||
self.mcmc_seed = mcmc_seed
|
||||
self.confidence_level = confidence_level
|
||||
self.region = region
|
||||
self.explore_CLR = explore_CLR
|
||||
self.step_size = step_size
|
||||
self.verbose = verbose
|
||||
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
self.mix_densities = self.get_mixture_components(classif_predictions, labels, self.classes_, self.bandwidth, self.kernel)
|
||||
return self
|
||||
|
||||
def aggregate(self, classif_predictions):
|
||||
self.prevalence_samples = self._bayesian_kde(classif_predictions, init=None, verbose=self.verbose)
|
||||
return self.prevalence_samples.mean(axis=0)
|
||||
|
||||
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
||||
if confidence_level is None:
|
||||
confidence_level = self.confidence_level
|
||||
classif_predictions = self.classify(instances)
|
||||
point_estimate = self.aggregate(classif_predictions)
|
||||
samples = self.prevalence_samples # available after calling "aggregate" function
|
||||
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
|
||||
return point_estimate, region
|
||||
|
||||
def _bayesian_kde(self, X_probs, init=None, verbose=False):
|
||||
"""
|
||||
Bayes:
|
||||
P(prev|data) = P(data|prev) P(prev) / P(data)
|
||||
i.e.,
|
||||
posterior = likelihood * prior / evidence
|
||||
we assume the likelihood be:
|
||||
prev @ [kde_i_likelihood(data) 1..i..n]
|
||||
prior be uniform in simplex
|
||||
"""
|
||||
|
||||
rng = np.random.default_rng(self.mcmc_seed)
|
||||
kdes = self.mix_densities
|
||||
test_densities = np.asarray([self.pdf(kde_i, X_probs, self.kernel) for kde_i in kdes])
|
||||
|
||||
def log_likelihood(prev, epsilon=1e-10):
|
||||
test_likelihoods = prev @ test_densities
|
||||
test_loglikelihood = np.log(test_likelihoods + epsilon)
|
||||
return np.sum(test_loglikelihood)
|
||||
|
||||
# def log_prior(prev):
|
||||
# todo: adapt to arbitrary prior knowledge (e.g., something around training prevalence)
|
||||
# return 1/np.sum((prev-init)**2) # it is not 1 but we assume uniform, son anyway is an useless constant
|
||||
|
||||
# def log_prior(prev, alpha_scale=1000):
|
||||
# alpha = np.array(init) * alpha_scale
|
||||
# return dirichlet.logpdf(prev, alpha)
|
||||
|
||||
def log_prior(prev):
|
||||
return 0
|
||||
|
||||
def sample_neighbour(prev, step_size):
|
||||
# random-walk Metropolis-Hastings
|
||||
d = len(prev)
|
||||
if not self.explore_CLR:
|
||||
dir_noise = rng.normal(scale=step_size/np.sqrt(d), size=d)
|
||||
neighbour = F.normalize_prevalence(prev + dir_noise, method='mapsimplex')
|
||||
else:
|
||||
clr = CLRtransformation()
|
||||
clr_point = clr(prev)
|
||||
dir_noise = rng.normal(scale=step_size, size=d)
|
||||
clr_neighbour = clr_point+dir_noise
|
||||
neighbour = clr.inverse(clr_neighbour)
|
||||
assert in_simplex(neighbour), 'wrong CLR transformation'
|
||||
return neighbour
|
||||
|
||||
n_classes = X_probs.shape[1]
|
||||
current_prev = F.uniform_prevalence(n_classes) if init is None else init
|
||||
current_likelihood = log_likelihood(current_prev) + log_prior(current_prev)
|
||||
|
||||
# Metropolis-Hastings with adaptive rate
|
||||
step_size = self.step_size
|
||||
target_acceptance = 0.3
|
||||
adapt_rate = 0.05
|
||||
acceptance_history = []
|
||||
|
||||
samples = []
|
||||
total_steps = self.num_samples + self.num_warmup
|
||||
for i in tqdm(range(total_steps), total=total_steps, disable=not verbose):
|
||||
proposed_prev = sample_neighbour(current_prev, step_size)
|
||||
|
||||
# probability of acceptance
|
||||
proposed_likelihood = log_likelihood(proposed_prev) + log_prior(proposed_prev)
|
||||
acceptance = proposed_likelihood - current_likelihood
|
||||
|
||||
# decide acceptance
|
||||
accepted = np.log(rng.random()) < acceptance
|
||||
if accepted:
|
||||
current_prev = proposed_prev
|
||||
current_likelihood = proposed_likelihood
|
||||
|
||||
samples.append(current_prev)
|
||||
acceptance_history.append(1. if accepted else 0.)
|
||||
|
||||
if i < self.num_warmup and i%10==0 and len(acceptance_history)>=100:
|
||||
recent_accept_rate = np.mean(acceptance_history[-100:])
|
||||
step_size *= np.exp(adapt_rate * (recent_accept_rate - target_acceptance))
|
||||
# step_size = float(np.clip(step_size, min_step, max_step))
|
||||
print(f'acceptance-rate={recent_accept_rate*100:.3f}%, step-size={step_size:.5f}')
|
||||
|
||||
# remove "warmup" initial iterations
|
||||
samples = np.asarray(samples[self.num_warmup:])
|
||||
return samples
|
||||
|
||||
|
||||
def in_simplex(x):
|
||||
return np.all(x >= 0) and np.isclose(x.sum(), 1)
|
||||
|
|
@ -0,0 +1,186 @@
|
|||
import os
|
||||
import warnings
|
||||
from os.path import join
|
||||
from pathlib import Path
|
||||
|
||||
from sklearn.calibration import CalibratedClassifierCV
|
||||
from sklearn.linear_model import LogisticRegression as LR
|
||||
from sklearn.model_selection import GridSearchCV, StratifiedKFold
|
||||
from copy import deepcopy as cp
|
||||
import quapy as qp
|
||||
from BayesianKDEy._bayeisan_kdey import BayesianKDEy
|
||||
from build.lib.quapy.data import LabelledCollection
|
||||
from quapy.method.aggregative import DistributionMatchingY as DMy, AggregativeQuantifier
|
||||
from quapy.method.base import BinaryQuantifier, BaseQuantifier
|
||||
from quapy.model_selection import GridSearchQ
|
||||
from quapy.data import Dataset
|
||||
# from BayesianKDEy.plot_simplex import plot_prev_points, plot_prev_points_matplot
|
||||
from quapy.method.confidence import ConfidenceIntervals, BayesianCC, PQ, WithConfidenceABC, AggregativeBootstrap
|
||||
from quapy.functional import strprev
|
||||
from quapy.method.aggregative import KDEyML, ACC
|
||||
from quapy.protocol import UPP
|
||||
import quapy.functional as F
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from scipy.stats import dirichlet
|
||||
from collections import defaultdict
|
||||
from time import time
|
||||
from sklearn.base import clone, BaseEstimator
|
||||
|
||||
|
||||
class KDEyCLR(KDEyML):
|
||||
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5, bandwidth=1., random_state=None):
|
||||
super().__init__(
|
||||
classifier=classifier, fit_classifier=fit_classifier, val_split=val_split, bandwidth=bandwidth,
|
||||
random_state=random_state, kernel='aitchison'
|
||||
)
|
||||
|
||||
def methods__():
|
||||
acc_hyper = {}
|
||||
hdy_hyper = {'nbins': [3,4,5,8,16,32]}
|
||||
kdey_hyper = {'bandwidth': [0.001, 0.005, 0.01, 0.05, 0.1, 0.2], 'classifier__C':[1]}
|
||||
wrap_hyper = lambda dic: {f'quantifier__{k}':v for k,v in dic.items()}
|
||||
# yield 'BootstrapACC', AggregativeBootstrap(ACC(LR()), n_test_samples=1000, random_state=0), wrap_hyper(acc_hyper)
|
||||
# yield 'BootstrapHDy', AggregativeBootstrap(DMy(LR(), divergence='HD'), n_test_samples=1000, random_state=0), wrap_hyper(hdy_hyper)
|
||||
yield 'BootstrapKDEy', AggregativeBootstrap(KDEyML(LR()), n_test_samples=1000, random_state=0), wrap_hyper(kdey_hyper)
|
||||
# yield 'BayesianACC', BayesianCC(LR(), mcmc_seed=0), acc_hyper
|
||||
# yield 'BayesianHDy', PQ(LR(), stan_seed=0), hdy_hyper
|
||||
# yield 'BayesianKDEy', BayesianKDEy(LR(), mcmc_seed=0), kdey_hyper
|
||||
|
||||
|
||||
def methods():
|
||||
"""
|
||||
Returns a tuple (name, quantifier, hyperparams, bayesian/bootstrap_constructor), where:
|
||||
- name: is a str representing the name of the method (e.g., 'BayesianKDEy')
|
||||
- quantifier: is the base model (e.g., KDEyML())
|
||||
- hyperparams: is a dictionary for the quantifier (e.g., {'bandwidth': [0.001, 0.005, 0.01, 0.05, 0.1, 0.2]})
|
||||
- bayesian/bootstrap_constructor: is a function that instantiates the bayesian o bootstrap method with the
|
||||
quantifier with optimized hyperparameters
|
||||
"""
|
||||
acc_hyper = {}
|
||||
hdy_hyper = {'nbins': [3,4,5,8,16,32]}
|
||||
kdey_hyper = {'bandwidth': [0.001, 0.005, 0.01, 0.05, 0.1, 0.2]}
|
||||
kdey_hyper_clr = {'bandwidth': [0.05, 0.1, 0.5, 1., 2., 5.]}
|
||||
|
||||
yield 'BootstrapACC', ACC(LR()), acc_hyper, lambda hyper: AggregativeBootstrap(ACC(LR()), n_test_samples=1000, random_state=0),
|
||||
yield 'BayesianACC', ACC(LR()), acc_hyper, lambda hyper: BayesianCC(LR(), mcmc_seed=0)
|
||||
|
||||
yield 'BootstrapHDy', DMy(LR()), hdy_hyper, lambda hyper: AggregativeBootstrap(DMy(LR(), **hyper), n_test_samples=1000, random_state=0),
|
||||
|
||||
yield 'BootstrapKDEy', KDEyML(LR()), kdey_hyper, lambda hyper: AggregativeBootstrap(KDEyML(LR(), **hyper), n_test_samples=1000, random_state=0, verbose=True),
|
||||
yield 'BayesianKDEy', KDEyML(LR()), kdey_hyper, lambda hyper: BayesianKDEy(mcmc_seed=0, **hyper),
|
||||
yield 'BayesianKDEy*', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0, **hyper),
|
||||
|
||||
|
||||
def model_selection(train: LabelledCollection, point_quantifier: AggregativeQuantifier, grid: dict):
|
||||
with qp.util.temp_seed(0):
|
||||
print(f'performing model selection for {point_quantifier.__class__.__name__} with grid {grid}')
|
||||
# model selection
|
||||
if len(grid)>0:
|
||||
train, val = train.split_stratified(train_prop=0.6, random_state=0)
|
||||
mod_sel = GridSearchQ(
|
||||
model=point_quantifier,
|
||||
param_grid=grid,
|
||||
protocol=qp.protocol.UPP(val, repeats=250, random_state=0),
|
||||
refit=False,
|
||||
n_jobs=-1,
|
||||
verbose=True
|
||||
).fit(*train.Xy)
|
||||
best_params = mod_sel.best_params_
|
||||
else:
|
||||
best_params = {}
|
||||
|
||||
return best_params
|
||||
|
||||
|
||||
def experiment(dataset: Dataset, point_quantifier: AggregativeQuantifier, method_name:str, grid: dict, withconf_constructor, hyper_choice_path: Path):
|
||||
with qp.util.temp_seed(0):
|
||||
|
||||
training, test = dataset.train_test
|
||||
|
||||
# model selection
|
||||
best_hyperparams = qp.util.pickled_resource(
|
||||
hyper_choice_path, model_selection, training, cp(point_quantifier), grid
|
||||
)
|
||||
|
||||
t_init = time()
|
||||
withconf_quantifier = withconf_constructor(best_hyperparams).fit(*training.Xy)
|
||||
tr_time = time() - t_init
|
||||
|
||||
# test
|
||||
train_prevalence = training.prevalence()
|
||||
results = defaultdict(list)
|
||||
test_generator = UPP(test, repeats=100, random_state=0)
|
||||
for i, (sample_X, true_prevalence) in tqdm(enumerate(test_generator()), total=test_generator.total(), desc=f'{method_name} predictions'):
|
||||
t_init = time()
|
||||
point_estimate, region = withconf_quantifier.predict_conf(sample_X)
|
||||
ttime = time()-t_init
|
||||
results['true-prevs'].append(true_prevalence)
|
||||
results['point-estim'].append(point_estimate)
|
||||
results['shift'].append(qp.error.ae(true_prevalence, train_prevalence))
|
||||
results['ae'].append(qp.error.ae(prevs_true=true_prevalence, prevs_hat=point_estimate))
|
||||
results['rae'].append(qp.error.rae(prevs_true=true_prevalence, prevs_hat=point_estimate))
|
||||
results['coverage'].append(region.coverage(true_prevalence))
|
||||
results['amplitude'].append(region.montecarlo_proportion(n_trials=50_000))
|
||||
results['test-time'].append(ttime)
|
||||
results['samples'].append(region.samples)
|
||||
|
||||
report = {
|
||||
'optim_hyper': best_hyperparams,
|
||||
'train_time': tr_time,
|
||||
'train-prev': train_prevalence,
|
||||
'results': {k:np.asarray(v) for k,v in results.items()}
|
||||
}
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def experiment_path(dir:Path, dataset_name:str, method_name:str):
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
return dir/f'{dataset_name}__{method_name}.pkl'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
binary = {
|
||||
'datasets': qp.datasets.UCI_BINARY_DATASETS,
|
||||
'fetch_fn': qp.datasets.fetch_UCIBinaryDataset,
|
||||
'sample_size': 500
|
||||
}
|
||||
|
||||
multiclass = {
|
||||
'datasets': qp.datasets.UCI_MULTICLASS_DATASETS,
|
||||
'fetch_fn': qp.datasets.fetch_UCIMulticlassDataset,
|
||||
'sample_size': 1000
|
||||
}
|
||||
|
||||
result_dir = Path('./results')
|
||||
|
||||
for setup in [binary, multiclass]: # [binary, multiclass]:
|
||||
qp.environ['SAMPLE_SIZE'] = setup['sample_size']
|
||||
for data_name in setup['datasets']:
|
||||
print(f'dataset={data_name}')
|
||||
# if data_name=='breast-cancer' or data_name.startswith("cmc") or data_name.startswith("ctg"):
|
||||
# print(f'skipping dataset: {data_name}')
|
||||
# continue
|
||||
data = setup['fetch_fn'](data_name)
|
||||
is_binary = data.n_classes==2
|
||||
result_subdir = result_dir / ('binary' if is_binary else 'multiclass')
|
||||
hyper_subdir = result_dir / 'hyperparams' / ('binary' if is_binary else 'multiclass')
|
||||
for method_name, method, hyper_params, withconf_constructor in methods():
|
||||
if isinstance(method, BinaryQuantifier) and not is_binary:
|
||||
continue
|
||||
result_path = experiment_path(result_subdir, data_name, method_name)
|
||||
hyper_path = experiment_path(hyper_subdir, data_name, method.__class__.__name__)
|
||||
report = qp.util.pickled_resource(
|
||||
result_path, experiment, data, method, method_name, hyper_params, withconf_constructor, hyper_path
|
||||
)
|
||||
print(f'dataset={data_name}, '
|
||||
f'method={method_name}: '
|
||||
f'mae={report["results"]["ae"].mean():.3f}, '
|
||||
f'coverage={report["results"]["coverage"].mean():.5f}, '
|
||||
f'amplitude={report["results"]["amplitude"].mean():.5f}, ')
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
import pickle
|
||||
from collections import defaultdict
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
import quapy as qp
|
||||
from quapy.method.confidence import ConfidenceEllipseSimplex, ConfidenceEllipseCLR
|
||||
|
||||
pd.set_option('display.max_columns', None)
|
||||
pd.set_option('display.width', 2000)
|
||||
pd.set_option('display.max_rows', None)
|
||||
pd.set_option("display.expand_frame_repr", False)
|
||||
pd.set_option("display.precision", 4)
|
||||
pd.set_option("display.float_format", "{:.4f}".format)
|
||||
|
||||
|
||||
def compute_coverage_amplitude(region_constructor):
|
||||
all_samples = results['samples']
|
||||
all_true_prevs = results['true-prevs']
|
||||
|
||||
def process_one(samples, true_prevs):
|
||||
ellipse = region_constructor(samples)
|
||||
return ellipse.coverage(true_prevs), ellipse.montecarlo_proportion()
|
||||
|
||||
out = Parallel(n_jobs=3)(
|
||||
delayed(process_one)(samples, true_prevs)
|
||||
for samples, true_prevs in tqdm(
|
||||
zip(all_samples, all_true_prevs),
|
||||
total=len(all_samples),
|
||||
desc='constructing ellipses'
|
||||
)
|
||||
)
|
||||
|
||||
# unzip results
|
||||
coverage, amplitude = zip(*out)
|
||||
return list(coverage), list(amplitude)
|
||||
|
||||
|
||||
def update_pickle(report, pickle_path, updated_dict:dict):
|
||||
for k,v in updated_dict.items():
|
||||
report[k]=v
|
||||
pickle.dump(report, open(pickle_path, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
for setup in ['binary', 'multiclass']:
|
||||
path = f'./results/{setup}/*.pkl'
|
||||
table = defaultdict(list)
|
||||
for file in tqdm(glob(path), desc='processing results', total=len(glob(path))):
|
||||
file = Path(file)
|
||||
dataset, method = file.name.replace('.pkl', '').split('__')
|
||||
report = pickle.load(open(file, 'rb'))
|
||||
results = report['results']
|
||||
n_samples = len(results['ae'])
|
||||
table['method'].extend([method.replace('Bayesian','Ba').replace('Bootstrap', 'Bo')] * n_samples)
|
||||
table['dataset'].extend([dataset] * n_samples)
|
||||
table['ae'].extend(results['ae'])
|
||||
table['c-CI'].extend(results['coverage'])
|
||||
table['a-CI'].extend(results['amplitude'])
|
||||
|
||||
if 'coverage-CE' not in report:
|
||||
covCE, ampCE = compute_coverage_amplitude(ConfidenceEllipseSimplex)
|
||||
covCLR, ampCLR = compute_coverage_amplitude(ConfidenceEllipseCLR)
|
||||
|
||||
update_fields = {
|
||||
'coverage-CE': covCE,
|
||||
'amplitude-CE': ampCE,
|
||||
'coverage-CLR': covCLR,
|
||||
'amplitude-CLR': ampCLR
|
||||
}
|
||||
|
||||
update_pickle(report, file, update_fields)
|
||||
|
||||
table['c-CE'].extend(report['coverage-CE'])
|
||||
table['a-CE'].extend(report['amplitude-CE'])
|
||||
|
||||
table['c-CLR'].extend(report['coverage-CLR'])
|
||||
table['a-CLR'].extend(report['amplitude-CLR'])
|
||||
|
||||
|
||||
df = pd.DataFrame(table)
|
||||
|
||||
n_classes = {}
|
||||
tr_size = {}
|
||||
for dataset in df['dataset'].unique():
|
||||
fetch_fn = {
|
||||
'binary': qp.datasets.fetch_UCIBinaryDataset,
|
||||
'multiclass': qp.datasets.fetch_UCIMulticlassDataset
|
||||
}[setup]
|
||||
data = fetch_fn(dataset)
|
||||
n_classes[dataset] = data.n_classes
|
||||
tr_size[dataset] = len(data.training)
|
||||
|
||||
# remove datasets with more than max_classes classes
|
||||
max_classes = 30
|
||||
for data_name, n in n_classes.items():
|
||||
if n > max_classes:
|
||||
df = df[df["dataset"] != data_name]
|
||||
|
||||
for region in ['CI', 'CE', 'CLR']:
|
||||
pv = pd.pivot_table(
|
||||
df, index='dataset', columns='method', values=['ae', f'c-{region}', f'a-{region}'], margins=True
|
||||
)
|
||||
pv['n_classes'] = pv.index.map(n_classes).astype('Int64')
|
||||
pv['tr_size'] = pv.index.map(tr_size).astype('Int64')
|
||||
pv = pv.drop(columns=[col for col in pv.columns if col[-1] == "All"])
|
||||
print(f'{setup=}')
|
||||
print(pv)
|
||||
print('-'*80)
|
||||
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.stats import gaussian_kde
|
||||
|
||||
|
||||
def plot_prev_points(prevs, true_prev, point_estim, train_prev):
|
||||
plt.rcParams.update({
|
||||
'font.size': 10, # tamaño base de todo el texto
|
||||
'axes.titlesize': 12, # título del eje
|
||||
'axes.labelsize': 10, # etiquetas de ejes
|
||||
'xtick.labelsize': 8, # etiquetas de ticks
|
||||
'ytick.labelsize': 8,
|
||||
'legend.fontsize': 9, # leyenda
|
||||
})
|
||||
|
||||
def cartesian(p):
|
||||
dim = p.shape[-1]
|
||||
p = p.reshape(-1,dim)
|
||||
x = p[:, 1] + p[:, 2] * 0.5
|
||||
y = p[:, 2] * np.sqrt(3) / 2
|
||||
return x, y
|
||||
|
||||
# simplex coordinates
|
||||
v1 = np.array([0, 0])
|
||||
v2 = np.array([1, 0])
|
||||
v3 = np.array([0.5, np.sqrt(3)/2])
|
||||
|
||||
# Plot
|
||||
fig, ax = plt.subplots(figsize=(6, 6))
|
||||
ax.scatter(*cartesian(prevs), s=10, alpha=0.5, edgecolors='none', label='samples')
|
||||
ax.scatter(*cartesian(prevs.mean(axis=0)), s=10, alpha=1, label='sample-mean', edgecolors='black')
|
||||
ax.scatter(*cartesian(true_prev), s=10, alpha=1, label='true-prev', edgecolors='black')
|
||||
ax.scatter(*cartesian(point_estim), s=10, alpha=1, label='KDEy-estim', edgecolors='black')
|
||||
ax.scatter(*cartesian(train_prev), s=10, alpha=1, label='train-prev', edgecolors='black')
|
||||
|
||||
# edges
|
||||
triangle = np.array([v1, v2, v3, v1])
|
||||
ax.plot(triangle[:, 0], triangle[:, 1], color='black')
|
||||
|
||||
# vertex labels
|
||||
ax.text(-0.05, -0.05, "y=0", ha='right', va='top')
|
||||
ax.text(1.05, -0.05, "y=1", ha='left', va='top')
|
||||
ax.text(0.5, np.sqrt(3)/2 + 0.05, "y=2", ha='center', va='bottom')
|
||||
|
||||
ax.set_aspect('equal')
|
||||
ax.axis('off')
|
||||
plt.legend(
|
||||
loc='center left',
|
||||
bbox_to_anchor=(1.05, 0.5),
|
||||
# ncol=3,
|
||||
# frameon=False
|
||||
)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_prev_points_matplot(points):
|
||||
|
||||
# project 2D
|
||||
v1 = np.array([0, 0])
|
||||
v2 = np.array([1, 0])
|
||||
v3 = np.array([0.5, np.sqrt(3) / 2])
|
||||
x = points[:, 1] + points[:, 2] * 0.5
|
||||
y = points[:, 2] * np.sqrt(3) / 2
|
||||
|
||||
# kde
|
||||
xy = np.vstack([x, y])
|
||||
kde = gaussian_kde(xy, bw_method=0.25)
|
||||
xmin, xmax = 0, 1
|
||||
ymin, ymax = 0, np.sqrt(3) / 2
|
||||
|
||||
# grid
|
||||
xx, yy = np.mgrid[xmin:xmax:200j, ymin:ymax:200j]
|
||||
positions = np.vstack([xx.ravel(), yy.ravel()])
|
||||
zz = np.reshape(kde(positions).T, xx.shape)
|
||||
|
||||
# mask points in simplex
|
||||
def in_triangle(x, y):
|
||||
return (y >= 0) & (y <= np.sqrt(3) * np.minimum(x, 1 - x))
|
||||
|
||||
mask = in_triangle(xx, yy)
|
||||
zz_masked = np.ma.array(zz, mask=~mask)
|
||||
|
||||
# plot
|
||||
fig, ax = plt.subplots(figsize=(6, 6))
|
||||
ax.imshow(
|
||||
np.rot90(zz_masked),
|
||||
cmap=plt.cm.viridis,
|
||||
extent=[xmin, xmax, ymin, ymax],
|
||||
alpha=0.8,
|
||||
)
|
||||
|
||||
# Bordes del triángulo
|
||||
triangle = np.array([v1, v2, v3, v1])
|
||||
ax.plot(triangle[:, 0], triangle[:, 1], color='black', lw=2)
|
||||
|
||||
# Puntos (opcional)
|
||||
ax.scatter(x, y, s=5, c='white', alpha=0.3)
|
||||
|
||||
# Etiquetas
|
||||
ax.text(-0.05, -0.05, "A (1,0,0)", ha='right', va='top')
|
||||
ax.text(1.05, -0.05, "B (0,1,0)", ha='left', va='top')
|
||||
ax.text(0.5, np.sqrt(3) / 2 + 0.05, "C (0,0,1)", ha='center', va='bottom')
|
||||
|
||||
ax.set_aspect('equal')
|
||||
ax.axis('off')
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
n = 1000
|
||||
points = np.random.dirichlet([2, 3, 4], size=n)
|
||||
plot_prev_points_matplot(points)
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
import warnings
|
||||
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
import quapy as qp
|
||||
from BayesianKDEy._bayeisan_kdey import BayesianKDEy
|
||||
from BayesianKDEy.plot_simplex import plot_prev_points, plot_prev_points_matplot
|
||||
from method.confidence import ConfidenceIntervals
|
||||
from quapy.functional import strprev
|
||||
from quapy.method.aggregative import KDEyML
|
||||
from quapy.protocol import UPP
|
||||
import quapy.functional as F
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from scipy.stats import dirichlet
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
qp.environ["SAMPLE_SIZE"] = 500
|
||||
cls = LogisticRegression()
|
||||
bayes_kdey = BayesianKDEy(cls, bandwidth=.3, kernel='aitchison', mcmc_seed=0)
|
||||
|
||||
datasets = qp.datasets.UCI_BINARY_DATASETS
|
||||
train, test = qp.datasets.fetch_UCIBinaryDataset(datasets[0]).train_test
|
||||
|
||||
# train, test = qp.datasets.fetch_UCIMulticlassDataset('academic-success', standardize=True).train_test
|
||||
|
||||
with qp.util.temp_seed(0):
|
||||
print('fitting KDEy')
|
||||
bayes_kdey.fit(*train.Xy)
|
||||
|
||||
shifted = test.sampling(500, *[0.2, 0.8])
|
||||
# shifted = test.sampling(500, *test.prevalence()[::-1])
|
||||
# shifted = test.sampling(500, *F.uniform_prevalence_sampling(train.n_classes))
|
||||
prev_hat = bayes_kdey.predict(shifted.X)
|
||||
mae = qp.error.mae(shifted.prevalence(), prev_hat)
|
||||
print(f'true_prev={strprev(shifted.prevalence())}')
|
||||
print(f'prev_hat={strprev(prev_hat)}, {mae=:.4f}')
|
||||
|
||||
prev_hat, conf_interval = bayes_kdey.predict_conf(shifted.X)
|
||||
|
||||
mae = qp.error.mae(shifted.prevalence(), prev_hat)
|
||||
print(f'mean posterior {strprev(prev_hat)}, {mae=:.4f}')
|
||||
print(f'CI={conf_interval}')
|
||||
print(f'\tcontains true={conf_interval.coverage(true_value=shifted.prevalence())==1}')
|
||||
print(f'\tamplitude={conf_interval.montecarlo_proportion(50_000)*100.:.3f}%')
|
||||
|
||||
if train.n_classes == 3:
|
||||
plot_prev_points(bayes_kdey.prevalence_samples, true_prev=shifted.prevalence(), point_estim=prev_hat, train_prev=train.prevalence())
|
||||
# plot_prev_points_matplot(samples)
|
||||
|
||||
# report = qp.evaluation.evaluation_report(kdey, protocol=UPP(test), verbose=True)
|
||||
# print(report.mean(numeric_only=True))
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
import os
|
||||
import warnings
|
||||
from os.path import join
|
||||
from pathlib import Path
|
||||
|
||||
from sklearn.calibration import CalibratedClassifierCV
|
||||
from sklearn.linear_model import LogisticRegression as LR
|
||||
from sklearn.model_selection import GridSearchCV, StratifiedKFold
|
||||
from copy import deepcopy as cp
|
||||
import quapy as qp
|
||||
from BayesianKDEy._bayeisan_kdey import BayesianKDEy
|
||||
from BayesianKDEy.full_experiments import experiment, experiment_path, KDEyCLR
|
||||
from build.lib.quapy.data import LabelledCollection
|
||||
from quapy.method.aggregative import DistributionMatchingY as DMy, AggregativeQuantifier
|
||||
from quapy.method.base import BinaryQuantifier, BaseQuantifier
|
||||
from quapy.model_selection import GridSearchQ
|
||||
from quapy.data import Dataset
|
||||
# from BayesianKDEy.plot_simplex import plot_prev_points, plot_prev_points_matplot
|
||||
from quapy.method.confidence import ConfidenceIntervals, BayesianCC, PQ, WithConfidenceABC, AggregativeBootstrap
|
||||
from quapy.functional import strprev
|
||||
from quapy.method.aggregative import KDEyML, ACC
|
||||
from quapy.protocol import UPP
|
||||
import quapy.functional as F
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from scipy.stats import dirichlet
|
||||
from collections import defaultdict
|
||||
from time import time
|
||||
from sklearn.base import clone, BaseEstimator
|
||||
|
||||
|
||||
def method():
|
||||
"""
|
||||
Returns a tuple (name, quantifier, hyperparams, bayesian/bootstrap_constructor), where:
|
||||
- name: is a str representing the name of the method (e.g., 'BayesianKDEy')
|
||||
- quantifier: is the base model (e.g., KDEyML())
|
||||
- hyperparams: is a dictionary for the quantifier (e.g., {'bandwidth': [0.001, 0.005, 0.01, 0.05, 0.1, 0.2]})
|
||||
- bayesian/bootstrap_constructor: is a function that instantiates the bayesian o bootstrap method with the
|
||||
quantifier with optimized hyperparameters
|
||||
"""
|
||||
acc_hyper = {}
|
||||
hdy_hyper = {'nbins': [3,4,5,8,16,32]}
|
||||
kdey_hyper = {'bandwidth': [0.001, 0.005, 0.01, 0.05, 0.1, 0.2]}
|
||||
kdey_hyper_clr = {'bandwidth': [0.05, 0.1, 0.5, 1., 2., 5.]}
|
||||
|
||||
wrap_hyper = lambda dic: {f'quantifier__{k}':v for k,v in dic.items()}
|
||||
|
||||
# yield 'BootstrapKDEy', KDEyML(LR()), kdey_hyper, lambda hyper: AggregativeBootstrap(KDEyML(LR(), **hyper), n_test_samples=1000, random_state=0, verbose=True),
|
||||
# yield 'BayesianKDEy', KDEyML(LR()), kdey_hyper, lambda hyper: BayesianKDEy(mcmc_seed=0, **hyper),
|
||||
return 'BayKDE*CLR', KDEyCLR(LR()), kdey_hyper_clr, lambda hyper: BayesianKDEy(kernel='aitchison', mcmc_seed=0,
|
||||
explore_CLR=True,
|
||||
step_size=.15,
|
||||
# num_warmup = 5000,
|
||||
# num_samples = 10_000,
|
||||
# region='ellipse',
|
||||
**hyper),
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
binary = {
|
||||
'datasets': qp.datasets.UCI_BINARY_DATASETS,
|
||||
'fetch_fn': qp.datasets.fetch_UCIBinaryDataset,
|
||||
'sample_size': 500
|
||||
}
|
||||
|
||||
multiclass = {
|
||||
'datasets': qp.datasets.UCI_MULTICLASS_DATASETS,
|
||||
'fetch_fn': qp.datasets.fetch_UCIMulticlassDataset,
|
||||
'sample_size': 1000
|
||||
}
|
||||
|
||||
result_dir = Path('./results')
|
||||
|
||||
setup = multiclass
|
||||
qp.environ['SAMPLE_SIZE'] = setup['sample_size']
|
||||
data_name = 'digits'
|
||||
print(f'dataset={data_name}')
|
||||
data = setup['fetch_fn'](data_name)
|
||||
is_binary = data.n_classes==2
|
||||
hyper_subdir = result_dir / 'hyperparams' / ('binary' if is_binary else 'multiclass')
|
||||
method_name, method, hyper_params, withconf_constructor = method()
|
||||
hyper_path = experiment_path(hyper_subdir, data_name, method.__class__.__name__)
|
||||
report = experiment(data, method, method_name, hyper_params, withconf_constructor, hyper_path)
|
||||
|
||||
print(f'dataset={data_name}, '
|
||||
f'method={method_name}: '
|
||||
f'mae={report["results"]["ae"].mean():.3f}, '
|
||||
f'coverage={report["results"]["coverage"].mean():.5f}, '
|
||||
f'amplitude={report["results"]["amplitude"].mean():.5f}, ')
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,7 +1,12 @@
|
|||
Change Log 0.2.0
|
||||
Change Log 0.2.1
|
||||
-----------------
|
||||
|
||||
CLEAN TODO-FILE
|
||||
- Improved documentation of confidence regions.
|
||||
- Added ReadMe method by Daniel Hopkins and Gary King
|
||||
- Internal index in LabelledCollection is now "lazy", and is only constructed if required.
|
||||
|
||||
Change Log 0.2.0
|
||||
-----------------
|
||||
|
||||
- Base code Refactor:
|
||||
- Removing coupling between LabelledCollection and quantification methods; the fit interface changes:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,56 @@
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from quapy.method.aggregative import EMQ, KDEyML, PACC
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
METHODS = ['PACC',
|
||||
'EMQ',
|
||||
'KDEy-ML',
|
||||
'KDEy-MLA'
|
||||
]
|
||||
|
||||
|
||||
# common hyperparameterss
|
||||
hyper_LR = {
|
||||
'classifier__C': np.logspace(-3, 3, 7),
|
||||
'classifier__class_weight': ['balanced', None]
|
||||
}
|
||||
|
||||
hyper_kde = {
|
||||
'bandwidth': np.linspace(0.001, 0.5, 100)
|
||||
}
|
||||
|
||||
hyper_kde_aitchison = {
|
||||
'bandwidth': np.linspace(0.01, 2, 100)
|
||||
}
|
||||
|
||||
# instances a new quantifier based on a string name
|
||||
def new_method(method, **lr_kwargs):
|
||||
lr = LogisticRegression(**lr_kwargs)
|
||||
|
||||
if method == 'KDEy-ML':
|
||||
param_grid = {**hyper_kde, **hyper_LR}
|
||||
quantifier = KDEyML(lr, kernel='gaussian')
|
||||
elif method == 'KDEy-MLA':
|
||||
param_grid = {**hyper_kde_aitchison, **hyper_LR}
|
||||
quantifier = KDEyML(lr, kernel='aitchison')
|
||||
elif method == 'EMQ':
|
||||
param_grid = hyper_LR
|
||||
quantifier = EMQ(lr)
|
||||
elif method == 'PACC':
|
||||
param_grid = hyper_LR
|
||||
quantifier = PACC(lr)
|
||||
else:
|
||||
raise NotImplementedError('unknown method', method)
|
||||
|
||||
return param_grid, quantifier
|
||||
|
||||
|
||||
def show_results(result_path):
|
||||
df = pd.read_csv(result_path+'.csv', sep='\t')
|
||||
|
||||
pd.set_option('display.max_columns', None)
|
||||
pd.set_option('display.max_rows', None)
|
||||
pv = df.pivot_table(index='Dataset', columns="Method", values=["MAE", "MRAE"])
|
||||
print(pv)
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
import pickle
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import quapy as qp
|
||||
from quapy.model_selection import GridSearchQ
|
||||
from quapy.protocol import UPP
|
||||
from commons import METHODS, new_method, show_results
|
||||
from new_table import LatexTable
|
||||
|
||||
|
||||
SEED = 1
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(qp.datasets.UCI_MULTICLASS_DATASETS)
|
||||
for optim in ['mae', 'mrae']:
|
||||
table = LatexTable()
|
||||
result_dir = f'results/ucimulti/{optim}'
|
||||
|
||||
for method in METHODS:
|
||||
print()
|
||||
global_result_path = f'{result_dir}/{method}'
|
||||
print(f'Method\tDataset\tMAE\tMRAE\tKLD')
|
||||
for dataset in qp.datasets.UCI_MULTICLASS_DATASETS:
|
||||
# print(dataset)
|
||||
local_result_path = global_result_path + '_' + dataset
|
||||
if os.path.exists(local_result_path + '.dataframe'):
|
||||
report = pd.read_csv(local_result_path+'.dataframe')
|
||||
print(f'{method}\t{dataset}\t{report[optim].mean():.5f}')
|
||||
table.add(benchmark=dataset, method=method, v=report[optim].values)
|
||||
else:
|
||||
print(dataset, 'not found for method', method)
|
||||
table.latexPDF(f'./tables/{optim}.pdf', landscape=False)
|
||||
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
import pickle
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import quapy as qp
|
||||
from quapy.model_selection import GridSearchQ
|
||||
from quapy.protocol import UPP
|
||||
from commons import METHODS, new_method, show_results
|
||||
|
||||
|
||||
SEED = 1
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
qp.environ['SAMPLE_SIZE'] = 500
|
||||
qp.environ['N_JOBS'] = -1
|
||||
n_bags_val = 250
|
||||
n_bags_test = 1000
|
||||
for optim in ['mae', 'mrae']:
|
||||
result_dir = f'results/ucimulti/{optim}'
|
||||
|
||||
os.makedirs(result_dir, exist_ok=True)
|
||||
|
||||
for method in METHODS:
|
||||
|
||||
print('Init method', method)
|
||||
|
||||
global_result_path = f'{result_dir}/{method}'
|
||||
# show_results(global_result_path)
|
||||
# sys.exit(0)
|
||||
|
||||
if not os.path.exists(global_result_path + '.csv'):
|
||||
with open(global_result_path + '.csv', 'wt') as csv:
|
||||
csv.write(f'Method\tDataset\tMAE\tMRAE\tKLD\n')
|
||||
|
||||
with open(global_result_path + '.csv', 'at') as csv:
|
||||
|
||||
for dataset in qp.datasets.UCI_MULTICLASS_DATASETS:
|
||||
|
||||
print('init', dataset)
|
||||
|
||||
local_result_path = global_result_path + '_' + dataset
|
||||
if os.path.exists(local_result_path + '.dataframe'):
|
||||
print(f'result file {local_result_path}.dataframe already exist; skipping')
|
||||
report = pd.read_csv(local_result_path+'.dataframe')
|
||||
print(report["mae"].mean())
|
||||
# data = qp.datasets.fetch_UCIMulticlassDataset(dataset)
|
||||
# csv.write(f'{method}\t{data.name}\t{report["mae"].mean():.5f}\t{report["mrae"].mean():.5f}\t{report["kld"].mean():.5f}\n')
|
||||
continue
|
||||
|
||||
with qp.util.temp_seed(SEED):
|
||||
|
||||
param_grid, quantifier = new_method(method, max_iter=3000)
|
||||
|
||||
data = qp.datasets.fetch_UCIMulticlassDataset(dataset)
|
||||
|
||||
# model selection
|
||||
train, test = data.train_test
|
||||
train, val = train.split_stratified(random_state=SEED)
|
||||
|
||||
protocol = UPP(val, repeats=n_bags_val)
|
||||
modsel = GridSearchQ(
|
||||
quantifier, param_grid, protocol, refit=True, n_jobs=-1, verbose=True, error=optim
|
||||
)
|
||||
|
||||
try:
|
||||
modsel.fit(*train.Xy)
|
||||
|
||||
print(f'best params {modsel.best_params_}')
|
||||
print(f'best score {modsel.best_score_}')
|
||||
pickle.dump(
|
||||
(modsel.best_params_, modsel.best_score_,),
|
||||
open(f'{local_result_path}.hyper.pkl', 'wb'), pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
quantifier = modsel.best_model()
|
||||
except:
|
||||
print('something went wrong... trying to fit the default model')
|
||||
quantifier.fit(*train.Xy)
|
||||
|
||||
|
||||
protocol = UPP(test, repeats=n_bags_test)
|
||||
report = qp.evaluation.evaluation_report(
|
||||
quantifier, protocol, error_metrics=['mae', 'mrae', 'kld'], verbose=True
|
||||
)
|
||||
report.to_csv(f'{local_result_path}.dataframe')
|
||||
print(f'{method}\t{data.name}\t{report["mae"].mean():.5f}\t{report["mrae"].mean():.5f}\t{report["kld"].mean():.5f}\n')
|
||||
csv.write(f'{method}\t{data.name}\t{report["mae"].mean():.5f}\t{report["mrae"].mean():.5f}\t{report["kld"].mean():.5f}\n')
|
||||
csv.flush()
|
||||
|
||||
show_results(global_result_path)
|
||||
|
|
@ -604,7 +604,10 @@ estim_prevalence = model.predict(dataset.test.X)
|
|||
|
||||
_(New in v0.2.0!)_ Some quantification methods go beyond providing a single point estimate of class prevalence values and also produce confidence regions, which characterize the uncertainty around the point estimate. In QuaPy, two such methods are currently implemented:
|
||||
|
||||
* Aggregative Bootstrap: The Aggregative Bootstrap method extends any aggregative quantifier by generating confidence regions for class prevalence estimates through bootstrapping. Key features of this method include:
|
||||
* Aggregative Bootstrap: The Aggregative Bootstrap method extends any aggregative quantifier by generating confidence regions for class prevalence estimates through bootstrapping. The method is described in the paper [Moreo, A., Salvati, N.
|
||||
An Efficient Method for Deriving Confidence Intervals in Aggregative Quantification.
|
||||
Learning to Quantify: Methods and Applications (LQ 2025), co-located at ECML-PKDD 2025.
|
||||
pp 12-33, Porto (Portugal)](https://lq-2025.github.io/proceedings/CompleteVolume.pdf). Key features of this method include:
|
||||
|
||||
* Optimized Computation: The bootstrap is applied to pre-classified instances, significantly speeding up training and inference.
|
||||
During training, bootstrap repetitions are performed only after training the classifier once. These repetitions are used to train multiple aggregation functions.
|
||||
|
|
|
|||
|
|
@ -60,6 +60,14 @@ quapy.method.composable module
|
|||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
quapy.method.confidence module
|
||||
------------------------------
|
||||
|
||||
.. automodule:: quapy.method.confidence
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ with qp.util.temp_seed(0):
|
|||
true_prev = shifted_test.prevalence()
|
||||
|
||||
# by calling "quantify_conf", we obtain the point estimate and the confidence intervals around it
|
||||
pred_prev, conf_intervals = pacc.quantify_conf(shifted_test.X)
|
||||
pred_prev, conf_intervals = pacc.predict_conf(shifted_test.X)
|
||||
|
||||
# conf_intervals is an instance of ConfidenceRegionABC, which provides some useful utilities like:
|
||||
# - coverage: a function which computes the fraction of true values that belong to the confidence region
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.feature_selection import SelectKBest, chi2
|
||||
|
||||
import quapy as qp
|
||||
from quapy.method.non_aggregative import ReadMe
|
||||
import quapy.functional as F
|
||||
from sklearn.pipeline import Pipeline
|
||||
|
||||
"""
|
||||
This example showcases how to use the non-aggregative method ReadMe proposed by Hopkins and King.
|
||||
This method is for text analysis, so let us first instantiate a dataset for sentiment quantification (we
|
||||
use IMDb for this example). The method is quite computationally expensive, so we will restrict the training
|
||||
set to 1000 documents only.
|
||||
"""
|
||||
reviews = qp.datasets.fetch_reviews('imdb').reduce(n_train=1000, random_state=0)
|
||||
|
||||
"""
|
||||
We need to convert text to bag-of-words representations. Actually, ReadMe requires the representations to be
|
||||
binary (i.e., storing a 1 whenever a document contains certain word, or 0 otherwise), so we will not use
|
||||
TFIDF weighting. We will also retain the top 1000 most important features according to chi2.
|
||||
"""
|
||||
encode_0_1 = Pipeline([
|
||||
('0_1_terms', CountVectorizer(min_df=5, binary=True)),
|
||||
('feat_sel', SelectKBest(chi2, k=1000))
|
||||
])
|
||||
train, test = qp.data.preprocessing.instance_transformation(reviews, encode_0_1, inplace=True).train_test
|
||||
|
||||
"""
|
||||
We now instantiate ReadMe, with the prob_model='full' (default behaviour, implementing the Hopkins and King original
|
||||
idea). This method consists of estimating Q(Y) by solving:
|
||||
|
||||
Q(X) = \sum_i Q(X|Y=i) Q(Y=i)
|
||||
|
||||
without resorting to estimating the posteriors Q(Y=i|X), by solving a linear least-squares problem.
|
||||
However, since Q(X) and Q(X|Y=i) are matrices of shape (2^K, 1) and (2^K, n), with K the number of features
|
||||
and n the number of classes, their calculation becomes intractable. ReadMe instead performs bagging (i.e., it
|
||||
samples small sets of features and averages the results) thus reducing K to a few terms. In our example we
|
||||
set K (bagging_range) to 20, and the number of bagging_trials to 100.
|
||||
|
||||
ReadMe also computes confidence intervals via bootstrap. We set the number of bootstrap trials to 100.
|
||||
"""
|
||||
readme = ReadMe(prob_model='full', bootstrap_trials=100, bagging_trials=100, bagging_range=20, random_state=0, verbose=True)
|
||||
readme.fit(*train.Xy) # <- there is actually nothing happening here (only bootstrap resampling); the method is "lazy"
|
||||
# and postpones most of the calculations to the test phase.
|
||||
|
||||
# since the method is slow, we will only test 3 cases with different imbalances
|
||||
few_negatives = [0.25, 0.75]
|
||||
balanced = [0.5, 0.5]
|
||||
few_positives = [0.75, 0.25]
|
||||
|
||||
for test_prev in [few_negatives, balanced, few_positives]:
|
||||
sample = reviews.test.sampling(500, *test_prev, random_state=0) # draw sets of 500 documents with desired prevs
|
||||
prev_estim, conf = readme.predict_conf(sample.X)
|
||||
err = qp.error.mae(sample.prevalence(), prev_estim)
|
||||
print(f'true-prevalence={F.strprev(sample.prevalence())},\n'
|
||||
f'predicted-prevalence={F.strprev(prev_estim)}, with confidence intervals {conf},\n'
|
||||
f'MAE={err:.4f}')
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,254 @@
|
|||
from scipy.sparse import csc_matrix, csr_matrix
|
||||
from sklearn.base import BaseEstimator, TransformerMixin
|
||||
from sklearn.feature_extraction.text import TfidfTransformer, TfidfVectorizer, CountVectorizer
|
||||
import numpy as np
|
||||
from joblib import Parallel, delayed
|
||||
import sklearn
|
||||
import math
|
||||
from scipy.stats import t
|
||||
|
||||
|
||||
class ContTable:
|
||||
def __init__(self, tp=0, tn=0, fp=0, fn=0):
|
||||
self.tp=tp
|
||||
self.tn=tn
|
||||
self.fp=fp
|
||||
self.fn=fn
|
||||
|
||||
def get_d(self): return self.tp + self.tn + self.fp + self.fn
|
||||
|
||||
def get_c(self): return self.tp + self.fn
|
||||
|
||||
def get_not_c(self): return self.tn + self.fp
|
||||
|
||||
def get_f(self): return self.tp + self.fp
|
||||
|
||||
def get_not_f(self): return self.tn + self.fn
|
||||
|
||||
def p_c(self): return (1.0*self.get_c())/self.get_d()
|
||||
|
||||
def p_not_c(self): return 1.0-self.p_c()
|
||||
|
||||
def p_f(self): return (1.0*self.get_f())/self.get_d()
|
||||
|
||||
def p_not_f(self): return 1.0-self.p_f()
|
||||
|
||||
def p_tp(self): return (1.0*self.tp) / self.get_d()
|
||||
|
||||
def p_tn(self): return (1.0*self.tn) / self.get_d()
|
||||
|
||||
def p_fp(self): return (1.0*self.fp) / self.get_d()
|
||||
|
||||
def p_fn(self): return (1.0*self.fn) / self.get_d()
|
||||
|
||||
def tpr(self):
|
||||
c = 1.0*self.get_c()
|
||||
return self.tp / c if c > 0.0 else 0.0
|
||||
|
||||
def fpr(self):
|
||||
_c = 1.0*self.get_not_c()
|
||||
return self.fp / _c if _c > 0.0 else 0.0
|
||||
|
||||
|
||||
def __ig_factor(p_tc, p_t, p_c):
|
||||
den = p_t * p_c
|
||||
if den != 0.0 and p_tc != 0:
|
||||
return p_tc * math.log(p_tc / den, 2)
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
|
||||
def information_gain(cell):
|
||||
return __ig_factor(cell.p_tp(), cell.p_f(), cell.p_c()) + \
|
||||
__ig_factor(cell.p_fp(), cell.p_f(), cell.p_not_c()) +\
|
||||
__ig_factor(cell.p_fn(), cell.p_not_f(), cell.p_c()) + \
|
||||
__ig_factor(cell.p_tn(), cell.p_not_f(), cell.p_not_c())
|
||||
|
||||
|
||||
def squared_information_gain(cell):
|
||||
return information_gain(cell)**2
|
||||
|
||||
|
||||
def posneg_information_gain(cell):
|
||||
ig = information_gain(cell)
|
||||
if cell.tpr() < cell.fpr():
|
||||
return -ig
|
||||
else:
|
||||
return ig
|
||||
|
||||
|
||||
def pos_information_gain(cell):
|
||||
if cell.tpr() < cell.fpr():
|
||||
return 0
|
||||
else:
|
||||
return information_gain(cell)
|
||||
|
||||
def pointwise_mutual_information(cell):
|
||||
return __ig_factor(cell.p_tp(), cell.p_f(), cell.p_c())
|
||||
|
||||
|
||||
def gss(cell):
|
||||
return cell.p_tp()*cell.p_tn() - cell.p_fp()*cell.p_fn()
|
||||
|
||||
|
||||
def chi_square(cell):
|
||||
den = cell.p_f() * cell.p_not_f() * cell.p_c() * cell.p_not_c()
|
||||
if den==0.0: return 0.0
|
||||
num = gss(cell)**2
|
||||
return num / den
|
||||
|
||||
|
||||
def conf_interval(xt, n):
|
||||
if n>30:
|
||||
z2 = 3.84145882069 # norm.ppf(0.5+0.95/2.0)**2
|
||||
else:
|
||||
z2 = t.ppf(0.5 + 0.95 / 2.0, df=max(n-1,1)) ** 2
|
||||
p = (xt + 0.5 * z2) / (n + z2)
|
||||
amplitude = 0.5 * z2 * math.sqrt((p * (1.0 - p)) / (n + z2))
|
||||
return p, amplitude
|
||||
|
||||
|
||||
def strength(minPosRelFreq, minPos, maxNeg):
|
||||
if minPos > maxNeg:
|
||||
return math.log(2.0 * minPosRelFreq, 2.0)
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
|
||||
#set cancel_features=True to allow some features to be weighted as 0 (as in the original article)
|
||||
#however, for some extremely imbalanced dataset caused all documents to be 0
|
||||
def conf_weight(cell, cancel_features=False):
|
||||
c = cell.get_c()
|
||||
not_c = cell.get_not_c()
|
||||
tp = cell.tp
|
||||
fp = cell.fp
|
||||
|
||||
pos_p, pos_amp = conf_interval(tp, c)
|
||||
neg_p, neg_amp = conf_interval(fp, not_c)
|
||||
|
||||
min_pos = pos_p-pos_amp
|
||||
max_neg = neg_p+neg_amp
|
||||
den = (min_pos + max_neg)
|
||||
minpos_relfreq = min_pos / (den if den != 0 else 1)
|
||||
|
||||
str_tplus = strength(minpos_relfreq, min_pos, max_neg);
|
||||
|
||||
if str_tplus == 0 and not cancel_features:
|
||||
return 1e-20
|
||||
|
||||
return str_tplus
|
||||
|
||||
|
||||
def get_tsr_matrix(cell_matrix, tsr_score_funtion):
|
||||
nC = len(cell_matrix)
|
||||
nF = len(cell_matrix[0])
|
||||
tsr_matrix = [[tsr_score_funtion(cell_matrix[c,f]) for f in range(nF)] for c in range(nC)]
|
||||
return np.array(tsr_matrix)
|
||||
|
||||
|
||||
def feature_label_contingency_table(positive_document_indexes, feature_document_indexes, nD):
|
||||
tp_ = len(positive_document_indexes & feature_document_indexes)
|
||||
fp_ = len(feature_document_indexes - positive_document_indexes)
|
||||
fn_ = len(positive_document_indexes - feature_document_indexes)
|
||||
tn_ = nD - (tp_ + fp_ + fn_)
|
||||
return ContTable(tp=tp_, tn=tn_, fp=fp_, fn=fn_)
|
||||
|
||||
|
||||
def category_tables(feature_sets, category_sets, c, nD, nF):
|
||||
return [feature_label_contingency_table(category_sets[c], feature_sets[f], nD) for f in range(nF)]
|
||||
|
||||
|
||||
def get_supervised_matrix(coocurrence_matrix, label_matrix, n_jobs=-1):
|
||||
"""
|
||||
Computes the nC x nF supervised matrix M where Mcf is the 4-cell contingency table for feature f and class c.
|
||||
Efficiency O(nF x nC x log(S)) where S is the sparse factor
|
||||
"""
|
||||
|
||||
nD, nF = coocurrence_matrix.shape
|
||||
nD2, nC = label_matrix.shape
|
||||
|
||||
if nD != nD2:
|
||||
raise ValueError('Number of rows in coocurrence matrix shape %s and label matrix shape %s is not consistent' %
|
||||
(coocurrence_matrix.shape,label_matrix.shape))
|
||||
|
||||
def nonzero_set(matrix, col):
|
||||
return set(matrix[:, col].nonzero()[0])
|
||||
|
||||
if isinstance(coocurrence_matrix, csr_matrix):
|
||||
coocurrence_matrix = csc_matrix(coocurrence_matrix)
|
||||
feature_sets = [nonzero_set(coocurrence_matrix, f) for f in range(nF)]
|
||||
category_sets = [nonzero_set(label_matrix, c) for c in range(nC)]
|
||||
cell_matrix = Parallel(n_jobs=n_jobs, backend="threading")(
|
||||
delayed(category_tables)(feature_sets, category_sets, c, nD, nF) for c in range(nC)
|
||||
)
|
||||
return np.array(cell_matrix)
|
||||
|
||||
|
||||
class TSRweighting(BaseEstimator,TransformerMixin):
|
||||
"""
|
||||
Supervised Term Weighting function based on any Term Selection Reduction (TSR) function (e.g., information gain,
|
||||
chi-square, etc.) or, more generally, on any function that could be computed on the 4-cell contingency table for
|
||||
each category-feature pair.
|
||||
The supervised_4cell_matrix is a `(n_classes, n_words)` matrix containing the 4-cell contingency tables
|
||||
for each class-word pair, and can be pre-computed (e.g., during the feature selection phase) and passed as an
|
||||
argument.
|
||||
When `n_classes>1`, i.e., in multiclass scenarios, a global_policy is used in order to determine a
|
||||
single feature-score which informs about its relevance. Accepted policies include "max" (takes the max score
|
||||
across categories), "ave" and "wave" (take the average, or weighted average, across all categories -- weights
|
||||
correspond to the class prevalence), and "sum" (which sums all category scores).
|
||||
"""
|
||||
|
||||
def __init__(self, tsr_function, global_policy='max', supervised_4cell_matrix=None, sublinear_tf=True, norm='l2', min_df=3, n_jobs=-1):
|
||||
if global_policy not in ['max', 'ave', 'wave', 'sum']: raise ValueError('Global policy should be in {"max", "ave", "wave", "sum"}')
|
||||
self.tsr_function = tsr_function
|
||||
self.global_policy = global_policy
|
||||
self.supervised_4cell_matrix = supervised_4cell_matrix
|
||||
self.sublinear_tf = sublinear_tf
|
||||
self.norm = norm
|
||||
self.min_df = min_df
|
||||
self.n_jobs = n_jobs
|
||||
|
||||
def fit(self, X, y):
|
||||
self.count_vectorizer = CountVectorizer(min_df=self.min_df)
|
||||
X = self.count_vectorizer.fit_transform(X)
|
||||
|
||||
self.tf_vectorizer = TfidfTransformer(
|
||||
norm=None, use_idf=False, smooth_idf=False, sublinear_tf=self.sublinear_tf
|
||||
).fit(X)
|
||||
|
||||
if len(y.shape) == 1:
|
||||
y = np.expand_dims(y, axis=1)
|
||||
|
||||
nD, nC = y.shape
|
||||
nF = len(self.tf_vectorizer.get_feature_names_out())
|
||||
|
||||
if self.supervised_4cell_matrix is None:
|
||||
self.supervised_4cell_matrix = get_supervised_matrix(X, y, n_jobs=self.n_jobs)
|
||||
else:
|
||||
if self.supervised_4cell_matrix.shape != (nC, nF):
|
||||
raise ValueError("Shape of supervised information matrix is inconsistent with X and y")
|
||||
|
||||
tsr_matrix = get_tsr_matrix(self.supervised_4cell_matrix, self.tsr_function)
|
||||
|
||||
if self.global_policy == 'ave':
|
||||
self.global_tsr_vector = np.average(tsr_matrix, axis=0)
|
||||
elif self.global_policy == 'wave':
|
||||
category_prevalences = [sum(y[:,c])*1.0/nD for c in range(nC)]
|
||||
self.global_tsr_vector = np.average(tsr_matrix, axis=0, weights=category_prevalences)
|
||||
elif self.global_policy == 'sum':
|
||||
self.global_tsr_vector = np.sum(tsr_matrix, axis=0)
|
||||
elif self.global_policy == 'max':
|
||||
self.global_tsr_vector = np.amax(tsr_matrix, axis=0)
|
||||
return self
|
||||
|
||||
def fit_transform(self, X, y):
|
||||
return self.fit(X,y).transform(X)
|
||||
|
||||
def transform(self, X):
|
||||
if not hasattr(self, 'global_tsr_vector'): raise NameError('TSRweighting: transform method called before fit.')
|
||||
X = self.count_vectorizer.transform(X)
|
||||
tf_X = self.tf_vectorizer.transform(X).toarray()
|
||||
weighted_X = np.multiply(tf_X, self.global_tsr_vector)
|
||||
if self.norm is not None and self.norm!='none':
|
||||
weighted_X = sklearn.preprocessing.normalize(weighted_X, norm=self.norm, axis=1, copy=False)
|
||||
return csr_matrix(weighted_X)
|
||||
|
|
@ -0,0 +1,208 @@
|
|||
from scipy.sparse import issparse
|
||||
from sklearn.decomposition import TruncatedSVD
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
import quapy as qp
|
||||
from data import LabelledCollection
|
||||
import numpy as np
|
||||
|
||||
from experimental_non_aggregative.custom_vectorizers import *
|
||||
from method._kdey import KDEBase
|
||||
from protocol import APP
|
||||
from quapy.method.aggregative import HDy, DistributionMatchingY
|
||||
from quapy.method.base import BaseQuantifier
|
||||
from scipy import optimize
|
||||
import pandas as pd
|
||||
import quapy.functional as F
|
||||
|
||||
|
||||
# TODO: explore the bernoulli (term presence/absence) variant
|
||||
# TODO: explore the multinomial (term frequency) variant
|
||||
# TODO: explore the multinomial + length normalization variant
|
||||
# TODO: consolidate the TSR-variant (e.g., using information gain) variant;
|
||||
# - works better with the idf?
|
||||
# - works better with length normalization?
|
||||
# - etc
|
||||
|
||||
class DxS(BaseQuantifier):
|
||||
def __init__(self, vectorizer=None, divergence='topsoe'):
|
||||
self.vectorizer = vectorizer
|
||||
self.divergence = divergence
|
||||
|
||||
# def __as_distribution(self, instances):
|
||||
# return np.asarray(instances.sum(axis=0) / instances.sum()).flatten()
|
||||
|
||||
def __as_distribution(self, instances):
|
||||
dist = instances.mean(axis=0)
|
||||
return np.asarray(dist).flatten()
|
||||
|
||||
def fit(self, text_instances, labels):
|
||||
|
||||
classes = np.unique(labels)
|
||||
|
||||
if self.vectorizer is not None:
|
||||
text_instances = self.vectorizer.fit_transform(text_instances, y=labels)
|
||||
|
||||
distributions = []
|
||||
for class_i in classes:
|
||||
distributions.append(self.__as_distribution(text_instances[labels == class_i]))
|
||||
|
||||
self.validation_distribution = np.asarray(distributions)
|
||||
|
||||
return self
|
||||
|
||||
def predict(self, text_instances):
|
||||
if self.vectorizer is not None:
|
||||
text_instances = self.vectorizer.transform(text_instances)
|
||||
|
||||
test_distribution = self.__as_distribution(text_instances)
|
||||
divergence = qp.functional.get_divergence(self.divergence)
|
||||
n_classes, n_feats = self.validation_distribution.shape
|
||||
|
||||
def match(prev):
|
||||
prev = np.expand_dims(prev, axis=0)
|
||||
mixture_distribution = (prev @ self.validation_distribution).flatten()
|
||||
return divergence(test_distribution, mixture_distribution)
|
||||
|
||||
# the initial point is set as the uniform distribution
|
||||
uniform_distribution = np.full(fill_value=1 / n_classes, shape=(n_classes,))
|
||||
|
||||
# solutions are bounded to those contained in the unit-simplex
|
||||
bounds = tuple((0, 1) for x in range(n_classes)) # values in [0,1]
|
||||
constraints = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)}) # values summing up to 1
|
||||
r = optimize.minimize(match, x0=uniform_distribution, method='SLSQP', bounds=bounds, constraints=constraints)
|
||||
return r.x
|
||||
|
||||
|
||||
|
||||
class KDExML(BaseQuantifier, KDEBase):
|
||||
|
||||
def __init__(self, bandwidth=0.1, standardize=False):
|
||||
self._check_bandwidth(bandwidth)
|
||||
self.bandwidth = bandwidth
|
||||
self.standardize = standardize
|
||||
|
||||
def fit(self, X, y):
|
||||
classes = sorted(np.unique(y))
|
||||
|
||||
if self.standardize:
|
||||
self.scaler = StandardScaler()
|
||||
X = self.scaler.fit_transform(X)
|
||||
|
||||
if issparse(X):
|
||||
X = X.toarray()
|
||||
|
||||
self.mix_densities = self.get_mixture_components(X, y, classes, self.bandwidth)
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
"""
|
||||
Searches for the mixture model parameter (the sought prevalence values) that maximizes the likelihood
|
||||
of the data (i.e., that minimizes the negative log-likelihood)
|
||||
|
||||
:param X: instances in the sample
|
||||
:return: a vector of class prevalence estimates
|
||||
"""
|
||||
epsilon = 1e-10
|
||||
if issparse(X):
|
||||
X = X.toarray()
|
||||
n_classes = len(self.mix_densities)
|
||||
if self.standardize:
|
||||
X = self.scaler.transform(X)
|
||||
test_densities = [self.pdf(kde_i, X) for kde_i in self.mix_densities]
|
||||
|
||||
def neg_loglikelihood(prev):
|
||||
test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip (prev, test_densities))
|
||||
test_loglikelihood = np.log(test_mixture_likelihood + epsilon)
|
||||
return -np.sum(test_loglikelihood)
|
||||
|
||||
return F.optim_minimize(neg_loglikelihood, n_classes)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
qp.environ['SAMPLE_SIZE'] = 250
|
||||
qp.environ['N_JOBS'] = -1
|
||||
min_df = 10
|
||||
# dataset = 'imdb'
|
||||
repeats = 10
|
||||
error = 'mae'
|
||||
|
||||
div = 'topsoe'
|
||||
|
||||
# generates tuples (dataset, method, method_name)
|
||||
# (the dataset is needed for methods that process the dataset differently)
|
||||
def gen_methods():
|
||||
|
||||
for dataset in qp.datasets.REVIEWS_SENTIMENT_DATASETS:
|
||||
|
||||
data = qp.datasets.fetch_reviews(dataset, tfidf=False)
|
||||
|
||||
# bernoulli_vectorizer = CountVectorizer(min_df=min_df, binary=True)
|
||||
# dxs = DxS(divergence=div, vectorizer=bernoulli_vectorizer)
|
||||
# yield data, dxs, 'DxS-Bernoulli'
|
||||
#
|
||||
# multinomial_vectorizer = CountVectorizer(min_df=min_df, binary=False)
|
||||
# dxs = DxS(divergence=div, vectorizer=multinomial_vectorizer)
|
||||
# yield data, dxs, 'DxS-multinomial'
|
||||
#
|
||||
# tf_vectorizer = TfidfVectorizer(sublinear_tf=False, use_idf=False, min_df=min_df, norm=None)
|
||||
# dxs = DxS(divergence=div, vectorizer=tf_vectorizer)
|
||||
# yield data, dxs, 'DxS-TF'
|
||||
#
|
||||
# logtf_vectorizer = TfidfVectorizer(sublinear_tf=True, use_idf=False, min_df=min_df, norm=None)
|
||||
# dxs = DxS(divergence=div, vectorizer=logtf_vectorizer)
|
||||
# yield data, dxs, 'DxS-logTF'
|
||||
#
|
||||
# tfidf_vectorizer = TfidfVectorizer(use_idf=True, min_df=min_df, norm=None)
|
||||
# dxs = DxS(divergence=div, vectorizer=tfidf_vectorizer)
|
||||
# yield data, dxs, 'DxS-TFIDF'
|
||||
#
|
||||
# tfidf_vectorizer = TfidfVectorizer(use_idf=True, min_df=min_df, norm='l2')
|
||||
# dxs = DxS(divergence=div, vectorizer=tfidf_vectorizer)
|
||||
# yield data, dxs, 'DxS-TFIDF-l2'
|
||||
|
||||
tsr_vectorizer = TSRweighting(tsr_function=information_gain, min_df=min_df, norm='l2')
|
||||
dxs = DxS(divergence=div, vectorizer=tsr_vectorizer)
|
||||
yield data, dxs, 'DxS-TFTSR-l2'
|
||||
|
||||
data = qp.datasets.fetch_reviews(dataset, tfidf=True, min_df=min_df)
|
||||
|
||||
kdex = KDExML()
|
||||
reduction = TruncatedSVD(n_components=100, random_state=0)
|
||||
red_data = qp.data.preprocessing.instance_transformation(data, transformer=reduction, inplace=False)
|
||||
yield red_data, kdex, 'KDEx'
|
||||
|
||||
hdy = HDy(LogisticRegression())
|
||||
yield data, hdy, 'HDy'
|
||||
|
||||
# dm = DistributionMatchingY(LogisticRegression(), divergence=div, nbins=5)
|
||||
# yield data, dm, 'DM-5b'
|
||||
#
|
||||
# dm = DistributionMatchingY(LogisticRegression(), divergence=div, nbins=10)
|
||||
# yield data, dm, 'DM-10b'
|
||||
|
||||
|
||||
|
||||
|
||||
result_path = 'results.csv'
|
||||
with open(result_path, 'wt') as csv:
|
||||
csv.write(f'Method\tDataset\tMAE\tMRAE\n')
|
||||
for data, quantifier, quant_name in gen_methods():
|
||||
quantifier.fit(*data.training.Xy)
|
||||
report = qp.evaluation.evaluation_report(quantifier, APP(data.test, repeats=repeats), error_metrics=['mae','mrae'], verbose=True)
|
||||
means = report.mean(numeric_only=True)
|
||||
csv.write(f'{quant_name}\t{data.name}\t{means["mae"]:.5f}\t{means["mrae"]:.5f}\n')
|
||||
|
||||
df = pd.read_csv(result_path, sep='\t')
|
||||
# print(df)
|
||||
|
||||
pv = df.pivot_table(index='Method', columns="Dataset", values=["MAE", "MRAE"])
|
||||
print(pv)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -13,7 +13,7 @@ from . import model_selection
|
|||
from . import classification
|
||||
import os
|
||||
|
||||
__version__ = '0.2.0'
|
||||
__version__ = '0.2.1'
|
||||
|
||||
|
||||
def _default_cls():
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ class LabelledCollection:
|
|||
else:
|
||||
self.instances = np.asarray(instances)
|
||||
self.labels = np.asarray(labels)
|
||||
n_docs = len(self)
|
||||
if classes is None:
|
||||
self.classes_ = F.classes_from_labels(self.labels)
|
||||
else:
|
||||
|
|
@ -41,7 +40,13 @@ class LabelledCollection:
|
|||
self.classes_.sort()
|
||||
if len(set(self.labels).difference(set(classes))) > 0:
|
||||
raise ValueError(f'labels ({set(self.labels)}) contain values not included in classes_ ({set(classes)})')
|
||||
self.index = {class_: np.arange(n_docs)[self.labels == class_] for class_ in self.classes_}
|
||||
self._index = None
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
if not hasattr(self, '_index') or self._index is None:
|
||||
self._index = {class_: np.arange(len(self))[self.labels == class_] for class_ in self.classes_}
|
||||
return self._index
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str, loader_func: callable, classes=None, **loader_kwargs):
|
||||
|
|
|
|||
|
|
@ -10,6 +10,37 @@ from quapy.util import map_parallel
|
|||
from .base import LabelledCollection
|
||||
|
||||
|
||||
def instance_transformation(dataset:Dataset, transformer, inplace=False):
|
||||
"""
|
||||
Transforms a :class:`quapy.data.base.Dataset` applying the `fit_transform` and `transform` functions
|
||||
of a (sklearn's) transformer.
|
||||
|
||||
:param dataset: a :class:`quapy.data.base.Dataset` where the instances of training and test collections are
|
||||
lists of str
|
||||
:param transformer: TransformerMixin implementing `fit_transform` and `transform` functions
|
||||
:param inplace: whether or not to apply the transformation inplace (True), or to a new copy (False, default)
|
||||
:return: a new :class:`quapy.data.base.Dataset` with transformed instances (if inplace=False) or a reference to the
|
||||
current Dataset (if inplace=True) where the instances have been transformed
|
||||
"""
|
||||
training_transformed = transformer.fit_transform(*dataset.training.Xy)
|
||||
test_transformed = transformer.transform(dataset.test.X)
|
||||
orig_name = dataset.name
|
||||
|
||||
if inplace:
|
||||
dataset.training = LabelledCollection(training_transformed, dataset.training.labels, dataset.classes_)
|
||||
dataset.test = LabelledCollection(test_transformed, dataset.test.labels, dataset.classes_)
|
||||
if hasattr(transformer, 'vocabulary_'):
|
||||
dataset.vocabulary = transformer.vocabulary_
|
||||
return dataset
|
||||
else:
|
||||
training = LabelledCollection(training_transformed, dataset.training.labels.copy(), dataset.classes_)
|
||||
test = LabelledCollection(test_transformed, dataset.test.labels.copy(), dataset.classes_)
|
||||
vocab = None
|
||||
if hasattr(transformer, 'vocabulary_'):
|
||||
vocab = transformer.vocabulary_
|
||||
return Dataset(training, test, vocabulary=vocab, name=orig_name)
|
||||
|
||||
|
||||
def text2tfidf(dataset:Dataset, min_df=3, sublinear_tf=True, inplace=False, **kwargs):
|
||||
"""
|
||||
Transforms a :class:`quapy.data.base.Dataset` of textual instances into a :class:`quapy.data.base.Dataset` of
|
||||
|
|
@ -29,18 +60,7 @@ def text2tfidf(dataset:Dataset, min_df=3, sublinear_tf=True, inplace=False, **kw
|
|||
__check_type(dataset.test.instances, np.ndarray, str)
|
||||
|
||||
vectorizer = TfidfVectorizer(min_df=min_df, sublinear_tf=sublinear_tf, **kwargs)
|
||||
training_documents = vectorizer.fit_transform(dataset.training.instances)
|
||||
test_documents = vectorizer.transform(dataset.test.instances)
|
||||
|
||||
if inplace:
|
||||
dataset.training = LabelledCollection(training_documents, dataset.training.labels, dataset.classes_)
|
||||
dataset.test = LabelledCollection(test_documents, dataset.test.labels, dataset.classes_)
|
||||
dataset.vocabulary = vectorizer.vocabulary_
|
||||
return dataset
|
||||
else:
|
||||
training = LabelledCollection(training_documents, dataset.training.labels.copy(), dataset.classes_)
|
||||
test = LabelledCollection(test_documents, dataset.test.labels.copy(), dataset.classes_)
|
||||
return Dataset(training, test, vectorizer.vocabulary_)
|
||||
return instance_transformation(dataset, vectorizer, inplace)
|
||||
|
||||
|
||||
def reduce_columns(dataset: Dataset, min_df=5, inplace=False):
|
||||
|
|
|
|||
|
|
@ -583,8 +583,8 @@ def solve_adjustment(
|
|||
"""
|
||||
Function that tries to solve for :math:`p` the equation :math:`q = M p`, where :math:`q` is the vector of
|
||||
`unadjusted counts` (as estimated, e.g., via classify and count) with :math:`q_i` an estimate of
|
||||
:math:`P(\hat{Y}=y_i)`, and where :math:`M` is the matrix of `class-conditional rates` with :math:`M_{ij}` an
|
||||
estimate of :math:`P(\hat{Y}=y_i|Y=y_j)`.
|
||||
:math:`P(\\hat{Y}=y_i)`, and where :math:`M` is the matrix of `class-conditional rates` with :math:`M_{ij}` an
|
||||
estimate of :math:`P(\\hat{Y}=y_i|Y=y_j)`.
|
||||
|
||||
:param class_conditional_rates: array of shape `(n_classes, n_classes,)` with entry `(i,j)` being the estimate
|
||||
of :math:`P(\hat{Y}=y_i|Y=y_j)`, that is, the probability that an instance that belongs to class :math:`y_j`
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ AGGREGATIVE_METHODS = {
|
|||
aggregative.KDEyHD,
|
||||
# aggregative.OneVsAllAggregative,
|
||||
confidence.BayesianCC,
|
||||
confidence.PQ,
|
||||
}
|
||||
|
||||
BINARY_METHODS = {
|
||||
|
|
@ -40,6 +41,7 @@ BINARY_METHODS = {
|
|||
aggregative.MAX,
|
||||
aggregative.MS,
|
||||
aggregative.MS2,
|
||||
confidence.PQ,
|
||||
}
|
||||
|
||||
MULTICLASS_METHODS = {
|
||||
|
|
|
|||
|
|
@ -1,13 +1,21 @@
|
|||
"""
|
||||
Utility functions for `Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods.
|
||||
"""
|
||||
import contextlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import importlib.resources
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpyro
|
||||
import numpyro.distributions as dist
|
||||
import stan
|
||||
import logging
|
||||
import stan.common
|
||||
|
||||
DEPENDENCIES_INSTALLED = True
|
||||
except ImportError:
|
||||
|
|
@ -15,6 +23,7 @@ except ImportError:
|
|||
jnp = None
|
||||
numpyro = None
|
||||
dist = None
|
||||
stan = None
|
||||
|
||||
DEPENDENCIES_INSTALLED = False
|
||||
|
||||
|
|
@ -77,3 +86,71 @@ def sample_posterior(
|
|||
rng_key = jax.random.PRNGKey(seed)
|
||||
mcmc.run(rng_key, n_c_unlabeled=n_c_unlabeled, n_y_and_c_labeled=n_y_and_c_labeled)
|
||||
return mcmc.get_samples()
|
||||
|
||||
|
||||
|
||||
def load_stan_file():
|
||||
return importlib.resources.files('quapy.method').joinpath('stan/pq.stan').read_text(encoding='utf-8')
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _suppress_stan_logging():
|
||||
with open(os.devnull, "w") as devnull:
|
||||
old_stderr = sys.stderr
|
||||
sys.stderr = devnull
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.stderr = old_stderr
|
||||
|
||||
|
||||
def pq_stan(stan_code, n_bins, pos_hist, neg_hist, test_hist, number_of_samples, num_warmup, stan_seed):
|
||||
"""
|
||||
Perform Bayesian prevalence estimation using a Stan model for probabilistic quantification.
|
||||
|
||||
This function builds and samples from a Stan model that implements a bin-based Bayesian
|
||||
quantifier. It uses the class-conditional histograms of the classifier
|
||||
outputs for positive and negative examples, along with the test histogram, to estimate
|
||||
the posterior distribution of prevalence in the test set.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stan_code : str
|
||||
The Stan model code as a string.
|
||||
n_bins : int
|
||||
Number of bins used to build the histograms for positive and negative examples.
|
||||
pos_hist : array-like of shape (n_bins,)
|
||||
Histogram counts of the classifier outputs for the positive class.
|
||||
neg_hist : array-like of shape (n_bins,)
|
||||
Histogram counts of the classifier outputs for the negative class.
|
||||
test_hist : array-like of shape (n_bins,)
|
||||
Histogram counts of the classifier outputs for the test set, binned using the same bins.
|
||||
number_of_samples : int
|
||||
Number of post-warmup samples to draw from the Stan posterior.
|
||||
num_warmup : int
|
||||
Number of warmup iterations for the sampler.
|
||||
stan_seed : int
|
||||
Random seed for Stan model compilation and sampling, ensuring reproducibility.
|
||||
|
||||
Returns
|
||||
-------
|
||||
prev_samples : numpy.ndarray
|
||||
An array of posterior samples of the prevalence (`prev`) in the test set.
|
||||
Each element corresponds to one draw from the posterior distribution.
|
||||
"""
|
||||
|
||||
logging.getLogger("stan.common").setLevel(logging.ERROR)
|
||||
|
||||
stan_data = {
|
||||
'n_bucket': n_bins,
|
||||
'train_neg': neg_hist.tolist(),
|
||||
'train_pos': pos_hist.tolist(),
|
||||
'test': test_hist.tolist(),
|
||||
'posterior': 1
|
||||
}
|
||||
|
||||
with _suppress_stan_logging():
|
||||
stan_model = stan.build(stan_code, data=stan_data, random_seed=stan_seed)
|
||||
fit = stan_model.sample(num_chains=1, num_samples=number_of_samples,num_warmup=num_warmup)
|
||||
|
||||
return fit['prev']
|
||||
|
|
|
|||
|
|
@ -9,15 +9,31 @@ import quapy.functional as F
|
|||
from sklearn.metrics.pairwise import rbf_kernel
|
||||
|
||||
|
||||
# class KDE(KernelDensity):
|
||||
#
|
||||
# KERNELS = ['gaussian', 'aitchison']
|
||||
#
|
||||
# def __init__(self, bandwidth, kernel):
|
||||
# assert kernel in KDE.KERNELS, f'unknown {kernel=}'
|
||||
# self.bandwidth = bandwidth
|
||||
# self.kernel = kernel
|
||||
#
|
||||
# def
|
||||
|
||||
|
||||
|
||||
|
||||
class KDEBase:
|
||||
"""
|
||||
Common ancestor for KDE-based methods. Implements some common routines.
|
||||
"""
|
||||
|
||||
BANDWIDTH_METHOD = ['scott', 'silverman']
|
||||
KERNELS = ['gaussian', 'aitchison']
|
||||
|
||||
|
||||
@classmethod
|
||||
def _check_bandwidth(cls, bandwidth):
|
||||
def _check_bandwidth(cls, bandwidth, kernel):
|
||||
"""
|
||||
Checks that the bandwidth parameter is correct
|
||||
|
||||
|
|
@ -27,32 +43,64 @@ class KDEBase:
|
|||
assert bandwidth in KDEBase.BANDWIDTH_METHOD or isinstance(bandwidth, float), \
|
||||
f'invalid bandwidth, valid ones are {KDEBase.BANDWIDTH_METHOD} or float values'
|
||||
if isinstance(bandwidth, float):
|
||||
assert 0 < bandwidth < 1, \
|
||||
"the bandwith for KDEy should be in (0,1), since this method models the unit simplex"
|
||||
assert kernel!='gaussian' or (0 < bandwidth < 1), \
|
||||
("the bandwidth for a Gaussian kernel in KDEy should be in (0,1), "
|
||||
"since this method models the unit simplex")
|
||||
return bandwidth
|
||||
|
||||
def get_kde_function(self, X, bandwidth):
|
||||
@classmethod
|
||||
def _check_kernel(cls, kernel):
|
||||
"""
|
||||
Checks that the kernel parameter is correct
|
||||
|
||||
:param kernel: str
|
||||
:return: the validated kernel
|
||||
"""
|
||||
assert kernel in KDEBase.KERNELS, f'unknown {kernel=}'
|
||||
return kernel
|
||||
|
||||
@classmethod
|
||||
def clr_transform(cls, P, eps=1e-7):
|
||||
"""
|
||||
Centered-Log Ratio (CLR) transform.
|
||||
P: array (n_samples, n_classes), every row is a point in the probability simplex
|
||||
eps: smoothing, to avoid log(0)
|
||||
"""
|
||||
X_safe = np.clip(P, eps, None)
|
||||
X_safe /= X_safe.sum(axis=1, keepdims=True) # renormalize
|
||||
gm = np.exp(np.mean(np.log(X_safe), axis=1, keepdims=True))
|
||||
return np.log(X_safe / gm)
|
||||
|
||||
def get_kde_function(self, X, bandwidth, kernel):
|
||||
"""
|
||||
Wraps the KDE function from scikit-learn.
|
||||
|
||||
:param X: data for which the density function is to be estimated
|
||||
:param bandwidth: the bandwidth of the kernel
|
||||
:param kernel: the kernel (valid ones are in KDEBase.KERNELS)
|
||||
:return: a scikit-learn's KernelDensity object
|
||||
"""
|
||||
if kernel == 'aitchison':
|
||||
X = KDEBase.clr_transform(X)
|
||||
|
||||
return KernelDensity(bandwidth=bandwidth).fit(X)
|
||||
|
||||
def pdf(self, kde, X):
|
||||
def pdf(self, kde, X, kernel):
|
||||
"""
|
||||
Wraps the density evalution of scikit-learn's KDE. Scikit-learn returns log-scores (s), so this
|
||||
function returns :math:`e^{s}`
|
||||
|
||||
:param kde: a previously fit KDE function
|
||||
:param X: the data for which the density is to be estimated
|
||||
:param kernel: the kernel (valid ones are in KDEBase.KERNELS)
|
||||
:return: np.ndarray with the densities
|
||||
"""
|
||||
if kernel == 'aitchison':
|
||||
X = KDEBase.clr_transform(X)
|
||||
|
||||
return np.exp(kde.score_samples(X))
|
||||
|
||||
def get_mixture_components(self, X, y, classes, bandwidth):
|
||||
def get_mixture_components(self, X, y, classes, bandwidth, kernel):
|
||||
"""
|
||||
Returns an array containing the mixture components, i.e., the KDE functions for each class.
|
||||
|
||||
|
|
@ -60,6 +108,7 @@ class KDEBase:
|
|||
:param y: the class labels
|
||||
:param n_classes: integer, the number of classes
|
||||
:param bandwidth: float, the bandwidth of the kernel
|
||||
:param kernel: the kernel (valid ones are in KDEBase.KERNELS)
|
||||
:return: a list of KernelDensity objects, each fitted with the corresponding class-specific covariates
|
||||
"""
|
||||
class_cond_X = []
|
||||
|
|
@ -67,8 +116,12 @@ class KDEBase:
|
|||
selX = X[y==cat]
|
||||
if selX.size==0:
|
||||
selX = [F.uniform_prevalence(len(classes))]
|
||||
|
||||
# if kernel == 'aitchison':
|
||||
# this is already done within get_kde_function
|
||||
# selX = KDEBase.clr_transform(selX)
|
||||
class_cond_X.append(selX)
|
||||
return [self.get_kde_function(X_cond_yi, bandwidth) for X_cond_yi in class_cond_X]
|
||||
return [self.get_kde_function(X_cond_yi, bandwidth, kernel) for X_cond_yi in class_cond_X]
|
||||
|
||||
|
||||
class KDEyML(AggregativeSoftQuantifier, KDEBase):
|
||||
|
|
@ -107,17 +160,19 @@ class KDEyML(AggregativeSoftQuantifier, KDEBase):
|
|||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
|
||||
:param bandwidth: float, the bandwidth of the Kernel
|
||||
:param kernel: kernel of KDE, valid ones are in KDEBase.KERNELS
|
||||
:param random_state: a seed to be set before fitting any base quantifier (default None)
|
||||
"""
|
||||
|
||||
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5, bandwidth=0.1,
|
||||
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5, bandwidth=0.1, kernel='gaussian',
|
||||
random_state=None):
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
self.bandwidth = KDEBase._check_bandwidth(bandwidth)
|
||||
self.bandwidth = KDEBase._check_bandwidth(bandwidth, kernel)
|
||||
self.kernel = self._check_kernel(kernel)
|
||||
self.random_state=random_state
|
||||
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
self.mix_densities = self.get_mixture_components(classif_predictions, labels, self.classes_, self.bandwidth)
|
||||
self.mix_densities = self.get_mixture_components(classif_predictions, labels, self.classes_, self.bandwidth, self.kernel)
|
||||
return self
|
||||
|
||||
def aggregate(self, posteriors: np.ndarray):
|
||||
|
|
@ -131,10 +186,11 @@ class KDEyML(AggregativeSoftQuantifier, KDEBase):
|
|||
with qp.util.temp_seed(self.random_state):
|
||||
epsilon = 1e-10
|
||||
n_classes = len(self.mix_densities)
|
||||
test_densities = [self.pdf(kde_i, posteriors) for kde_i in self.mix_densities]
|
||||
test_densities = [self.pdf(kde_i, posteriors, self.kernel) for kde_i in self.mix_densities]
|
||||
|
||||
def neg_loglikelihood(prev):
|
||||
test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip (prev, test_densities))
|
||||
# test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip (prev, test_densities))
|
||||
test_mixture_likelihood = prev @ test_densities
|
||||
test_loglikelihood = np.log(test_mixture_likelihood + epsilon)
|
||||
return -np.sum(test_loglikelihood)
|
||||
|
||||
|
|
@ -191,7 +247,7 @@ class KDEyHD(AggregativeSoftQuantifier, KDEBase):
|
|||
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
self.divergence = divergence
|
||||
self.bandwidth = KDEBase._check_bandwidth(bandwidth)
|
||||
self.bandwidth = KDEBase._check_bandwidth(bandwidth, kernel='gaussian')
|
||||
self.random_state=random_state
|
||||
self.montecarlo_trials = montecarlo_trials
|
||||
|
||||
|
|
@ -278,7 +334,7 @@ class KDEyCS(AggregativeSoftQuantifier):
|
|||
|
||||
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5, bandwidth=0.1):
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
self.bandwidth = KDEBase._check_bandwidth(bandwidth)
|
||||
self.bandwidth = KDEBase._check_bandwidth(bandwidth, kernel='gaussian')
|
||||
|
||||
def gram_matrix_mix_sum(self, X, Y=None):
|
||||
# this adapts the output of the rbf_kernel function (pairwise evaluations of Gaussian kernels k(x,y))
|
||||
|
|
|
|||
|
|
@ -1,19 +1,20 @@
|
|||
import numpy as np
|
||||
from joblib import Parallel, delayed
|
||||
from sklearn.base import BaseEstimator
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
import quapy as qp
|
||||
import quapy.functional as F
|
||||
from quapy.method import _bayesian
|
||||
from quapy.method.aggregative import AggregativeCrispQuantifier
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.method.aggregative import AggregativeQuantifier
|
||||
from quapy.method.aggregative import AggregativeQuantifier, AggregativeCrispQuantifier, AggregativeSoftQuantifier, BinaryAggregativeQuantifier
|
||||
from scipy.stats import chi2
|
||||
from sklearn.utils import resample
|
||||
from abc import ABC, abstractmethod
|
||||
from scipy.special import softmax, factorial
|
||||
import copy
|
||||
from functools import lru_cache
|
||||
from tqdm import tqdm
|
||||
|
||||
"""
|
||||
This module provides implementation of different types of confidence regions, and the implementation of Bootstrap
|
||||
|
|
@ -80,6 +81,12 @@ class ConfidenceRegionABC(ABC):
|
|||
proportion = np.clip(self.coverage(uniform_simplex), 0., 1.)
|
||||
return proportion
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def samples(self):
|
||||
""" Returns internal samples """
|
||||
...
|
||||
|
||||
|
||||
class WithConfidenceABC(ABC):
|
||||
"""
|
||||
|
|
@ -88,20 +95,32 @@ class WithConfidenceABC(ABC):
|
|||
METHODS = ['intervals', 'ellipse', 'ellipse-clr']
|
||||
|
||||
@abstractmethod
|
||||
def quantify_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
||||
def predict_conf(self, instances, confidence_level=0.95) -> (np.ndarray, ConfidenceRegionABC):
|
||||
"""
|
||||
Adds the method `quantify_conf` to the interface. This method returns not only the point-estimate, but
|
||||
Adds the method `predict_conf` to the interface. This method returns not only the point-estimate, but
|
||||
also the confidence region around it.
|
||||
|
||||
:param instances: a np.ndarray of shape (n_instances, n_features,)
|
||||
:confidence_level: float in (0, 1)
|
||||
:param confidence_level: float in (0, 1), default is 0.95
|
||||
:return: a tuple (`point_estimate`, `conf_region`), where `point_estimate` is a np.ndarray of shape
|
||||
(n_classes,) and `conf_region` is an object from :class:`ConfidenceRegionABC`
|
||||
"""
|
||||
...
|
||||
|
||||
def quantify_conf(self, instances, confidence_level=0.95) -> (np.ndarray, ConfidenceRegionABC):
|
||||
"""
|
||||
Alias to `predict_conf`. This method returns not only the point-estimate, but
|
||||
also the confidence region around it.
|
||||
|
||||
:param instances: a np.ndarray of shape (n_instances, n_features,)
|
||||
:param confidence_level: float in (0, 1), default is 0.95
|
||||
:return: a tuple (`point_estimate`, `conf_region`), where `point_estimate` is a np.ndarray of shape
|
||||
(n_classes,) and `conf_region` is an object from :class:`ConfidenceRegionABC`
|
||||
"""
|
||||
return self.predict_conf(instances=instances, confidence_level=confidence_level)
|
||||
|
||||
@classmethod
|
||||
def construct_region(cls, prev_estims, confidence_level=0.95, method='intervals'):
|
||||
def construct_region(cls, prev_estims, confidence_level=0.95, method='intervals')->ConfidenceRegionABC:
|
||||
"""
|
||||
Construct a confidence region given many prevalence estimations.
|
||||
|
||||
|
|
@ -173,30 +192,35 @@ class ConfidenceEllipseSimplex(ConfidenceRegionABC):
|
|||
"""
|
||||
Instantiates a Confidence Ellipse in the probability simplex.
|
||||
|
||||
:param X: np.ndarray of shape (n_bootstrap_samples, n_classes)
|
||||
:param samples: np.ndarray of shape (n_bootstrap_samples, n_classes)
|
||||
:param confidence_level: float, the confidence level (default 0.95)
|
||||
"""
|
||||
|
||||
def __init__(self, X, confidence_level=0.95):
|
||||
def __init__(self, samples, confidence_level=0.95):
|
||||
|
||||
assert 0. < confidence_level < 1., f'{confidence_level=} must be in range(0,1)'
|
||||
|
||||
X = np.asarray(X)
|
||||
samples = np.asarray(samples)
|
||||
|
||||
self.mean_ = X.mean(axis=0)
|
||||
self.cov_ = np.cov(X, rowvar=False, ddof=1)
|
||||
self.mean_ = samples.mean(axis=0)
|
||||
self.cov_ = np.cov(samples, rowvar=False, ddof=1)
|
||||
|
||||
try:
|
||||
self.precision_matrix_ = np.linalg.inv(self.cov_)
|
||||
except:
|
||||
self.precision_matrix_ = None
|
||||
|
||||
self.dim = X.shape[-1]
|
||||
self.dim = samples.shape[-1]
|
||||
self.ddof = self.dim - 1
|
||||
|
||||
# critical chi-square value
|
||||
self.confidence_level = confidence_level
|
||||
self.chi2_critical_ = chi2.ppf(confidence_level, df=self.ddof)
|
||||
self._samples = samples
|
||||
|
||||
@property
|
||||
def samples(self):
|
||||
return self._samples
|
||||
|
||||
def point_estimate(self):
|
||||
"""
|
||||
|
|
@ -222,15 +246,21 @@ class ConfidenceEllipseCLR(ConfidenceRegionABC):
|
|||
"""
|
||||
Instantiates a Confidence Ellipse in the Centered-Log Ratio (CLR) space.
|
||||
|
||||
:param X: np.ndarray of shape (n_bootstrap_samples, n_classes)
|
||||
:param samples: np.ndarray of shape (n_bootstrap_samples, n_classes)
|
||||
:param confidence_level: float, the confidence level (default 0.95)
|
||||
"""
|
||||
|
||||
def __init__(self, X, confidence_level=0.95):
|
||||
def __init__(self, samples, confidence_level=0.95):
|
||||
samples = np.asarray(samples)
|
||||
self.clr = CLRtransformation()
|
||||
Z = self.clr(X)
|
||||
self.mean_ = np.mean(X, axis=0)
|
||||
Z = self.clr(samples)
|
||||
self.mean_ = np.mean(samples, axis=0)
|
||||
self.conf_region_clr = ConfidenceEllipseSimplex(Z, confidence_level=confidence_level)
|
||||
self._samples = samples
|
||||
|
||||
@property
|
||||
def samples(self):
|
||||
return self._samples
|
||||
|
||||
def point_estimate(self):
|
||||
"""
|
||||
|
|
@ -260,19 +290,24 @@ class ConfidenceIntervals(ConfidenceRegionABC):
|
|||
"""
|
||||
Instantiates a region based on (independent) Confidence Intervals.
|
||||
|
||||
:param X: np.ndarray of shape (n_bootstrap_samples, n_classes)
|
||||
:param samples: np.ndarray of shape (n_bootstrap_samples, n_classes)
|
||||
:param confidence_level: float, the confidence level (default 0.95)
|
||||
"""
|
||||
def __init__(self, X, confidence_level=0.95):
|
||||
def __init__(self, samples, confidence_level=0.95):
|
||||
assert 0 < confidence_level < 1, f'{confidence_level=} must be in range(0,1)'
|
||||
|
||||
X = np.asarray(X)
|
||||
samples = np.asarray(samples)
|
||||
|
||||
self.means_ = X.mean(axis=0)
|
||||
self.means_ = samples.mean(axis=0)
|
||||
alpha = 1-confidence_level
|
||||
low_perc = (alpha/2.)*100
|
||||
high_perc = (1-alpha/2.)*100
|
||||
self.I_low, self.I_high = np.percentile(X, q=[low_perc, high_perc], axis=0)
|
||||
self.I_low, self.I_high = np.percentile(samples, q=[low_perc, high_perc], axis=0)
|
||||
self._samples = samples
|
||||
|
||||
@property
|
||||
def samples(self):
|
||||
return self._samples
|
||||
|
||||
def point_estimate(self):
|
||||
"""
|
||||
|
|
@ -297,6 +332,9 @@ class ConfidenceIntervals(ConfidenceRegionABC):
|
|||
|
||||
return proportion
|
||||
|
||||
def __repr__(self):
|
||||
return '['+', '.join(f'({low:.4f}, {high:.4f})' for (low,high) in zip(self.I_low, self.I_high))+']'
|
||||
|
||||
|
||||
class CLRtransformation:
|
||||
"""
|
||||
|
|
@ -339,6 +377,12 @@ class AggregativeBootstrap(WithConfidenceABC, AggregativeQuantifier):
|
|||
|
||||
During inference, the bootstrap repetitions are applied to the pre-classified test instances.
|
||||
|
||||
See
|
||||
`Moreo, A., Salvati, N.
|
||||
An Efficient Method for Deriving Confidence Intervals in Aggregative Quantification.
|
||||
Learning to Quantify: Methods and Applications (LQ 2025), co-located at ECML-PKDD 2025.
|
||||
pp 12-33 <https://lq-2025.github.io/proceedings/CompleteVolume.pdf>`_
|
||||
|
||||
:param quantifier: an aggregative quantifier
|
||||
:para n_train_samples: int, the number of training resamplings (defaults to 1, set to > 1 to activate a
|
||||
model-based bootstrap approach)
|
||||
|
|
@ -357,7 +401,8 @@ class AggregativeBootstrap(WithConfidenceABC, AggregativeQuantifier):
|
|||
n_test_samples=500,
|
||||
confidence_level=0.95,
|
||||
region='intervals',
|
||||
random_state=None):
|
||||
random_state=None,
|
||||
verbose=False):
|
||||
|
||||
assert isinstance(quantifier, AggregativeQuantifier), \
|
||||
f'base quantifier does not seem to be an instance of {AggregativeQuantifier.__name__}'
|
||||
|
|
@ -374,6 +419,7 @@ class AggregativeBootstrap(WithConfidenceABC, AggregativeQuantifier):
|
|||
self.confidence_level = confidence_level
|
||||
self.region = region
|
||||
self.random_state = random_state
|
||||
self.verbose = verbose
|
||||
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
data = LabelledCollection(classif_predictions, labels, classes=self.classes_)
|
||||
|
|
@ -399,6 +445,24 @@ class AggregativeBootstrap(WithConfidenceABC, AggregativeQuantifier):
|
|||
prev_mean, self.confidence = self.aggregate_conf(classif_predictions)
|
||||
return prev_mean
|
||||
|
||||
def aggregate_conf_sequential__(self, classif_predictions: np.ndarray, confidence_level=None):
|
||||
if confidence_level is None:
|
||||
confidence_level = self.confidence_level
|
||||
|
||||
n_samples = classif_predictions.shape[0]
|
||||
prevs = []
|
||||
with qp.util.temp_seed(self.random_state):
|
||||
for quantifier in self.quantifiers:
|
||||
for i in tqdm(range(self.n_test_samples), desc='resampling', total=self.n_test_samples, disable=not self.verbose):
|
||||
sample_i = resample(classif_predictions, n_samples=n_samples)
|
||||
prev_i = quantifier.aggregate(sample_i)
|
||||
prevs.append(prev_i)
|
||||
|
||||
conf = WithConfidenceABC.construct_region(prevs, confidence_level, method=self.region)
|
||||
prev_estim = conf.point_estimate()
|
||||
|
||||
return prev_estim, conf
|
||||
|
||||
def aggregate_conf(self, classif_predictions: np.ndarray, confidence_level=None):
|
||||
if confidence_level is None:
|
||||
confidence_level = self.confidence_level
|
||||
|
|
@ -407,10 +471,15 @@ class AggregativeBootstrap(WithConfidenceABC, AggregativeQuantifier):
|
|||
prevs = []
|
||||
with qp.util.temp_seed(self.random_state):
|
||||
for quantifier in self.quantifiers:
|
||||
for i in range(self.n_test_samples):
|
||||
sample_i = resample(classif_predictions, n_samples=n_samples)
|
||||
prev_i = quantifier.aggregate(sample_i)
|
||||
prevs.append(prev_i)
|
||||
results = Parallel(n_jobs=-1)(
|
||||
delayed(bootstrap_once)(i, classif_predictions, quantifier, n_samples)
|
||||
for i in range(self.n_test_samples)
|
||||
)
|
||||
prevs.extend(results)
|
||||
# for i in tqdm(range(self.n_test_samples), desc='resampling', total=self.n_test_samples, disable=not self.verbose):
|
||||
# sample_i = resample(classif_predictions, n_samples=n_samples)
|
||||
# prev_i = quantifier.aggregate(sample_i)
|
||||
# prevs.append(prev_i)
|
||||
|
||||
conf = WithConfidenceABC.construct_region(prevs, confidence_level, method=self.region)
|
||||
prev_estim = conf.point_estimate()
|
||||
|
|
@ -423,7 +492,7 @@ class AggregativeBootstrap(WithConfidenceABC, AggregativeQuantifier):
|
|||
self.aggregation_fit(classif_predictions, labels)
|
||||
return self
|
||||
|
||||
def quantify_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
||||
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
||||
predictions = self.quantifier.classify(instances)
|
||||
return self.aggregate_conf(predictions, confidence_level=confidence_level)
|
||||
|
||||
|
|
@ -435,9 +504,16 @@ class AggregativeBootstrap(WithConfidenceABC, AggregativeQuantifier):
|
|||
return self.quantifier._classifier_method()
|
||||
|
||||
|
||||
def bootstrap_once(i, classif_predictions, quantifier, n_samples):
|
||||
idx = np.random.randint(0, len(classif_predictions), n_samples)
|
||||
sample = classif_predictions[idx]
|
||||
prev = quantifier.aggregate(sample)
|
||||
return prev
|
||||
|
||||
|
||||
class BayesianCC(AggregativeCrispQuantifier, WithConfidenceABC):
|
||||
"""
|
||||
`Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ method,
|
||||
`Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ method (by Albert Ziegler and Paweł Czyż),
|
||||
which is a variant of :class:`ACC` that calculates the posterior probability distribution
|
||||
over the prevalence vectors, rather than providing a point estimate obtained
|
||||
by matrix inversion.
|
||||
|
|
@ -543,9 +619,115 @@ class BayesianCC(AggregativeCrispQuantifier, WithConfidenceABC):
|
|||
samples = self.sample_from_posterior(classif_predictions)[_bayesian.P_TEST_Y]
|
||||
return np.asarray(samples.mean(axis=0), dtype=float)
|
||||
|
||||
def quantify_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
||||
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
||||
if confidence_level is None:
|
||||
confidence_level = self.confidence_level
|
||||
classif_predictions = self.classify(instances)
|
||||
point_estimate = self.aggregate(classif_predictions)
|
||||
samples = self.get_prevalence_samples() # available after calling "aggregate" function
|
||||
region = WithConfidenceABC.construct_region(samples, confidence_level=self.confidence_level, method=self.region)
|
||||
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
|
||||
return point_estimate, region
|
||||
|
||||
|
||||
class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
||||
"""
|
||||
`Precise Quantifier: Bayesian distribution matching quantifier <https://arxiv.org/abs/2507.06061>,
|
||||
which is a variant of :class:`HDy` that calculates the posterior probability distribution
|
||||
over the prevalence vectors, rather than providing a point estimate.
|
||||
|
||||
This method relies on extra dependencies, which have to be installed via:
|
||||
`$ pip install quapy[bayes]`
|
||||
|
||||
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||
for `k`); or as a tuple `(X,y)` defining the specific set of data to use for validation. Set to
|
||||
None when the method does not require any validation data, in order to avoid that some portion of
|
||||
the training data be wasted.
|
||||
:param num_warmup: number of warmup iterations for the STAN sampler (default 500)
|
||||
:param num_samples: number of samples to draw from the posterior (default 1000)
|
||||
:param stan_seed: random seed for the STAN sampler (default 0)
|
||||
:param region: string, set to `intervals` for constructing confidence intervals (default), or to
|
||||
`ellipse` for constructing an ellipse in the probability simplex, or to `ellipse-clr` for
|
||||
constructing an ellipse in the Centered-Log Ratio (CLR) unconstrained space.
|
||||
"""
|
||||
def __init__(self,
|
||||
classifier: BaseEstimator=None,
|
||||
fit_classifier=True,
|
||||
val_split: int = 5,
|
||||
nbins: int = 4,
|
||||
fixed_bins: bool = False,
|
||||
num_warmup: int = 500,
|
||||
num_samples: int = 1_000,
|
||||
stan_seed: int = 0,
|
||||
confidence_level: float = 0.95,
|
||||
region: str = 'intervals'):
|
||||
|
||||
if num_warmup <= 0:
|
||||
raise ValueError(f'parameter {num_warmup=} must be a positive integer')
|
||||
if num_samples <= 0:
|
||||
raise ValueError(f'parameter {num_samples=} must be a positive integer')
|
||||
|
||||
if not _bayesian.DEPENDENCIES_INSTALLED:
|
||||
raise ImportError("Auxiliary dependencies are required. "
|
||||
"Run `$ pip install quapy[bayes]` to install them.")
|
||||
|
||||
super().__init__(classifier, fit_classifier, val_split)
|
||||
|
||||
self.nbins = nbins
|
||||
self.fixed_bins = fixed_bins
|
||||
self.num_warmup = num_warmup
|
||||
self.num_samples = num_samples
|
||||
self.stan_seed = stan_seed
|
||||
self.stan_code = _bayesian.load_stan_file()
|
||||
self.confidence_level = confidence_level
|
||||
self.region = region
|
||||
|
||||
def aggregation_fit(self, classif_predictions, labels):
|
||||
y_pred = classif_predictions[:, self.pos_label]
|
||||
|
||||
# Compute bin limits
|
||||
if self.fixed_bins:
|
||||
# Uniform bins in [0,1]
|
||||
self.bin_limits = np.linspace(0, 1, self.nbins + 1)
|
||||
else:
|
||||
# Quantile bins
|
||||
self.bin_limits = np.quantile(y_pred, np.linspace(0, 1, self.nbins + 1))
|
||||
|
||||
# Assign each prediction to a bin
|
||||
bin_indices = np.digitize(y_pred, self.bin_limits[1:-1], right=True)
|
||||
|
||||
# Positive and negative masks
|
||||
pos_mask = (labels == self.pos_label)
|
||||
neg_mask = ~pos_mask
|
||||
|
||||
# Count positives and negatives per bin
|
||||
self.pos_hist = np.bincount(bin_indices[pos_mask], minlength=self.nbins)
|
||||
self.neg_hist = np.bincount(bin_indices[neg_mask], minlength=self.nbins)
|
||||
|
||||
def aggregate(self, classif_predictions):
|
||||
Px_test = classif_predictions[:, self.pos_label]
|
||||
test_hist, _ = np.histogram(Px_test, bins=self.bin_limits)
|
||||
prevs = _bayesian.pq_stan(
|
||||
self.stan_code, self.nbins, self.pos_hist, self.neg_hist, test_hist,
|
||||
self.num_samples, self.num_warmup, self.stan_seed
|
||||
).flatten()
|
||||
self.prev_distribution = np.vstack([1-prevs, prevs]).T
|
||||
return self.prev_distribution.mean(axis=0)
|
||||
|
||||
def aggregate_conf(self, predictions, confidence_level=None):
|
||||
if confidence_level is None:
|
||||
confidence_level = self.confidence_level
|
||||
point_estimate = self.aggregate(predictions)
|
||||
samples = self.prev_distribution
|
||||
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
|
||||
return point_estimate, region
|
||||
|
||||
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
||||
predictions = self.classify(instances)
|
||||
return self.aggregate_conf(predictions, confidence_level=confidence_level)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,17 @@
|
|||
from typing import Union, Callable
|
||||
from itertools import product
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Callable, Counter
|
||||
import numpy as np
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.utils import resample
|
||||
from sklearn.preprocessing import normalize
|
||||
|
||||
from quapy.method.confidence import WithConfidenceABC, ConfidenceRegionABC
|
||||
from quapy.functional import get_divergence
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.method.base import BaseQuantifier, BinaryQuantifier
|
||||
import quapy.functional as F
|
||||
from scipy.optimize import lsq_linear
|
||||
from scipy import sparse
|
||||
|
||||
|
||||
class MaximumLikelihoodPrevalenceEstimation(BaseQuantifier):
|
||||
|
|
@ -149,53 +155,164 @@ class DMx(BaseQuantifier):
|
|||
return F.argmin_prevalence(loss, n_classes, method=self.search)
|
||||
|
||||
|
||||
# class ReadMe(BaseQuantifier):
|
||||
#
|
||||
# def __init__(self, bootstrap_trials=100, bootstrap_range=100, bagging_trials=100, bagging_range=25, **vectorizer_kwargs):
|
||||
# raise NotImplementedError('under development ...')
|
||||
# self.bootstrap_trials = bootstrap_trials
|
||||
# self.bootstrap_range = bootstrap_range
|
||||
# self.bagging_trials = bagging_trials
|
||||
# self.bagging_range = bagging_range
|
||||
# self.vectorizer_kwargs = vectorizer_kwargs
|
||||
#
|
||||
# def fit(self, data: LabelledCollection):
|
||||
# X, y = data.Xy
|
||||
# self.vectorizer = CountVectorizer(binary=True, **self.vectorizer_kwargs)
|
||||
# X = self.vectorizer.fit_transform(X)
|
||||
# self.class_conditional_X = {i: X[y==i] for i in range(data.classes_)}
|
||||
#
|
||||
# def predict(self, X):
|
||||
# X = self.vectorizer.transform(X)
|
||||
#
|
||||
# # number of features
|
||||
# num_docs, num_feats = X.shape
|
||||
#
|
||||
# # bootstrap
|
||||
# p_boots = []
|
||||
# for _ in range(self.bootstrap_trials):
|
||||
# docs_idx = np.random.choice(num_docs, size=self.bootstra_range, replace=False)
|
||||
# class_conditional_X = {i: X[docs_idx] for i, X in self.class_conditional_X.items()}
|
||||
# Xboot = X[docs_idx]
|
||||
#
|
||||
# # bagging
|
||||
# p_bags = []
|
||||
# for _ in range(self.bagging_trials):
|
||||
# feat_idx = np.random.choice(num_feats, size=self.bagging_range, replace=False)
|
||||
# class_conditional_Xbag = {i: X[:, feat_idx] for i, X in class_conditional_X.items()}
|
||||
# Xbag = Xboot[:,feat_idx]
|
||||
# p = self.std_constrained_linear_ls(Xbag, class_conditional_Xbag)
|
||||
# p_bags.append(p)
|
||||
# p_boots.append(np.mean(p_bags, axis=0))
|
||||
#
|
||||
# p_mean = np.mean(p_boots, axis=0)
|
||||
# p_std = np.std(p_bags, axis=0)
|
||||
#
|
||||
# return p_mean
|
||||
#
|
||||
#
|
||||
# def std_constrained_linear_ls(self, X, class_cond_X: dict):
|
||||
# pass
|
||||
|
||||
|
||||
class ReadMe(BaseQuantifier, WithConfidenceABC):
|
||||
"""
|
||||
ReadMe is a non-aggregative quantification system proposed by
|
||||
`Daniel Hopkins and Gary King, 2007. A method of automated nonparametric content analysis for
|
||||
social science. American Journal of Political Science, 54(1):229–247.
|
||||
<https://onlinelibrary.wiley.com/doi/abs/10.1111/j.1540-5907.2009.00428.x>`_.
|
||||
The idea is to estimate `Q(Y=i)` directly from:
|
||||
|
||||
:math:`Q(X)=\\sum_{i=1} Q(X|Y=i) Q(Y=i)`
|
||||
|
||||
via least-squares regression, i.e., without incurring the cost of computing posterior probabilities.
|
||||
However, this poses a very difficult representation in which the vector `Q(X)` and the matrix `Q(X|Y=i)`
|
||||
can be of very high dimensions. In order to render the problem tracktable, ReadMe performs bagging in
|
||||
the feature space. ReadMe also combines bagging with bootstrap in order to derive confidence intervals
|
||||
around point estimations.
|
||||
|
||||
We use the same default parameters as in the official
|
||||
`R implementation <https://github.com/iqss-research/ReadMeV1/blob/master/R/prototype.R>`_.
|
||||
|
||||
:param prob_model: str ('naive', or 'full'), selects the modality in which the probabilities `Q(X)` and
|
||||
`Q(X|Y)` are to be modelled. Options include "full", which corresponds to the original formulation of
|
||||
ReadMe, in which X is constrained to be a binary matrix (e.g., of term presence/absence) and in which
|
||||
`Q(X)` and `Q(X|Y)` are modelled, respectively, as matrices of `(2^K, 1)` and `(2^K, n)` values, where
|
||||
`K` is the number of columns in the data matrix (i.e., `bagging_range`), and `n` is the number of classes.
|
||||
Of course, this approach is computationally prohibited for large `K`, so the computation is restricted to data
|
||||
matrices with `K<=25` (although we recommend even smaller values of `K`). A much faster model is "naive", which
|
||||
considers the `Q(X)` and `Q(X|Y)` be multinomial distributions under the `bag-of-words` perspective. In this
|
||||
case, `bagging_range` can be set to much larger values. Default is "full" (i.e., original ReadMe behavior).
|
||||
:param bootstrap_trials: int, number of bootstrap trials (default 300)
|
||||
:param bagging_trials: int, number of bagging trials (default 300)
|
||||
:param bagging_range: int, number of features to keep for each bagging trial (default 15)
|
||||
:param confidence_level: float, a value in (0,1) reflecting the desired confidence level (default 0.95)
|
||||
:param region: str in 'intervals', 'ellipse', 'ellipse-clr'; indicates the preferred method for
|
||||
defining the confidence region (see :class:`WithConfidenceABC`)
|
||||
:param random_state: int or None, allows replicability (default None)
|
||||
:param verbose: bool, whether to display information during the process (default False)
|
||||
"""
|
||||
|
||||
MAX_FEATURES_FOR_EMPIRICAL_ESTIMATION = 25
|
||||
PROBABILISTIC_MODELS = ["naive", "full"]
|
||||
|
||||
def __init__(self,
|
||||
prob_model="full",
|
||||
bootstrap_trials=300,
|
||||
bagging_trials=300,
|
||||
bagging_range=15,
|
||||
confidence_level=0.95,
|
||||
region='intervals',
|
||||
random_state=None,
|
||||
verbose=False):
|
||||
assert prob_model in ReadMe.PROBABILISTIC_MODELS, \
|
||||
f'unknown {prob_model=}, valid ones are {ReadMe.PROBABILISTIC_MODELS=}'
|
||||
self.prob_model = prob_model
|
||||
self.bootstrap_trials = bootstrap_trials
|
||||
self.bagging_trials = bagging_trials
|
||||
self.bagging_range = bagging_range
|
||||
self.confidence_level = confidence_level
|
||||
self.region = region
|
||||
self.random_state = random_state
|
||||
self.verbose = verbose
|
||||
|
||||
def fit(self, X, y):
|
||||
self._check_matrix(X)
|
||||
|
||||
self.rng = np.random.default_rng(self.random_state)
|
||||
self.classes_ = np.unique(y)
|
||||
|
||||
|
||||
Xsize = X.shape[0]
|
||||
|
||||
# Bootstrap loop
|
||||
self.Xboots, self.yboots = [], []
|
||||
for _ in range(self.bootstrap_trials):
|
||||
idx = self.rng.choice(Xsize, size=Xsize, replace=True)
|
||||
self.Xboots.append(X[idx])
|
||||
self.yboots.append(y[idx])
|
||||
|
||||
return self
|
||||
|
||||
def predict_conf(self, X, confidence_level=0.95) -> (np.ndarray, ConfidenceRegionABC):
|
||||
self._check_matrix(X)
|
||||
|
||||
n_features = X.shape[1]
|
||||
boots_prevalences = []
|
||||
for Xboots, yboots in tqdm(
|
||||
zip(self.Xboots, self.yboots),
|
||||
desc='bootstrap predictions', total=self.bootstrap_trials, disable=not self.verbose
|
||||
):
|
||||
bagging_estimates = []
|
||||
for _ in range(self.bagging_trials):
|
||||
feat_idx = self.rng.choice(n_features, size=self.bagging_range, replace=False)
|
||||
Xboots_bagging = Xboots[:, feat_idx]
|
||||
X_boots_bagging = X[:, feat_idx]
|
||||
bagging_prev = self._quantify_iteration(Xboots_bagging, yboots, X_boots_bagging)
|
||||
bagging_estimates.append(bagging_prev)
|
||||
|
||||
boots_prevalences.append(np.mean(bagging_estimates, axis=0))
|
||||
|
||||
conf = WithConfidenceABC.construct_region(boots_prevalences, confidence_level, method=self.region)
|
||||
prev_estim = conf.point_estimate()
|
||||
|
||||
return prev_estim, conf
|
||||
|
||||
def predict(self, X):
|
||||
prev_estim, _ = self.predict_conf(X)
|
||||
return prev_estim
|
||||
|
||||
def _quantify_iteration(self, Xtr, ytr, Xte):
|
||||
"""Single ReadMe estimate."""
|
||||
PX_given_Y = np.asarray([self._compute_P(Xtr[ytr == c]) for i,c in enumerate(self.classes_)])
|
||||
PX = self._compute_P(Xte)
|
||||
|
||||
res = lsq_linear(A=PX_given_Y.T, b=PX, bounds=(0, 1))
|
||||
pY = np.maximum(res.x, 0)
|
||||
return pY / pY.sum()
|
||||
|
||||
def _check_matrix(self, X):
|
||||
"""the "full" model requires estimating empirical distributions; due to the high computational cost,
|
||||
this function is only made available for binary matrices"""
|
||||
if self.prob_model == 'full' and not self._is_binary_matrix(X):
|
||||
raise ValueError('the empirical distribution can only be computed efficiently on binary matrices')
|
||||
|
||||
def _is_binary_matrix(self, X):
|
||||
data = X.data if sparse.issparse(X) else X
|
||||
return np.all((data == 0) | (data == 1))
|
||||
|
||||
def _compute_P(self, X):
|
||||
if self.prob_model == 'naive':
|
||||
return self._multinomial_distribution(X)
|
||||
elif self.prob_model == 'full':
|
||||
return self._empirical_distribution(X)
|
||||
else:
|
||||
raise ValueError(f'unknown {self.prob_model}; valid ones are {ReadMe.PROBABILISTIC_MODELS=}')
|
||||
|
||||
def _empirical_distribution(self, X):
|
||||
|
||||
if X.shape[1] > self.MAX_FEATURES_FOR_EMPIRICAL_ESTIMATION:
|
||||
raise ValueError(f'the empirical distribution can only be computed efficiently for dimensions '
|
||||
f'less or equal than {self.MAX_FEATURES_FOR_EMPIRICAL_ESTIMATION}')
|
||||
|
||||
# we convert every binary row (e.g., 0 0 1 0 1) into the equivalent number (e.g., 5)
|
||||
K = X.shape[1]
|
||||
binary_powers = 1 << np.arange(K-1, -1, -1) # (2^K, ..., 32, 16, 8, 4, 2, 1)
|
||||
X_as_binary_numbers = X @ binary_powers
|
||||
|
||||
# count occurrences and compute probs
|
||||
counts = np.bincount(X_as_binary_numbers, minlength=2 ** K).astype(float)
|
||||
probs = counts / counts.sum()
|
||||
return probs
|
||||
|
||||
def _multinomial_distribution(self, X):
|
||||
PX = np.asarray(X.sum(axis=0))
|
||||
PX = normalize(PX, norm='l1', axis=1)
|
||||
return PX.ravel()
|
||||
|
||||
|
||||
|
||||
|
||||
def _get_features_range(X):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
data {
|
||||
int<lower=0> n_bucket;
|
||||
array[n_bucket] int<lower=0> train_pos;
|
||||
array[n_bucket] int<lower=0> train_neg;
|
||||
array[n_bucket] int<lower=0> test;
|
||||
int<lower=0,upper=1> posterior;
|
||||
}
|
||||
|
||||
transformed data{
|
||||
row_vector<lower=0>[n_bucket] train_pos_rv;
|
||||
row_vector<lower=0>[n_bucket] train_neg_rv;
|
||||
row_vector<lower=0>[n_bucket] test_rv;
|
||||
real n_test;
|
||||
|
||||
train_pos_rv = to_row_vector( train_pos );
|
||||
train_neg_rv = to_row_vector( train_neg );
|
||||
test_rv = to_row_vector( test );
|
||||
n_test = sum( test );
|
||||
}
|
||||
|
||||
parameters {
|
||||
simplex[n_bucket] p_neg;
|
||||
simplex[n_bucket] p_pos;
|
||||
real<lower=0,upper=1> prev_prior;
|
||||
}
|
||||
|
||||
model {
|
||||
if( posterior ) {
|
||||
target += train_neg_rv * log( p_neg );
|
||||
target += train_pos_rv * log( p_pos );
|
||||
target += test_rv * log( p_neg * ( 1 - prev_prior) + p_pos * prev_prior );
|
||||
}
|
||||
}
|
||||
|
||||
generated quantities {
|
||||
real<lower=0,upper=1> prev;
|
||||
prev = sum( binomial_rng(test, 1 / ( 1 + (p_neg./p_pos) *(1-prev_prior)/prev_prior ) ) ) / n_test;
|
||||
}
|
||||
|
||||
|
|
@ -410,7 +410,7 @@ def group_params(param_grid: dict):
|
|||
"""
|
||||
classifier_params, quantifier_params = {}, {}
|
||||
for key, values in param_grid.items():
|
||||
if key.startswith('classifier__') or key == 'val_split':
|
||||
if 'classifier__' in key or key == 'val_split':
|
||||
classifier_params[key] = values
|
||||
else:
|
||||
quantifier_params[key] = values
|
||||
|
|
|
|||
8
setup.py
8
setup.py
|
|
@ -111,6 +111,12 @@ setup(
|
|||
#
|
||||
packages=find_packages(include=['quapy', 'quapy.*']), # Required
|
||||
|
||||
package_data={
|
||||
# For the 'quapy.method' package, include all files
|
||||
# in the 'stan' subdirectory that end with .stan
|
||||
'quapy.method': ['stan/*.stan']
|
||||
},
|
||||
|
||||
python_requires='>=3.8, <4',
|
||||
|
||||
install_requires=['scikit-learn', 'pandas', 'tqdm', 'matplotlib', 'joblib', 'xlrd', 'abstention', 'ucimlrepo', 'certifi'],
|
||||
|
|
@ -124,7 +130,7 @@ setup(
|
|||
# Similar to `install_requires` above, these must be valid existing
|
||||
# projects.
|
||||
extras_require={ # Optional
|
||||
'bayes': ['jax', 'jaxlib', 'numpyro'],
|
||||
'bayes': ['jax', 'jaxlib', 'numpyro', 'pystan'],
|
||||
'neural': ['torch'],
|
||||
'tests': ['certifi'],
|
||||
'docs' : ['sphinx-rtd-theme', 'myst-parser'],
|
||||
|
|
|
|||
Loading…
Reference in New Issue