forked from moreo/QuaPy
153 lines
5.1 KiB
Python
Executable File
153 lines
5.1 KiB
Python
Executable File
from zipfile import ZipFile
|
|
import xml.etree.ElementTree as ET
|
|
from data.labeled import LabelledDocuments
|
|
from util.file import list_files
|
|
from os.path import join, exists
|
|
from util.file import download_file_if_not_exists
|
|
import re
|
|
from collections import Counter
|
|
|
|
RCV1_TOPICHIER_URL = "http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/a02-orig-topics-hierarchy/rcv1.topics.hier.orig"
|
|
RCV1_BASE_URL = "http://www.daviddlewis.com/resources/testcollections/rcv1/"
|
|
|
|
rcv1_test_data_gz = ['lyrl2004_tokens_test_pt0.dat.gz',
|
|
'lyrl2004_tokens_test_pt1.dat.gz',
|
|
'lyrl2004_tokens_test_pt2.dat.gz',
|
|
'lyrl2004_tokens_test_pt3.dat.gz']
|
|
|
|
rcv1_train_data_gz = ['lyrl2004_tokens_train.dat.gz']
|
|
|
|
rcv1_doc_cats_data_gz = 'rcv1-v2.topics.qrels.gz'
|
|
|
|
class RCV_Document:
|
|
def __init__(self, id, text, categories, date=''):
|
|
self.id = id
|
|
self.date = date
|
|
self.text = text
|
|
self.categories = categories
|
|
|
|
class IDRangeException(Exception): pass
|
|
|
|
nwords = []
|
|
|
|
def parse_document(xml_content, valid_id_range=None):
|
|
root = ET.fromstring(xml_content)
|
|
|
|
doc_id = root.attrib['itemid']
|
|
if valid_id_range is not None:
|
|
if not valid_id_range[0] <= int(doc_id) <= valid_id_range[1]:
|
|
raise IDRangeException
|
|
|
|
doc_categories = [cat.attrib['code'] for cat in
|
|
root.findall('.//metadata/codes[@class="bip:topics:1.0"]/code')]
|
|
|
|
doc_date = root.attrib['date']
|
|
doc_title = root.find('.//title').text
|
|
doc_headline = root.find('.//headline').text
|
|
doc_body = '\n'.join([p.text for p in root.findall('.//text/p')])
|
|
|
|
if not doc_body:
|
|
raise ValueError('Empty document')
|
|
|
|
if doc_title is None: doc_title = ''
|
|
if doc_headline is None or doc_headline in doc_title: doc_headline = ''
|
|
text = '\n'.join([doc_title, doc_headline, doc_body]).strip()
|
|
|
|
return RCV_Document(id=doc_id, text=text, categories=doc_categories, date=doc_date)
|
|
|
|
|
|
def fetch_RCV1(data_path, subset='all'):
|
|
|
|
assert subset in ['train', 'test', 'all'], 'split should either be "train", "test", or "all"'
|
|
|
|
request = []
|
|
labels = set()
|
|
read_documents = 0
|
|
|
|
training_documents = 23149
|
|
test_documents = 781265
|
|
|
|
if subset == 'all':
|
|
split_range = (2286, 810596)
|
|
expected = training_documents+test_documents
|
|
elif subset == 'train':
|
|
split_range = (2286, 26150)
|
|
expected = training_documents
|
|
else:
|
|
split_range = (26151, 810596)
|
|
expected = test_documents
|
|
|
|
# global nwords
|
|
# nwords=[]
|
|
for part in list_files(data_path):
|
|
if not re.match('\d+\.zip', part): continue
|
|
target_file = join(data_path, part)
|
|
assert exists(target_file), \
|
|
"You don't seem to have the file "+part+" in " + data_path + ", and the RCV1 corpus can not be downloaded"+\
|
|
" w/o a formal permission. Please, refer to " + RCV1_BASE_URL + " for more information."
|
|
zipfile = ZipFile(target_file)
|
|
for xmlfile in zipfile.namelist():
|
|
xmlcontent = zipfile.open(xmlfile).read()
|
|
try:
|
|
doc = parse_document(xmlcontent, valid_id_range=split_range)
|
|
labels.update(doc.categories)
|
|
request.append(doc)
|
|
read_documents += 1
|
|
except (IDRangeException,ValueError) as e:
|
|
pass
|
|
print('\r[{}] read {} documents'.format(part, len(request)), end='')
|
|
if read_documents == expected: break
|
|
if read_documents == expected: break
|
|
|
|
print()
|
|
# print('ave:{} std {} min {} max {}'.format(np.mean(nwords), np.std(nwords), np.min(nwords), np.max(nwords)))
|
|
|
|
return LabelledDocuments(data=[d.text for d in request], target=[d.categories for d in request], target_names=list(labels))
|
|
|
|
|
|
|
|
def fetch_topic_hierarchy(path, topics='all'):
|
|
assert topics in ['all', 'leaves']
|
|
|
|
download_file_if_not_exists(RCV1_TOPICHIER_URL, path)
|
|
hierarchy = {}
|
|
for line in open(path, 'rt'):
|
|
parts = line.strip().split()
|
|
parent,child = parts[1],parts[3]
|
|
if parent not in hierarchy:
|
|
hierarchy[parent]=[]
|
|
hierarchy[parent].append(child)
|
|
|
|
del hierarchy['None']
|
|
del hierarchy['Root']
|
|
print(hierarchy)
|
|
|
|
if topics=='all':
|
|
topics = set(hierarchy.keys())
|
|
for parent in hierarchy.keys():
|
|
topics.update(hierarchy[parent])
|
|
return list(topics)
|
|
elif topics=='leaves':
|
|
parents = set(hierarchy.keys())
|
|
childs = set()
|
|
for parent in hierarchy.keys():
|
|
childs.update(hierarchy[parent])
|
|
return list(childs.difference(parents))
|
|
|
|
|
|
if __name__=='__main__':
|
|
|
|
# example
|
|
|
|
RCV1_PATH = '../../datasets/RCV1-v2/unprocessed_corpus'
|
|
|
|
rcv1_train = fetch_RCV1(RCV1_PATH, subset='train')
|
|
rcv1_test = fetch_RCV1(RCV1_PATH, subset='test')
|
|
|
|
print('read {} documents in rcv1-train, and {} labels'.format(len(rcv1_train.data), len(rcv1_train.target_names)))
|
|
print('read {} documents in rcv1-test, and {} labels'.format(len(rcv1_test.data), len(rcv1_test.target_names)))
|
|
|
|
cats = Counter()
|
|
for cats in rcv1_train.target: cats.update(cats)
|
|
print('RCV1', cats)
|