forked from moreo/QuaPy
64 lines
2.8 KiB
Python
Executable File
64 lines
2.8 KiB
Python
Executable File
import os
|
|
import pickle
|
|
import tarfile
|
|
from os.path import join
|
|
import urllib.request
|
|
from data.labeled import LabelledDocuments
|
|
from util.file import create_if_not_exist, download_file_if_not_exists
|
|
import math
|
|
|
|
|
|
def fetch_ohsumed50k(data_path=None, subset='train', train_test_split=0.7):
|
|
_dataname = 'ohsumed50k'
|
|
if data_path is None:
|
|
data_path = join(os.path.expanduser('~'), _dataname)
|
|
create_if_not_exist(data_path)
|
|
|
|
pickle_file = join(data_path, _dataname + '.' + subset + str(train_test_split) + '.pickle')
|
|
if not os.path.exists(pickle_file):
|
|
DOWNLOAD_URL = ('http://disi.unitn.it/moschitti/corpora/ohsumed-all-docs.tar.gz')
|
|
archive_path = os.path.join(data_path, 'ohsumed-all-docs.tar.gz')
|
|
download_file_if_not_exists(DOWNLOAD_URL, archive_path)
|
|
untardir = 'ohsumed-all'
|
|
if not os.path.exists(os.path.join(data_path, untardir)):
|
|
print("untarring ohsumed...")
|
|
tarfile.open(archive_path, 'r:gz').extractall(data_path)
|
|
|
|
target_names = []
|
|
doc_classes = dict()
|
|
class_docs = dict()
|
|
content = dict()
|
|
doc_ids = set()
|
|
for cat_id in os.listdir(join(data_path, untardir)):
|
|
target_names.append(cat_id)
|
|
class_docs[cat_id] = []
|
|
for doc_id in os.listdir(join(data_path, untardir, cat_id)):
|
|
doc_ids.add(doc_id)
|
|
text_content = open(join(data_path, untardir, cat_id, doc_id), 'r').read()
|
|
if doc_id not in doc_classes: doc_classes[doc_id] = []
|
|
doc_classes[doc_id].append(cat_id)
|
|
if doc_id not in content: content[doc_id] = text_content
|
|
class_docs[cat_id].append(doc_id)
|
|
target_names.sort()
|
|
print('Read %d different documents' % len(doc_ids))
|
|
|
|
splitdata = dict({'train': [], 'test': []})
|
|
for cat_id in target_names:
|
|
free_docs = [d for d in class_docs[cat_id] if (d not in splitdata['train'] and d not in splitdata['test'])]
|
|
if len(free_docs) > 0:
|
|
split_point = int(math.floor(len(free_docs) * train_test_split))
|
|
splitdata['train'].extend(free_docs[:split_point])
|
|
splitdata['test'].extend(free_docs[split_point:])
|
|
for split in ['train', 'test']:
|
|
dataset = LabelledDocuments([], [], target_names)
|
|
for doc_id in splitdata[split]:
|
|
dataset.data.append(content[doc_id])
|
|
dataset.target.append([target_names.index(cat_id) for cat_id in doc_classes[doc_id]])
|
|
pickle.dump(dataset,
|
|
open(join(data_path, _dataname + '.' + split + str(train_test_split) + '.pickle'), 'wb'),
|
|
protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
print(pickle_file)
|
|
return pickle.load(open(pickle_file, 'rb'))
|
|
|