115 lines
3.4 KiB
Python
115 lines
3.4 KiB
Python
import pandas as pd
|
|
import numpy as np
|
|
from glob import glob
|
|
from os.path import join
|
|
|
|
from quapy.data import LabelledCollection
|
|
from quapy.protocol import AbstractProtocol
|
|
import json
|
|
|
|
|
|
def load_sample(path, class_name):
|
|
"""
|
|
Loads a sample json as a dataframe and returns text and labels for
|
|
the given class_name
|
|
|
|
:param path: path to a json file
|
|
:param class_name: string representing the target class
|
|
:return: texts, labels for class_name
|
|
"""
|
|
df = pd.read_json(path)
|
|
text = df.text.values
|
|
labels = df[class_name].values
|
|
return text, labels
|
|
|
|
|
|
def get_text_label_score(df, class_name, vectorizer=None, filter_classes=None):
|
|
text = df.text.values
|
|
labels = df[class_name].values
|
|
rel_score = df.score.values
|
|
|
|
if filter_classes is not None:
|
|
idx = np.isin(labels, filter_classes)
|
|
text = text[idx]
|
|
labels = labels[idx]
|
|
rel_score = rel_score[idx]
|
|
|
|
if vectorizer is not None:
|
|
text = vectorizer.transform(text)
|
|
|
|
order = np.argsort(-rel_score)
|
|
return text[order], labels[order], rel_score[order]
|
|
|
|
|
|
class RetrievedSamples:
|
|
|
|
def __init__(self,
|
|
class_home: str,
|
|
test_rankings_path: str,
|
|
vectorizer,
|
|
class_name,
|
|
classes=None
|
|
):
|
|
self.class_home = class_home
|
|
self.test_rankings_df = pd.read_json(test_rankings_path)
|
|
self.vectorizer = vectorizer
|
|
self.class_name = class_name
|
|
self.classes=classes
|
|
|
|
|
|
def __call__(self):
|
|
tests_df = self.test_rankings_df
|
|
class_name = self.class_name
|
|
vectorizer = self.vectorizer
|
|
|
|
for file in self._list_queries():
|
|
|
|
# print(file)
|
|
|
|
# loads the training sample
|
|
train_df = pd.read_json(file)
|
|
if len(train_df) == 0:
|
|
print('empty dataframe: ', file)
|
|
else:
|
|
Xtr, ytr, score_tr = get_text_label_score(train_df, class_name, vectorizer, filter_classes=self.classes)
|
|
|
|
# loads the test sample
|
|
query_id = self._get_query_id_from_path(file)
|
|
sel_df = tests_df[tests_df.qid == int(query_id)]
|
|
Xte, yte, score_te = get_text_label_score(sel_df, class_name, vectorizer, filter_classes=self.classes)
|
|
|
|
yield (Xtr, ytr, score_tr), (Xte, yte, score_te)
|
|
|
|
def _list_queries(self):
|
|
return sorted(glob(join(self.class_home, 'training_Query*200SPLIT.json')))
|
|
|
|
# def _get_test_sample(self, query_id, max_lines=-1):
|
|
# df = self.test_rankings_df
|
|
# sel_df = df[df.qid==int(query_id)]
|
|
# return get_text_label_score(sel_df)
|
|
# texts = sel_df.text.values
|
|
# try:
|
|
# labels = sel_df[self.class_name].values
|
|
# except KeyError as e:
|
|
# print(f'error: key {self.class_name} not found in test rankings')
|
|
# raise e
|
|
# if max_lines > 0 and len(texts) > max_lines:
|
|
# ranks = sel_df.rank.values
|
|
# idx = np.argsort(ranks)[:max_lines]
|
|
# texts = np.asarray(texts)[idx]
|
|
# labels = np.asarray(labels)[idx]
|
|
# return texts, labels
|
|
|
|
def total(self):
|
|
return len(self._list_queries())
|
|
|
|
def _get_query_id_from_path(self, path):
|
|
prefix = 'training_Query-'
|
|
posfix = 'Sample-200SPLIT'
|
|
qid = path
|
|
qid = qid[:qid.index(posfix)]
|
|
qid = qid[qid.index(prefix) + len(prefix):]
|
|
return qid
|
|
|
|
|