#https://www.wipo.int/classifications/ipc/en/ITsupport/Categorization/dataset/
import os, sys
from os.path import exists, join
from util.file import *
from zipfile import ZipFile
import xml.etree.ElementTree as ET
from tqdm import tqdm
import numpy as np
import pickle
from joblib import Parallel, delayed

WIPO_URL= 'https://www.wipo.int/classifications/ipc/en/ITsupport/Categorization/dataset/'


class WipoGammaDocument:
    def __init__(self, id, text, main_label, all_labels):
        self.id = id
        self.text = text
        self.main_label = main_label
        self.all_labels = all_labels


def remove_nested_claimtext_tags(xmlcontent):
    from_pos = xmlcontent.find(b'<claims')
    to_pos = xmlcontent.find(b'</claims>')
    if from_pos > -1 and to_pos > -1:
        in_between = xmlcontent[from_pos:to_pos].replace(b'<claim-text>',b'').replace(b'</claim-text>',b'')
        xmlcontent = (xmlcontent[:from_pos]+in_between+xmlcontent[to_pos:]).strip()
    return xmlcontent


def parse_document(xml_content, text_fields, limit_description):
    root = ET.fromstring(remove_nested_claimtext_tags(xml_content))

    doc_id  = root.attrib['ucid']
    lang    = root.attrib['lang']

    #take categories from the categorization up the "sub-class" level
    main_group = set(t.text[:6] for t in root.findall('.//bibliographic-data/technical-data/classifications-ipcr/classification-ipcr[@computed="from_ecla_to_ipc_SG"][@generated_main_IPC="true"]'))
    sec_groups = set(t.text[:6] for t in root.findall('.//bibliographic-data/technical-data/classifications-ipcr/classification-ipcr[@computed="from_ecla_to_ipc_SG"][@generated_main_IPC="false"]'))
    sec_groups.update(main_group)

    assert len(main_group) == 1, 'more than one main groups'
    main_group = list(main_group)[0]
    sec_groups = sorted(list(sec_groups))

    assert lang == 'EN', f'only English documents allowed (doc {doc_id})'

    doc_text_fields=[]
    if 'abstract' in text_fields:
        abstract = '\n'.join(filter(None, [t.text for t in root.findall('.//abstract[@lang="EN"]/p')]))
        doc_text_fields.append(abstract)
    if 'description' in text_fields:
        description = '\n'.join(filter(None, [t.text for t in root.findall('.//description[@lang="EN"]/p')]))
        if limit_description>-1:
            description=' '.join(description.split()[:limit_description])
        doc_text_fields.append(description)
    if 'claims' in text_fields:
        claims = '\n'.join(filter(None, [t.text for t in root.findall('.//claims[@lang="EN"]/claim')]))
        doc_text_fields.append(claims)

    text = '\n'.join(doc_text_fields)
    if text:
        return WipoGammaDocument(doc_id, text, main_group, sec_groups)
    else:
        return None


def extract(fin, fout, text_fields, limit_description):
    zipfile = ZipFile(fin)
    ndocs=0
    with open(fout, 'wt') as out:
        for xmlfile in tqdm(zipfile.namelist()):
            if xmlfile.endswith('.xml'):
                xmlcontent = zipfile.open(xmlfile).read()
                document = parse_document(xmlcontent, text_fields, limit_description)
                if document:
                    line_text = document.text.replace('\n', ' ').replace('\t', ' ').strip()
                    assert line_text, f'empty document in {xmlfile}'
                    all_labels = ' '.join(document.all_labels)
                    out.write('\t'.join([document.id, document.main_label, all_labels, line_text]))
                    out.write('\n')
                    ndocs+=1
            out.flush()



def read_classification_file(data_path, classification_level):
    assert classification_level in ['subclass', 'maingroup'], 'wrong classification requested'
    z = ZipFile(join(data_path,'EnglishWipoGamma1.zip'))
    inpath='Wipo_Gamma/English/TrainTestSpits'
    document_labels = dict()
    train_ids, test_ids = set(), set()
    labelcut = LabelCut(classification_level)
    for subset in tqdm(['train', 'test'], desc='loading classification file'):
        target_subset = train_ids if subset=='train' else test_ids
        if classification_level == 'subclass':
            file = f'{subset}set_en_sc.parts' #sub-class level
        else:
            file = f'{subset}set_en_mg.parts' #main-group level

        for line in z.open(f'{inpath}/{file}').readlines():
            line = line.decode().strip().split(',')
            id = line[0]
            id = id[id.rfind('/')+1:].replace('.xml','')
            labels = labelcut.trim(line[1:])
            document_labels[id]=labels
            target_subset.add(id)

    return document_labels, train_ids, test_ids


