forked from moreo/QuaPy
116 lines
4.5 KiB
Python
116 lines
4.5 KiB
Python
import os.path
|
|
|
|
from sklearn.metrics import f1_score
|
|
|
|
import functions as fn
|
|
import quapy as qp
|
|
import argparse
|
|
from data import LabelledCollection
|
|
|
|
|
|
def eval_classifier(learner, test:LabelledCollection):
|
|
predictions = learner.predict(test.instances)
|
|
true_labels = test.labels
|
|
# f1 = f1_score(true_labels, predictions, average='macro')
|
|
f1 = f1_score(true_labels, predictions, average='binary')
|
|
# f1 = (true_labels==predictions).mean()
|
|
return f1
|
|
|
|
|
|
def main(args):
|
|
|
|
datasetname = args.dataset
|
|
k = args.k
|
|
init_nD = args.initsize
|
|
init_prev = [1-args.initprev, args.initprev]
|
|
sampling_fn = getattr(fn, args.sampling)
|
|
max_iterations = args.iter
|
|
outputdir = './results'
|
|
|
|
collection = qp.util.pickled_resource(f'./dataset/{datasetname}.pkl', fn.create_dataset, datasetname)
|
|
nD = len(collection)
|
|
|
|
with qp.util.temp_seed(args.seed):
|
|
# initial labelled data selection
|
|
idx = collection.sampling_index(init_nD, *init_prev)
|
|
train, pool = fn.split_from_index(collection, idx)
|
|
first_train = LabelledCollection(train.instances, train.labels)
|
|
|
|
# recall_target = 0.99
|
|
qp.util.create_if_not_exist(outputdir)
|
|
|
|
i = 0
|
|
with open(os.path.join(outputdir, fn.experiment_name(args)), 'wt') as foo:
|
|
def tee(msg):
|
|
foo.write(msg + '\n')
|
|
foo.flush()
|
|
print(msg)
|
|
|
|
tee('it\t%\ttr-size\tte-size\ttr-prev\tte-prev\tte-estim\tte-estimCC\tR\tRhat\tRhatCC\tShift\tAE\tAE_CC\tMF1_Q\tMF1_Clf')
|
|
|
|
while True:
|
|
|
|
pool_p_hat_cc, classifier = fn.estimate_prev_CC(train, pool)
|
|
pool_p_hat, q_classifier = fn.estimate_prev_Q(train, pool, args.quantifier)
|
|
|
|
f1_clf = eval_classifier(classifier, pool)
|
|
f1_q = eval_classifier(q_classifier, pool)
|
|
|
|
tr_p = train.prevalence()
|
|
te_p = pool.prevalence()
|
|
nDtr = len(train)
|
|
nDte = len(pool)
|
|
|
|
r_hat_cc = fn.recall(tr_p, pool_p_hat_cc, nDtr, nDte)
|
|
r_hat = fn.recall(tr_p, pool_p_hat, nDtr, nDte)
|
|
r = fn.recall(tr_p, te_p, nDtr, nDte)
|
|
tr_te_shift = qp.error.ae(tr_p, te_p)
|
|
|
|
progress = 100 * nDtr / nD
|
|
|
|
q_ae = qp.error.ae(te_p, pool_p_hat)
|
|
cc_ae = qp.error.ae(te_p, pool_p_hat_cc)
|
|
|
|
tee(f'{i}\t{progress:.2f}\t{nDtr}\t{nDte}\t{tr_p[1]:.3f}\t{te_p[1]:.3f}\t{pool_p_hat[1]:.3f}\t{pool_p_hat_cc[1]:.3f}'
|
|
f'\t{r:.3f}\t{r_hat:.3f}\t{r_hat_cc:.3f}\t{tr_te_shift:.5f}\t{q_ae:.4f}\t{cc_ae:.4f}\t{f1_q:.3f}\t{f1_clf:.3f}')
|
|
|
|
if nDte < k:
|
|
print('[stop] too few documents remaining')
|
|
break
|
|
elif i+1 == max_iterations:
|
|
print('[stop] maximum number of iterations reached')
|
|
break
|
|
|
|
top_relevant_idx = sampling_fn(pool, classifier, k, progress)
|
|
selected, pool = fn.split_from_index(pool, top_relevant_idx)
|
|
train = train + selected
|
|
|
|
i += 1
|
|
|
|
|
|
if __name__=='__main__':
|
|
parser = argparse.ArgumentParser(description='e-Discovery')
|
|
parser.add_argument('--dataset', metavar='DATASET', type=str, help='Dataset name',
|
|
default='RCV1.C4')
|
|
parser.add_argument('--quantifier', metavar='METHOD', type=str, help='Quantification method',
|
|
default='EMQ')
|
|
parser.add_argument('--sampling', metavar='SAMPLING', type=str, help='Sampling criterion',
|
|
default='relevance_sampling')
|
|
parser.add_argument('--iter', metavar='INT', type=int, help='number of iterations (-1 to set no limit)',
|
|
default=-1)
|
|
parser.add_argument('--k', metavar='BATCH', type=int, help='number of documents in a batch',
|
|
default=100)
|
|
parser.add_argument('--initsize', metavar='SIZE', type=int, help='number of labelled documents at the beginning',
|
|
default=1000)
|
|
parser.add_argument('--initprev', metavar='PREV', type=float,
|
|
help='prevalence of the initial sample (-1 for uniform sampling)',
|
|
default=0.5)
|
|
parser.add_argument('--seed', metavar='SEED', type=int,
|
|
help='random seed',
|
|
default=1)
|
|
args = parser.parse_args()
|
|
|
|
assert 0 < args.initprev < 1, 'wrong value for initsize; should be in (0., 1.)'
|
|
|
|
main(args)
|