forked from moreo/QuaPy
230 lines
11 KiB
Python
Executable File
230 lines
11 KiB
Python
Executable File
import os,sys
|
|
from sklearn.datasets import get_data_home, fetch_20newsgroups
|
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
from sklearn.preprocessing import MultiLabelBinarizer
|
|
from MultiLabel.data.jrcacquis_reader import fetch_jrcacquis
|
|
from MultiLabel.data.ohsumed_reader import fetch_ohsumed50k
|
|
from MultiLabel.data.reuters21578_reader import fetch_reuters21578
|
|
from MultiLabel.data.rcv_reader import fetch_RCV1
|
|
from MultiLabel.data.wipo_reader import fetch_WIPOgamma, WipoGammaDocument
|
|
import pickle
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from os.path import join
|
|
import re
|
|
|
|
|
|
def init_vectorizer():
|
|
return TfidfVectorizer(min_df=5, sublinear_tf=True)
|
|
|
|
|
|
class Dataset:
|
|
|
|
dataset_available = {'reuters21578', '20newsgroups', 'ohsumed', 'rcv1', 'ohsumed', 'jrcall',
|
|
'wipo-sl-mg','wipo-ml-mg','wipo-sl-sc','wipo-ml-sc'}
|
|
|
|
def __init__(self, name):
|
|
assert name in Dataset.dataset_available, f'dataset {name} is not available'
|
|
if name=='reuters21578':
|
|
self._load_reuters()
|
|
elif name == '20newsgroups':
|
|
self._load_20news()
|
|
elif name == 'rcv1':
|
|
self._load_rcv1()
|
|
elif name == 'ohsumed':
|
|
self._load_ohsumed()
|
|
elif name == 'jrcall':
|
|
self._load_jrc(version='all')
|
|
elif name == 'wipo-sl-mg':
|
|
self._load_wipo('singlelabel', 'maingroup')
|
|
elif name == 'wipo-ml-mg':
|
|
self._load_wipo('multilabel', 'maingroup')
|
|
elif name == 'wipo-sl-sc':
|
|
self._load_wipo('singlelabel', 'subclass')
|
|
elif name == 'wipo-ml-sc':
|
|
self._load_wipo('multilabel', 'subclass')
|
|
|
|
self.nC = self.devel_labelmatrix.shape[1]
|
|
self._vectorizer = init_vectorizer()
|
|
self._vectorizer.fit(self.devel_raw)
|
|
self.vocabulary = self._vectorizer.vocabulary_
|
|
|
|
def show(self):
|
|
nTr_docs = len(self.devel_raw)
|
|
nTe_docs = len(self.test_raw)
|
|
nfeats = len(self._vectorizer.vocabulary_)
|
|
nC = self.devel_labelmatrix.shape[1]
|
|
nD=nTr_docs+nTe_docs
|
|
print(f'{self.classification_type}, nD={nD}=({nTr_docs}+{nTe_docs}), nF={nfeats}, nC={nC}')
|
|
return self
|
|
|
|
def _load_reuters(self):
|
|
data_path = os.path.join(get_data_home(), 'reuters21578')
|
|
devel = fetch_reuters21578(subset='train', data_path=data_path)
|
|
test = fetch_reuters21578(subset='test', data_path=data_path)
|
|
|
|
self.classification_type = 'multilabel'
|
|
self.devel_raw, self.test_raw = mask_numbers(devel.data), mask_numbers(test.data)
|
|
self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(devel.target, test.target)
|
|
self.devel_target, self.test_target = self.devel_labelmatrix, self.test_labelmatrix
|
|
|
|
def _load_rcv1(self):
|
|
data_path = '../datasets/RCV1-v2/unprocessed_corpus' #TODO: check when missing
|
|
devel = fetch_RCV1(subset='train', data_path=data_path)
|
|
test = fetch_RCV1(subset='test', data_path=data_path)
|
|
|
|
self.classification_type = 'multilabel'
|
|
self.devel_raw, self.test_raw = mask_numbers(devel.data), mask_numbers(test.data)
|
|
self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(devel.target, test.target)
|
|
self.devel_target, self.test_target = self.devel_labelmatrix, self.test_labelmatrix
|
|
|
|
def _load_jrc(self, version):
|
|
assert version in ['300','all'], 'allowed versions are "300" or "all"'
|
|
data_path = "../datasets/JRC_Acquis_v3"
|
|
tr_years=list(range(1986, 2006))
|
|
te_years=[2006]
|
|
if version=='300':
|
|
training_docs, tr_cats = fetch_jrcacquis(data_path=data_path, years=tr_years, cat_threshold=1,most_frequent=300)
|
|
test_docs, te_cats = fetch_jrcacquis(data_path=data_path, years=te_years, cat_filter=tr_cats)
|
|
else:
|
|
training_docs, tr_cats = fetch_jrcacquis(data_path=data_path, years=tr_years, cat_threshold=1)
|
|
test_docs, te_cats = fetch_jrcacquis(data_path=data_path, years=te_years, cat_filter=tr_cats)
|
|
print(f'load jrc-acquis (English) with {len(tr_cats)} tr categories ({len(te_cats)} te categories)')
|
|
|
|
devel_data = JRCAcquis_Document.get_text(training_docs)
|
|
test_data = JRCAcquis_Document.get_text(test_docs)
|
|
devel_target = JRCAcquis_Document.get_target(training_docs)
|
|
test_target = JRCAcquis_Document.get_target(test_docs)
|
|
|
|
self.classification_type = 'multilabel'
|
|
self.devel_raw, self.test_raw = mask_numbers(devel_data), mask_numbers(test_data)
|
|
self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(devel_target, test_target)
|
|
self.devel_target, self.test_target = self.devel_labelmatrix, self.test_labelmatrix
|
|
|
|
def _load_ohsumed(self):
|
|
data_path = os.path.join(get_data_home(), 'ohsumed50k')
|
|
devel = fetch_ohsumed50k(subset='train', data_path=data_path)
|
|
test = fetch_ohsumed50k(subset='test', data_path=data_path)
|
|
|
|
self.classification_type = 'multilabel'
|
|
self.devel_raw, self.test_raw = mask_numbers(devel.data), mask_numbers(test.data)
|
|
self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(devel.target, test.target)
|
|
self.devel_target, self.test_target = self.devel_labelmatrix, self.test_labelmatrix
|
|
|
|
def _load_20news(self):
|
|
metadata = ('headers', 'footers', 'quotes')
|
|
devel = fetch_20newsgroups(subset='train', remove=metadata)
|
|
test = fetch_20newsgroups(subset='test', remove=metadata)
|
|
self.classification_type = 'singlelabel'
|
|
self.devel_raw, self.test_raw = mask_numbers(devel.data), mask_numbers(test.data)
|
|
self.devel_target, self.test_target = devel.target, test.target
|
|
self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(self.devel_target.reshape(-1,1), self.test_target.reshape(-1,1))
|
|
|
|
def _load_fasttext_data(self,name):
|
|
data_path='../datasets/fastText'
|
|
self.classification_type = 'singlelabel'
|
|
name=name.replace('-','_')
|
|
train_file = join(data_path,f'{name}.train')
|
|
assert os.path.exists(train_file), f'file {name} not found, please place the fasttext data in {data_path}' #' or specify the path' #todo
|
|
self.devel_raw, self.devel_target = load_fasttext_format(train_file)
|
|
self.test_raw, self.test_target = load_fasttext_format(join(data_path, f'{name}.test'))
|
|
self.devel_raw = mask_numbers(self.devel_raw)
|
|
self.test_raw = mask_numbers(self.test_raw)
|
|
self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(self.devel_target.reshape(-1, 1), self.test_target.reshape(-1, 1))
|
|
|
|
def _load_wipo(self, classmode, classlevel):
|
|
assert classmode in {'singlelabel', 'multilabel'}, 'available class_mode are sl (single-label) or ml (multi-label)'
|
|
data_path = '../datasets/WIPO/wipo-gamma/en'
|
|
data_proc = '../datasets/WIPO-extracted'
|
|
|
|
devel = fetch_WIPOgamma(subset='train', classification_level=classlevel, data_home=data_path, extracted_path=data_proc, text_fields=['abstract'])
|
|
test = fetch_WIPOgamma(subset='test', classification_level=classlevel, data_home=data_path, extracted_path=data_proc, text_fields=['abstract'])
|
|
|
|
devel_data = [d.text for d in devel]
|
|
test_data = [d.text for d in test]
|
|
self.devel_raw, self.test_raw = mask_numbers(devel_data), mask_numbers(test_data)
|
|
|
|
self.classification_type = classmode
|
|
if classmode== 'multilabel':
|
|
devel_target = [d.all_labels for d in devel]
|
|
test_target = [d.all_labels for d in test]
|
|
self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(devel_target, test_target)
|
|
self.devel_target, self.test_target = self.devel_labelmatrix, self.test_labelmatrix
|
|
else:
|
|
devel_target = [d.main_label for d in devel]
|
|
test_target = [d.main_label for d in test]
|
|
# only for labels with at least one training document
|
|
class_id = {labelname:index for index,labelname in enumerate(sorted(set(devel_target)))}
|
|
devel_target = np.array([class_id[id] for id in devel_target]).astype(int)
|
|
test_target = np.array([class_id.get(id,None) for id in test_target])
|
|
if None in test_target:
|
|
print(f'deleting {(test_target==None).sum()} test documents without valid categories')
|
|
keep_pos = test_target!=None
|
|
self.test_raw = (np.asarray(self.test_raw)[keep_pos]).tolist()
|
|
test_target = test_target[keep_pos]
|
|
test_target=test_target.astype(int)
|
|
self.devel_target, self.test_target = devel_target, test_target
|
|
self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(self.devel_target.reshape(-1, 1), self.test_target.reshape(-1, 1))
|
|
|
|
def vectorize(self):
|
|
if not hasattr(self, 'Xtr') or not hasattr(self, 'Xte'):
|
|
self.Xtr = self._vectorizer.transform(self.devel_raw)
|
|
self.Xte = self._vectorizer.transform(self.test_raw)
|
|
self.Xtr.sort_indices()
|
|
self.Xte.sort_indices()
|
|
return self.Xtr, self.Xte
|
|
|
|
def analyzer(self):
|
|
return self._vectorizer.build_analyzer()
|
|
|
|
@classmethod
|
|
def load(cls, dataset_name, pickle_path=None):
|
|
|
|
if pickle_path:
|
|
if os.path.exists(pickle_path):
|
|
print(f'loading pickled dataset from {pickle_path}')
|
|
dataset = pickle.load(open(pickle_path, 'rb'))
|
|
else:
|
|
print(f'fetching dataset and dumping it into {pickle_path}')
|
|
dataset = Dataset(name=dataset_name)
|
|
print('vectorizing for faster processing')
|
|
dataset.vectorize()
|
|
print('dumping')
|
|
pickle.dump(dataset, open(pickle_path, 'wb', pickle.HIGHEST_PROTOCOL))
|
|
else:
|
|
print(f'loading dataset {dataset_name}')
|
|
dataset = Dataset(name=dataset_name)
|
|
|
|
print('[Done]')
|
|
return dataset
|
|
|
|
|
|
def _label_matrix(tr_target, te_target):
|
|
mlb = MultiLabelBinarizer(sparse_output=True)
|
|
ytr = mlb.fit_transform(tr_target)
|
|
yte = mlb.transform(te_target)
|
|
print(mlb.classes_)
|
|
return ytr, yte
|
|
|
|
|
|
def load_fasttext_format(path):
|
|
print(f'loading {path}')
|
|
labels,docs=[],[]
|
|
for line in tqdm(open(path, 'rt').readlines()):
|
|
space = line.strip().find(' ')
|
|
label = int(line[:space].replace('__label__',''))-1
|
|
labels.append(label)
|
|
docs.append(line[space+1:])
|
|
labels=np.asarray(labels,dtype=int)
|
|
return docs,labels
|
|
|
|
|
|
def mask_numbers(data, number_mask='numbermask'):
|
|
mask = re.compile(r'\b[0-9][0-9.,-]*\b')
|
|
masked = []
|
|
for text in tqdm(data, desc='masking numbers'):
|
|
masked.append(mask.sub(number_mask, text))
|
|
return masked
|
|
|
|
|