class LabelCut:
    """
    Labels consists of 1 char for section, 2 chars for class, 1 class for subclass, 2 chars for maingroup and so on.
    This class cuts the label at a desired level (4 for subclass, or 6 for maingroup)
    """
    def __init__(self, classification_level):
        assert classification_level in {'subclass','maingroup'}, 'unknown classification level'
        if classification_level == 'subclass': self.cut = 4
        else: self.cut = 6

    def trim(self, label):
        if isinstance(label, list):
            return sorted(set([l[:self.cut] for l in label]))
        else:
            return label[:self.cut]



def fetch_WIPOgamma(subset, classification_level, data_home, extracted_path, text_fields = ['abstract', 'description'], limit_description=300):
    """
    Fetchs the WIPO-gamma dataset
    :param subset: 'train' or 'test' split
    :param classification_level: the classification level, either 'subclass' or 'maingroup'
    :param data_home: directory containing the original 11 English zips
    :param extracted_path: directory used to extract and process the original files
    :param text_fields: indicates the fields to extract, in 'abstract', 'description', 'claims'
    :param limit_description: the maximum number of words to take from the description field (default 300); set to -1 for all
    :return:
    """
    assert subset in {"train", "test"}, 'unknown target request (valid ones are "train" or "test")'
    assert len(text_fields)>0, 'at least some text field should be indicated'
    if not exists(data_home):
        raise ValueError(f'{data_home} does not exist, and the dataset cannot be automatically download, '
              f'since you need to request for permission. Please refer to {WIPO_URL}')

    create_if_not_exist(extracted_path)
    config = f'{"-".join(text_fields)}'
    if 'description' in text_fields: config+='-{limit_description}'
    pickle_path=join(extracted_path, f'wipo-{subset}-{classification_level}-{config}.pickle')
    if exists(pickle_path):
        print(f'loading pickled file in {pickle_path}')
        return pickle.load(open(pickle_path,'rb'))

    print('pickle file not found, processing...(this will take some minutes)')
    extracted = sum([exists(f'{extracted_path}/EnglishWipoGamma{(i+1)}-{config}.txt') for i in range(11)])==11
    if not extracted:
        print(f'extraction files not found, extracting files in {data_home}... (this will take some additional minutes)')
        Parallel(n_jobs=-1)(
            delayed(extract)(
                join(data_home, file), join(extracted_path, file.replace('.zip', f'-{config}.txt')), text_fields, limit_description
            )
            for file in list_files(data_home)
        )
    doc_labels, train_ids, test_ids = read_classification_file(data_home, classification_level=classification_level)  # or maingroup
    print(f'{len(doc_labels)} documents classified split in {len(train_ids)} train and {len(test_ids)} test documents')

    train_request = []
    test_request  = []
    pbar = tqdm([filename for filename in list_files(extracted_path) if filename.endswith(f'-{config}.txt')])
    labelcut = LabelCut(classification_level)
    errors=0
    for proc_file in pbar:
        pbar.set_description(f'processing {proc_file} [errors={errors}]')
        if not proc_file.endswith(f'-{config}.txt'): continue
        lines = open(f'{extracted_path}/{proc_file}', 'rt').readlines()
        for lineno,line in enumerate(lines):
            parts = line.split('\t')
            assert len(parts)==4, f'wrong format in {extracted_path}/{proc_file} line {lineno}'
            id,mainlabel,alllabels,text=parts
            mainlabel = labelcut.trim(mainlabel)
            alllabels = labelcut.trim(alllabels.split())

            # assert id in train_ids or id in test_ids, f'id {id} out of scope'
            if id not in train_ids and id not in test_ids:
                errors+=1
            else:
                # assert mainlabel == doc_labels[id][0], 'main label not consistent'
                request = train_request if id in train_ids else test_request
                request.append(WipoGammaDocument(id, text, mainlabel, alllabels))

    print('pickling requests for faster subsequent runs')
    pickle.dump(train_request, open(join(extracted_path,f'wipo-train-{classification_level}-{config}.pickle'), 'wb', pickle.HIGHEST_PROTOCOL))
    pickle.dump(test_request, open(join(extracted_path, f'wipo-test-{classification_level}-{config}.pickle'), 'wb', pickle.HIGHEST_PROTOCOL))

    if subset== 'train':
        return train_request
    else:
        return test_request


if __name__=='__main__':
    data_home = '../../datasets/WIPO/wipo-gamma/en'
    extracted_path = '../../datasets/WIPO-extracted'

    train = fetch_WIPOgamma(subset='train', classification_level='subclass', data_home=data_home, extracted_path=extracted_path, text_fields=('abstract'))
    test = fetch_WIPOgamma(subset='test', classification_level='subclass', data_home=data_home, extracted_path=extracted_path, text_fields=('abstract'))
    # train = fetch_WIPOgamma(subset='train', classification_level='maingroup', data_home=data_home, extracted_path=extracted_path)
    # test = fetch_WIPOgamma(subset='test', classification_level='maingroup', data_home=data_home, extracted_path=extracted_path)

    print('Done')