QuaPy/Retrieval/commons.py

133 lines
4.3 KiB
Python

import pandas as pd
import numpy as np
from glob import glob
from os.path import join
from quapy.data import LabelledCollection
from quapy.protocol import AbstractProtocol
import json
def load_sample(path, class_name, max_lines=-1):
"""
Loads a sample json as a dataframe and returns text and labels for
the given class_name
:param path: path to a json file
:param class_name: string representing the target class
:param max_lines: if provided and > 0 then returns only the
first requested number of instances
:return: texts and labels for class_name
"""
df = pd.read_json(path)
text = df.text.values
try:
labels = df[class_name].values
except KeyError as e:
print(f'error in {path}; key {class_name} not found')
raise e
if max_lines is not None and max_lines>0:
text = text[:max_lines]
labels = labels[:max_lines]
return text, labels
class TextRankings:
def __init__(self, path, class_name):
self.obj = json.load(open(path, 'rt'))
self.class_name = class_name
def get_sample_Xy(self, sample_id, max_lines=-1):
sample_id = str(sample_id)
O = self.obj
docs_ids = [doc_id for doc_id, query_id in O['qid'].items() if query_id == sample_id]
texts = [O['text'][doc_id] for doc_id in docs_ids]
labels = [O[self.class_name][doc_id] for doc_id in docs_ids]
if max_lines > 0 and len(texts) > max_lines:
ranks = [int(O['rank'][doc_id]) for doc_id in docs_ids]
sel = np.argsort(ranks)[:max_lines]
texts = np.asarray(texts)[sel]
labels = np.asarray(labels)[sel]
return texts, labels
def filter_by_classes(X, y, classes):
idx = np.isin(y, classes)
return X[idx], y[idx]
class RetrievedSamples(AbstractProtocol):
def __init__(self,
class_home: str,
test_rankings_path: str,
load_fn,
vectorizer,
class_name,
max_train_lines=None,
max_test_lines=None,
classes=None
):
self.class_home = class_home
self.test_rankings_df = pd.read_json(test_rankings_path)
self.load_fn = load_fn
self.vectorizer = vectorizer
self.class_name = class_name
self.max_train_lines = max_train_lines
self.max_test_lines = max_test_lines
self.classes=classes
def __call__(self):
for file in self._list_queries():
texts, y = self.load_fn(file, class_name=self.class_name, max_lines=self.max_train_lines)
texts, y = filter_by_classes(texts, y, self.classes)
X = self.vectorizer.transform(texts)
train_sample = LabelledCollection(X, y, classes=self.classes)
query_id = self._get_query_id_from_path(file)
texts, y = self._get_test_sample(query_id, max_lines=self.max_test_lines)
texts, y = filter_by_classes(texts, y, self.classes)
X = self.vectorizer.transform(texts)
try:
test_sample = LabelledCollection(X, y, classes=train_sample.classes_)
yield train_sample, test_sample
except ValueError as e:
print(f'file {file} caused an exception: {e}')
yield None, None
def _list_queries(self):
return sorted(glob(join(self.class_home, 'training_Query*200SPLIT.json')))
def _get_test_sample(self, query_id, max_lines=-1):
df = self.test_rankings_df
sel_df = df[df.qid==int(query_id)]
texts = sel_df.text.values
try:
labels = sel_df[self.class_name].values
except KeyError as e:
print(f'error: key {self.class_name} not found in test rankings')
raise e
if max_lines > 0 and len(texts) > max_lines:
ranks = sel_df.rank.values
idx = np.argsort(ranks)[:max_lines]
texts = np.asarray(texts)[idx]
labels = np.asarray(labels)[idx]
return texts, labels
def total(self):
return len(self._list_queries())
def _get_query_id_from_path(self, path):
prefix = 'training_Query-'
posfix = 'Sample-200SPLIT'
qid = path
qid = qid[:qid.index(posfix)]
qid = qid[qid.index(prefix) + len(prefix):]
return qid