forked from moreo/QuaPy
some little refactor
This commit is contained in:
parent
11319ffd0e
commit
14dbfb567b
|
@ -43,10 +43,17 @@ def experiment_name(args:argparse.Namespace):
|
||||||
def split_from_index(collection: LabelledCollection, index: np.ndarray):
|
def split_from_index(collection: LabelledCollection, index: np.ndarray):
|
||||||
in_index_set = set(index)
|
in_index_set = set(index)
|
||||||
out_index_set = set(range(len(collection))) - in_index_set
|
out_index_set = set(range(len(collection))) - in_index_set
|
||||||
out_index = np.asarray(list(out_index_set), dtype=int)
|
out_index = np.asarray(sorted(out_index_set), dtype=int)
|
||||||
return collection.sampling_from_index(index), collection.sampling_from_index(out_index)
|
return collection.sampling_from_index(index), collection.sampling_from_index(out_index)
|
||||||
|
|
||||||
|
|
||||||
|
def move_documents(target: LabelledCollection, origin: LabelledCollection, idx_origin: np.ndarray):
|
||||||
|
# moves documents (indexed by idx_origin) from origin to target
|
||||||
|
selected, reduced_origin = split_from_index(origin, idx_origin)
|
||||||
|
enhanced_target = target + selected
|
||||||
|
return enhanced_target, reduced_origin
|
||||||
|
|
||||||
|
|
||||||
def uniform_sampling(pool: LabelledCollection, classifier: BaseEstimator, k: int, *args):
|
def uniform_sampling(pool: LabelledCollection, classifier: BaseEstimator, k: int, *args):
|
||||||
return np.random.choice(len(pool), k, replace=False)
|
return np.random.choice(len(pool), k, replace=False)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
import os.path
|
import os.path
|
||||||
|
|
||||||
from sklearn.metrics import f1_score
|
from sklearn.metrics import f1_score
|
||||||
|
|
||||||
import functions as fn
|
import functions as fn
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -22,10 +20,13 @@ def main(args):
|
||||||
datasetname = args.dataset
|
datasetname = args.dataset
|
||||||
k = args.k
|
k = args.k
|
||||||
init_nD = args.initsize
|
init_nD = args.initsize
|
||||||
init_prev = [1-args.initprev, args.initprev]
|
|
||||||
sampling_fn = getattr(fn, args.sampling)
|
sampling_fn = getattr(fn, args.sampling)
|
||||||
max_iterations = args.iter
|
max_iterations = args.iter
|
||||||
outputdir = './results'
|
outputdir = './results'
|
||||||
|
clf_name = args.classifier
|
||||||
|
q_name = args.quantifier
|
||||||
|
|
||||||
|
qp.util.create_if_not_exist(outputdir)
|
||||||
|
|
||||||
collection = qp.util.pickled_resource(f'./dataset/{datasetname}.pkl', fn.create_dataset, datasetname)
|
collection = qp.util.pickled_resource(f'./dataset/{datasetname}.pkl', fn.create_dataset, datasetname)
|
||||||
nD = len(collection)
|
nD = len(collection)
|
||||||
|
@ -37,11 +38,9 @@ def main(args):
|
||||||
else:
|
else:
|
||||||
idx = collection.sampling_index(init_nD, *[1 - args.initprev, args.initprev])
|
idx = collection.sampling_index(init_nD, *[1 - args.initprev, args.initprev])
|
||||||
train, pool = fn.split_from_index(collection, idx)
|
train, pool = fn.split_from_index(collection, idx)
|
||||||
first_train = LabelledCollection(train.instances, train.labels)
|
#first_train = LabelledCollection(train.instances, train.labels)
|
||||||
|
|
||||||
# recall_target = 0.99
|
# recall_target = 0.99
|
||||||
qp.util.create_if_not_exist(outputdir)
|
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
with open(os.path.join(outputdir, fn.experiment_name(args)), 'wt') as foo:
|
with open(os.path.join(outputdir, fn.experiment_name(args)), 'wt') as foo:
|
||||||
def tee(msg):
|
def tee(msg):
|
||||||
|
@ -53,8 +52,8 @@ def main(args):
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
pool_p_hat_cc, classifier = fn.estimate_prev_CC(train, pool, args.classifier)
|
pool_p_hat_cc, classifier = fn.estimate_prev_CC(train, pool, clf_name)
|
||||||
pool_p_hat, q_classifier = fn.estimate_prev_Q(train, pool, args.quantifier, args.classifier)
|
pool_p_hat_q, q_classifier = fn.estimate_prev_Q(train, pool, q_name, clf_name)
|
||||||
|
|
||||||
f1_clf = eval_classifier(classifier, pool)
|
f1_clf = eval_classifier(classifier, pool)
|
||||||
f1_q = eval_classifier(q_classifier, pool)
|
f1_q = eval_classifier(q_classifier, pool)
|
||||||
|
@ -65,17 +64,17 @@ def main(args):
|
||||||
nDte = len(pool)
|
nDte = len(pool)
|
||||||
|
|
||||||
r_hat_cc = fn.recall(tr_p, pool_p_hat_cc, nDtr, nDte)
|
r_hat_cc = fn.recall(tr_p, pool_p_hat_cc, nDtr, nDte)
|
||||||
r_hat = fn.recall(tr_p, pool_p_hat, nDtr, nDte)
|
r_hat_q = fn.recall(tr_p, pool_p_hat_q, nDtr, nDte)
|
||||||
r = fn.recall(tr_p, te_p, nDtr, nDte)
|
r = fn.recall(tr_p, te_p, nDtr, nDte)
|
||||||
tr_te_shift = qp.error.ae(tr_p, te_p)
|
tr_te_shift = qp.error.ae(tr_p, te_p)
|
||||||
|
|
||||||
progress = 100 * nDtr / nD
|
progress = 100 * nDtr / nD
|
||||||
|
|
||||||
q_ae = qp.error.ae(te_p, pool_p_hat)
|
ae_q = qp.error.ae(te_p, pool_p_hat_q)
|
||||||
cc_ae = qp.error.ae(te_p, pool_p_hat_cc)
|
ae_cc = qp.error.ae(te_p, pool_p_hat_cc)
|
||||||
|
|
||||||
tee(f'{i}\t{progress:.2f}\t{nDtr}\t{nDte}\t{tr_p[1]:.3f}\t{te_p[1]:.3f}\t{pool_p_hat[1]:.3f}\t{pool_p_hat_cc[1]:.3f}'
|
tee(f'{i}\t{progress:.2f}\t{nDtr}\t{nDte}\t{tr_p[1]:.3f}\t{te_p[1]:.3f}\t{pool_p_hat_q[1]:.3f}\t{pool_p_hat_cc[1]:.3f}'
|
||||||
f'\t{r:.3f}\t{r_hat:.3f}\t{r_hat_cc:.3f}\t{tr_te_shift:.5f}\t{q_ae:.4f}\t{cc_ae:.4f}\t{f1_q:.3f}\t{f1_clf:.3f}')
|
f'\t{r:.3f}\t{r_hat_q:.3f}\t{r_hat_cc:.3f}\t{tr_te_shift:.5f}\t{ae_q:.4f}\t{ae_cc:.4f}\t{f1_q:.3f}\t{f1_clf:.3f}')
|
||||||
|
|
||||||
if nDte < k:
|
if nDte < k:
|
||||||
print('[stop] too few documents remaining')
|
print('[stop] too few documents remaining')
|
||||||
|
@ -85,13 +84,12 @@ def main(args):
|
||||||
break
|
break
|
||||||
|
|
||||||
top_relevant_idx = sampling_fn(pool, classifier, k, progress)
|
top_relevant_idx = sampling_fn(pool, classifier, k, progress)
|
||||||
selected, pool = fn.split_from_index(pool, top_relevant_idx)
|
train, pool = fn.move_documents(train, pool, top_relevant_idx)
|
||||||
train = train + selected
|
|
||||||
|
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
if __name__=='__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='e-Discovery')
|
parser = argparse.ArgumentParser(description='e-Discovery')
|
||||||
parser.add_argument('--dataset', metavar='DATASET', type=str, help='Dataset name',
|
parser.add_argument('--dataset', metavar='DATASET', type=str, help='Dataset name',
|
||||||
default='RCV1.C4')
|
default='RCV1.C4')
|
||||||
|
|
|
@ -8,8 +8,8 @@ initsize=500
|
||||||
initprev=-1
|
initprev=-1
|
||||||
seed=1
|
seed=1
|
||||||
Q=ACC
|
Q=ACC
|
||||||
CLS=svm
|
CLS=lr
|
||||||
sampling=mix_sampling
|
sampling=proportional_sampling
|
||||||
|
|
||||||
filepath="./results/classifier:"$CLS"__dataset:"$dataset"__initprev:"$initprev"__initsize:"$initsize"__iter:"$iter"__k:"$k"__quantifier:"$Q"__sampling:"$sampling"__seed:"$seed".csv"
|
filepath="./results/classifier:"$CLS"__dataset:"$dataset"__initprev:"$initprev"__initsize:"$initsize"__iter:"$iter"__k:"$k"__quantifier:"$Q"__sampling:"$sampling"__seed:"$seed".csv"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue