1
0
Fork 0
QuaPy/MultiLabel/data/rcv_reader.py

153 lines
5.1 KiB
Python
Executable File

from zipfile import ZipFile
import xml.etree.ElementTree as ET
from data.labeled import LabelledDocuments
from util.file import list_files
from os.path import join, exists
from util.file import download_file_if_not_exists
import re
from collections import Counter
RCV1_TOPICHIER_URL = "http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/a02-orig-topics-hierarchy/rcv1.topics.hier.orig"
RCV1_BASE_URL = "http://www.daviddlewis.com/resources/testcollections/rcv1/"
rcv1_test_data_gz = ['lyrl2004_tokens_test_pt0.dat.gz',
'lyrl2004_tokens_test_pt1.dat.gz',
'lyrl2004_tokens_test_pt2.dat.gz',
'lyrl2004_tokens_test_pt3.dat.gz']
rcv1_train_data_gz = ['lyrl2004_tokens_train.dat.gz']
rcv1_doc_cats_data_gz = 'rcv1-v2.topics.qrels.gz'
class RCV_Document:
def __init__(self, id, text, categories, date=''):
self.id = id
self.date = date
self.text = text
self.categories = categories
class IDRangeException(Exception): pass
nwords = []
def parse_document(xml_content, valid_id_range=None):
root = ET.fromstring(xml_content)
doc_id = root.attrib['itemid']
if valid_id_range is not None:
if not valid_id_range[0] <= int(doc_id) <= valid_id_range[1]:
raise IDRangeException
doc_categories = [cat.attrib['code'] for cat in
root.findall('.//metadata/codes[@class="bip:topics:1.0"]/code')]
doc_date = root.attrib['date']
doc_title = root.find('.//title').text
doc_headline = root.find('.//headline').text
doc_body = '\n'.join([p.text for p in root.findall('.//text/p')])
if not doc_body:
raise ValueError('Empty document')
if doc_title is None: doc_title = ''
if doc_headline is None or doc_headline in doc_title: doc_headline = ''
text = '\n'.join([doc_title, doc_headline, doc_body]).strip()
return RCV_Document(id=doc_id, text=text, categories=doc_categories, date=doc_date)
def fetch_RCV1(data_path, subset='all'):
assert subset in ['train', 'test', 'all'], 'split should either be "train", "test", or "all"'
request = []
labels = set()
read_documents = 0
training_documents = 23149
test_documents = 781265
if subset == 'all':
split_range = (2286, 810596)
expected = training_documents+test_documents
elif subset == 'train':
split_range = (2286, 26150)
expected = training_documents
else:
split_range = (26151, 810596)
expected = test_documents
# global nwords
# nwords=[]
for part in list_files(data_path):
if not re.match('\d+\.zip', part): continue
target_file = join(data_path, part)
assert exists(target_file), \
"You don't seem to have the file "+part+" in " + data_path + ", and the RCV1 corpus can not be downloaded"+\
" w/o a formal permission. Please, refer to " + RCV1_BASE_URL + " for more information."
zipfile = ZipFile(target_file)
for xmlfile in zipfile.namelist():
xmlcontent = zipfile.open(xmlfile).read()
try:
doc = parse_document(xmlcontent, valid_id_range=split_range)
labels.update(doc.categories)
request.append(doc)
read_documents += 1
except (IDRangeException,ValueError) as e:
pass
print('\r[{}] read {} documents'.format(part, len(request)), end='')
if read_documents == expected: break
if read_documents == expected: break
print()
# print('ave:{} std {} min {} max {}'.format(np.mean(nwords), np.std(nwords), np.min(nwords), np.max(nwords)))
return LabelledDocuments(data=[d.text for d in request], target=[d.categories for d in request], target_names=list(labels))
def fetch_topic_hierarchy(path, topics='all'):
assert topics in ['all', 'leaves']
download_file_if_not_exists(RCV1_TOPICHIER_URL, path)
hierarchy = {}
for line in open(path, 'rt'):
parts = line.strip().split()
parent,child = parts[1],parts[3]
if parent not in hierarchy:
hierarchy[parent]=[]
hierarchy[parent].append(child)
del hierarchy['None']
del hierarchy['Root']
print(hierarchy)
if topics=='all':
topics = set(hierarchy.keys())
for parent in hierarchy.keys():
topics.update(hierarchy[parent])
return list(topics)
elif topics=='leaves':
parents = set(hierarchy.keys())
childs = set()
for parent in hierarchy.keys():
childs.update(hierarchy[parent])
return list(childs.difference(parents))
if __name__=='__main__':
# example
RCV1_PATH = '../../datasets/RCV1-v2/unprocessed_corpus'
rcv1_train = fetch_RCV1(RCV1_PATH, subset='train')
rcv1_test = fetch_RCV1(RCV1_PATH, subset='test')
print('read {} documents in rcv1-train, and {} labels'.format(len(rcv1_train.data), len(rcv1_train.target_names)))
print('read {} documents in rcv1-test, and {} labels'.format(len(rcv1_test.data), len(rcv1_test.target_names)))
cats = Counter()
for cats in rcv1_train.target: cats.update(cats)
print('RCV1', cats)