1
0
Fork 0
QuaPy/MultiLabel/data/reuters21578_reader.py

189 lines
6.5 KiB
Python
Executable File

# Modified version of the code originally implemented by Eustache Diemert <eustache@diemert.fr>
# @FedericoV <https://github.com/FedericoV/>
# with License: BSD 3 clause
import os.path
import re
import tarfile
from sklearn.datasets import get_data_home
from six.moves import html_parser
from six.moves import urllib
import pickle
from glob import glob
import numpy as np
from data.labeled import LabelledDocuments
def _not_in_sphinx():
# Hack to detect whether we are running by the sphinx builder
return '__file__' in globals()
class ReutersParser(html_parser.HTMLParser):
"""Utility class to parse a SGML file and yield documents one at a time."""
def __init__(self, encoding='latin-1', data_path=None):
self.data_path = data_path
self.download_if_not_exist()
self.tr_docs = []
self.te_docs = []
html_parser.HTMLParser.__init__(self)
self._reset()
self.encoding = encoding
self.empty_docs = 0
def handle_starttag(self, tag, attrs):
method = 'start_' + tag
getattr(self, method, lambda x: None)(attrs)
def handle_endtag(self, tag):
method = 'end_' + tag
getattr(self, method, lambda: None)()
def _reset(self):
self.in_title = 0
self.in_body = 0
self.in_topics = 0
self.in_topic_d = 0
self.in_unproc_text = 0
self.title = ""
self.body = ""
self.topics = []
self.topic_d = ""
self.text = ""
def parse(self, fd):
for chunk in fd:
self.feed(chunk.decode(self.encoding))
self.close()
def handle_data(self, data):
if self.in_body:
self.body += data
elif self.in_title:
self.title += data
elif self.in_topic_d:
self.topic_d += data
elif self.in_unproc_text:
self.text += data
def start_reuters(self, attributes):
topic_attr = attributes[0][1]
lewissplit_attr = attributes[1][1]
self.lewissplit = u'unused'
if topic_attr==u'YES':
if lewissplit_attr == u'TRAIN':
self.lewissplit = 'train'
elif lewissplit_attr == u'TEST':
self.lewissplit = 'test'
pass
def end_reuters(self):
self.body = re.sub(r'\s+', r' ', self.body)
if self.lewissplit != u'unused':
parsed_doc = {'title': self.title, 'body': self.body, 'unproc':self.text, 'topics': self.topics}
if (self.title+self.body+self.text).strip() == '':
self.empty_docs += 1
if self.lewissplit == u'train':
self.tr_docs.append(parsed_doc)
elif self.lewissplit == u'test':
self.te_docs.append(parsed_doc)
self._reset()
def start_title(self, attributes):
self.in_title = 1
def end_title(self):
self.in_title = 0
def start_body(self, attributes):
self.in_body = 1
def end_body(self):
self.in_body = 0
def start_topics(self, attributes):
self.in_topics = 1
def end_topics(self):
self.in_topics = 0
def start_text(self, attributes):
if len(attributes)>0 and attributes[0][1] == u'UNPROC':
self.in_unproc_text = 1
def end_text(self):
self.in_unproc_text = 0
def start_d(self, attributes):
self.in_topic_d = 1
def end_d(self):
if self.in_topics:
self.topics.append(self.topic_d)
self.in_topic_d = 0
self.topic_d = ""
def download_if_not_exist(self):
DOWNLOAD_URL = ('http://archive.ics.uci.edu/ml/machine-learning-databases/'
'reuters21578-mld/reuters21578.tar.gz')
ARCHIVE_FILENAME = 'reuters21578.tar.gz'
if self.data_path is None:
self.data_path = os.path.join(get_data_home(), "reuters")
if not os.path.exists(self.data_path):
"""Download the dataset."""
print("downloading dataset (once and for all) into %s" % self.data_path)
os.mkdir(self.data_path)
def progress(blocknum, bs, size):
total_sz_mb = '%.2f MB' % (size / 1e6)
current_sz_mb = '%.2f MB' % ((blocknum * bs) / 1e6)
if _not_in_sphinx():
print('\rdownloaded %s / %s' % (current_sz_mb, total_sz_mb), end='')
archive_path = os.path.join(self.data_path, ARCHIVE_FILENAME)
urllib.request.urlretrieve(DOWNLOAD_URL, filename=archive_path,
reporthook=progress)
if _not_in_sphinx():
print('\r', end='')
print("untarring Reuters dataset...")
tarfile.open(archive_path, 'r:gz').extractall(self.data_path)
print("done.")
def fetch_reuters21578(data_path=None, subset='train'):
if data_path is None:
data_path = os.path.join(get_data_home(), 'reuters21578')
reuters_pickle_path = os.path.join(data_path, "reuters." + subset + ".pickle")
if not os.path.exists(reuters_pickle_path):
parser = ReutersParser(data_path=data_path)
for filename in glob(os.path.join(data_path, "*.sgm")):
parser.parse(open(filename, 'rb'))
# index category names with a unique numerical code (only considering categories with training examples)
tr_categories = np.unique(np.concatenate([doc['topics'] for doc in parser.tr_docs])).tolist()
def pickle_documents(docs, subset):
for doc in docs:
doc['topics'] = [tr_categories.index(t) for t in doc['topics'] if t in tr_categories]
pickle_docs = {'categories': tr_categories, 'documents': docs}
pickle.dump(pickle_docs, open(os.path.join(data_path, "reuters." + subset + ".pickle"), 'wb'),
protocol=pickle.HIGHEST_PROTOCOL)
return pickle_docs
pickle_tr = pickle_documents(parser.tr_docs, "train")
pickle_te = pickle_documents(parser.te_docs, "test")
# self.sout('Empty docs %d' % parser.empty_docs)
requested_subset = pickle_tr if subset == 'train' else pickle_te
else:
requested_subset = pickle.load(open(reuters_pickle_path, 'rb'))
data = [(u'{title}\n{body}\n{unproc}'.format(**doc), doc['topics']) for doc in requested_subset['documents']]
text_data, topics = zip(*data)
return LabelledDocuments(data=text_data, target=topics, target_names=requested_subset['categories'])
if __name__=='__main__':
reuters_train = fetch_reuters21578(subset='train')
print(reuters_train.data)