forked from moreo/QuaPy
213 lines
9.7 KiB
Python
Executable File
213 lines
9.7 KiB
Python
Executable File
#https://www.wipo.int/classifications/ipc/en/ITsupport/Categorization/dataset/
|
|
import os, sys
|
|
from os.path import exists, join
|
|
from util.file import *
|
|
from zipfile import ZipFile
|
|
import xml.etree.ElementTree as ET
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
import pickle
|
|
from joblib import Parallel, delayed
|
|
|
|
WIPO_URL= 'https://www.wipo.int/classifications/ipc/en/ITsupport/Categorization/dataset/'
|
|
|
|
|
|
class WipoGammaDocument:
|
|
def __init__(self, id, text, main_label, all_labels):
|
|
self.id = id
|
|
self.text = text
|
|
self.main_label = main_label
|
|
self.all_labels = all_labels
|
|
|
|
|
|
def remove_nested_claimtext_tags(xmlcontent):
|
|
from_pos = xmlcontent.find(b'<claims')
|
|
to_pos = xmlcontent.find(b'</claims>')
|
|
if from_pos > -1 and to_pos > -1:
|
|
in_between = xmlcontent[from_pos:to_pos].replace(b'<claim-text>',b'').replace(b'</claim-text>',b'')
|
|
xmlcontent = (xmlcontent[:from_pos]+in_between+xmlcontent[to_pos:]).strip()
|
|
return xmlcontent
|
|
|
|
|
|
def parse_document(xml_content, text_fields, limit_description):
|
|
root = ET.fromstring(remove_nested_claimtext_tags(xml_content))
|
|
|
|
doc_id = root.attrib['ucid']
|
|
lang = root.attrib['lang']
|
|
|
|
#take categories from the categorization up the "sub-class" level
|
|
main_group = set(t.text[:6] for t in root.findall('.//bibliographic-data/technical-data/classifications-ipcr/classification-ipcr[@computed="from_ecla_to_ipc_SG"][@generated_main_IPC="true"]'))
|
|
sec_groups = set(t.text[:6] for t in root.findall('.//bibliographic-data/technical-data/classifications-ipcr/classification-ipcr[@computed="from_ecla_to_ipc_SG"][@generated_main_IPC="false"]'))
|
|
sec_groups.update(main_group)
|
|
|
|
assert len(main_group) == 1, 'more than one main groups'
|
|
main_group = list(main_group)[0]
|
|
sec_groups = sorted(list(sec_groups))
|
|
|
|
assert lang == 'EN', f'only English documents allowed (doc {doc_id})'
|
|
|
|
doc_text_fields=[]
|
|
if 'abstract' in text_fields:
|
|
abstract = '\n'.join(filter(None, [t.text for t in root.findall('.//abstract[@lang="EN"]/p')]))
|
|
doc_text_fields.append(abstract)
|
|
if 'description' in text_fields:
|
|
description = '\n'.join(filter(None, [t.text for t in root.findall('.//description[@lang="EN"]/p')]))
|
|
if limit_description>-1:
|
|
description=' '.join(description.split()[:limit_description])
|
|
doc_text_fields.append(description)
|
|
if 'claims' in text_fields:
|
|
claims = '\n'.join(filter(None, [t.text for t in root.findall('.//claims[@lang="EN"]/claim')]))
|
|
doc_text_fields.append(claims)
|
|
|
|
text = '\n'.join(doc_text_fields)
|
|
if text:
|
|
return WipoGammaDocument(doc_id, text, main_group, sec_groups)
|
|
else:
|
|
return None
|
|
|
|
|
|
def extract(fin, fout, text_fields, limit_description):
|
|
zipfile = ZipFile(fin)
|
|
ndocs=0
|
|
with open(fout, 'wt') as out:
|
|
for xmlfile in tqdm(zipfile.namelist()):
|
|
if xmlfile.endswith('.xml'):
|
|
xmlcontent = zipfile.open(xmlfile).read()
|
|
document = parse_document(xmlcontent, text_fields, limit_description)
|
|
if document:
|
|
line_text = document.text.replace('\n', ' ').replace('\t', ' ').strip()
|
|
assert line_text, f'empty document in {xmlfile}'
|
|
all_labels = ' '.join(document.all_labels)
|
|
out.write('\t'.join([document.id, document.main_label, all_labels, line_text]))
|
|
out.write('\n')
|
|
ndocs+=1
|
|
out.flush()
|
|
|
|
|
|
|
|
def read_classification_file(data_path, classification_level):
|
|
assert classification_level in ['subclass', 'maingroup'], 'wrong classification requested'
|
|
z = ZipFile(join(data_path,'EnglishWipoGamma1.zip'))
|
|
inpath='Wipo_Gamma/English/TrainTestSpits'
|
|
document_labels = dict()
|
|
train_ids, test_ids = set(), set()
|
|
labelcut = LabelCut(classification_level)
|
|
for subset in tqdm(['train', 'test'], desc='loading classification file'):
|
|
target_subset = train_ids if subset=='train' else test_ids
|
|
if classification_level == 'subclass':
|
|
file = f'{subset}set_en_sc.parts' #sub-class level
|
|
else:
|
|
file = f'{subset}set_en_mg.parts' #main-group level
|
|
|
|
for line in z.open(f'{inpath}/{file}').readlines():
|
|
line = line.decode().strip().split(',')
|
|
id = line[0]
|
|
id = id[id.rfind('/')+1:].replace('.xml','')
|
|
labels = labelcut.trim(line[1:])
|
|
document_labels[id]=labels
|
|
target_subset.add(id)
|
|
|
|
return document_labels, train_ids, test_ids
|
|
|
|
|
|
class LabelCut:
|
|
"""
|
|
Labels consists of 1 char for section, 2 chars for class, 1 class for subclass, 2 chars for maingroup and so on.
|
|
This class cuts the label at a desired level (4 for subclass, or 6 for maingroup)
|
|
"""
|
|
def __init__(self, classification_level):
|
|
assert classification_level in {'subclass','maingroup'}, 'unknown classification level'
|
|
if classification_level == 'subclass': self.cut = 4
|
|
else: self.cut = 6
|
|
|
|
def trim(self, label):
|
|
if isinstance(label, list):
|
|
return sorted(set([l[:self.cut] for l in label]))
|
|
else:
|
|
return label[:self.cut]
|
|
|
|
|
|
|
|
def fetch_WIPOgamma(subset, classification_level, data_home, extracted_path, text_fields = ['abstract', 'description'], limit_description=300):
|
|
"""
|
|
Fetchs the WIPO-gamma dataset
|
|
:param subset: 'train' or 'test' split
|
|
:param classification_level: the classification level, either 'subclass' or 'maingroup'
|
|
:param data_home: directory containing the original 11 English zips
|
|
:param extracted_path: directory used to extract and process the original files
|
|
:param text_fields: indicates the fields to extract, in 'abstract', 'description', 'claims'
|
|
:param limit_description: the maximum number of words to take from the description field (default 300); set to -1 for all
|
|
:return:
|
|
"""
|
|
assert subset in {"train", "test"}, 'unknown target request (valid ones are "train" or "test")'
|
|
assert len(text_fields)>0, 'at least some text field should be indicated'
|
|
if not exists(data_home):
|
|
raise ValueError(f'{data_home} does not exist, and the dataset cannot be automatically download, '
|
|
f'since you need to request for permission. Please refer to {WIPO_URL}')
|
|
|
|
create_if_not_exist(extracted_path)
|
|
config = f'{"-".join(text_fields)}'
|
|
if 'description' in text_fields: config+='-{limit_description}'
|
|
pickle_path=join(extracted_path, f'wipo-{subset}-{classification_level}-{config}.pickle')
|
|
if exists(pickle_path):
|
|
print(f'loading pickled file in {pickle_path}')
|
|
return pickle.load(open(pickle_path,'rb'))
|
|
|
|
print('pickle file not found, processing...(this will take some minutes)')
|
|
extracted = sum([exists(f'{extracted_path}/EnglishWipoGamma{(i+1)}-{config}.txt') for i in range(11)])==11
|
|
if not extracted:
|
|
print(f'extraction files not found, extracting files in {data_home}... (this will take some additional minutes)')
|
|
Parallel(n_jobs=-1)(
|
|
delayed(extract)(
|
|
join(data_home, file), join(extracted_path, file.replace('.zip', f'-{config}.txt')), text_fields, limit_description
|
|
)
|
|
for file in list_files(data_home)
|
|
)
|
|
doc_labels, train_ids, test_ids = read_classification_file(data_home, classification_level=classification_level) # or maingroup
|
|
print(f'{len(doc_labels)} documents classified split in {len(train_ids)} train and {len(test_ids)} test documents')
|
|
|
|
train_request = []
|
|
test_request = []
|
|
pbar = tqdm([filename for filename in list_files(extracted_path) if filename.endswith(f'-{config}.txt')])
|
|
labelcut = LabelCut(classification_level)
|
|
errors=0
|
|
for proc_file in pbar:
|
|
pbar.set_description(f'processing {proc_file} [errors={errors}]')
|
|
if not proc_file.endswith(f'-{config}.txt'): continue
|
|
lines = open(f'{extracted_path}/{proc_file}', 'rt').readlines()
|
|
for lineno,line in enumerate(lines):
|
|
parts = line.split('\t')
|
|
assert len(parts)==4, f'wrong format in {extracted_path}/{proc_file} line {lineno}'
|
|
id,mainlabel,alllabels,text=parts
|
|
mainlabel = labelcut.trim(mainlabel)
|
|
alllabels = labelcut.trim(alllabels.split())
|
|
|
|
# assert id in train_ids or id in test_ids, f'id {id} out of scope'
|
|
if id not in train_ids and id not in test_ids:
|
|
errors+=1
|
|
else:
|
|
# assert mainlabel == doc_labels[id][0], 'main label not consistent'
|
|
request = train_request if id in train_ids else test_request
|
|
request.append(WipoGammaDocument(id, text, mainlabel, alllabels))
|
|
|
|
print('pickling requests for faster subsequent runs')
|
|
pickle.dump(train_request, open(join(extracted_path,f'wipo-train-{classification_level}-{config}.pickle'), 'wb', pickle.HIGHEST_PROTOCOL))
|
|
pickle.dump(test_request, open(join(extracted_path, f'wipo-test-{classification_level}-{config}.pickle'), 'wb', pickle.HIGHEST_PROTOCOL))
|
|
|
|
if subset== 'train':
|
|
return train_request
|
|
else:
|
|
return test_request
|
|
|
|
|
|
if __name__=='__main__':
|
|
data_home = '../../datasets/WIPO/wipo-gamma/en'
|
|
extracted_path = '../../datasets/WIPO-extracted'
|
|
|
|
train = fetch_WIPOgamma(subset='train', classification_level='subclass', data_home=data_home, extracted_path=extracted_path, text_fields=('abstract'))
|
|
test = fetch_WIPOgamma(subset='test', classification_level='subclass', data_home=data_home, extracted_path=extracted_path, text_fields=('abstract'))
|
|
# train = fetch_WIPOgamma(subset='train', classification_level='maingroup', data_home=data_home, extracted_path=extracted_path)
|
|
# test = fetch_WIPOgamma(subset='test', classification_level='maingroup', data_home=data_home, extracted_path=extracted_path)
|
|
|
|
print('Done')
|