from sklearn.base import BaseEstimator from sklearn.calibration import CalibratedClassifierCV from sklearn.linear_model import LogisticRegression import quapy as qp from method.base import BaseQuantifier from quapy.data import LabelledCollection from quapy.method.aggregative import EMQ, ClassifyAndCount, PACC from quapy import functional as F import numpy as np def split_from_index(collection:LabelledCollection, index:np.ndarray): in_index_set = set(index) out_index_set = set(range(len(collection))) - in_index_set out_index = np.asarray(list(out_index_set), dtype=int) return collection.sampling_from_index(index), collection.sampling_from_index(out_index) def relevance_sampling_index(pool:LabelledCollection, classifier:BaseEstimator, k:int): prob = classifier.predict_proba(pool.instances)[:, 1].flatten() top_relevant_idx = np.argsort(-prob)[:k] return top_relevant_idx def recall(train_prev, pool_prev, train_len, pool_len): nD = train_len + pool_len pTr = train_len / nD pPool = pool_len / nD recall = train_prev[1] * pTr / (train_prev[1] * pTr + pool_prev[1] * pPool) return recall data = qp.datasets.fetch_reviews('imdb', tfidf=True, min_df=5) # collection = data.training + data.test collection = data.training.sampling(10000, 0.75) nD = len(collection) # initial labelled data selection init_nD = 100 init_prev = 0.5 idx = collection.sampling_index(init_nD, init_prev) train, pool = split_from_index(collection, idx) k = 50 recall_target = 0.95 # Q = EMQ(CalibratedClassifierCV(LogisticRegression())) # Q = ClassifyAndCount(LogisticRegression()) Q = PACC(LogisticRegression()) i = 0 while True: Q.fit(train) pool_p_hat = Q.quantify(pool.instances) tr_p = train.prevalence() te_p = pool.prevalence() nDtr = len(train) nDte = len(pool) r_hat = recall(tr_p, pool_p_hat, nDtr, nDte) r = recall(tr_p, te_p, nDtr, nDte) r_error = abs(r_hat-r) proc_percent = 100*nDtr/nD print(f'{i}\t [{proc_percent:.2f}%] tr-prev={F.strprev(tr_p)} te-prev={F.strprev(te_p)} te-estim={F.strprev(pool_p_hat)} R={r:.3f} Rhat={r_hat:.3f} E={r_error:.3f}') # if r_hat >= recall_target: if proc_percent > 95: break top_relevant_idx = relevance_sampling_index(pool, Q.learner, k) selected, pool = split_from_index(pool, top_relevant_idx) train = train + selected i += 1