from zipfile import ZipFile
import xml.etree.ElementTree as ET
from data.labeled import LabelledDocuments
from util.file import list_files
from os.path import join, exists
from util.file import download_file_if_not_exists
import re
from collections import Counter

RCV1_TOPICHIER_URL = "http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/a02-orig-topics-hierarchy/rcv1.topics.hier.orig"
RCV1_BASE_URL = "http://www.daviddlewis.com/resources/testcollections/rcv1/"

rcv1_test_data_gz = ['lyrl2004_tokens_test_pt0.dat.gz',
             'lyrl2004_tokens_test_pt1.dat.gz',
             'lyrl2004_tokens_test_pt2.dat.gz',
             'lyrl2004_tokens_test_pt3.dat.gz']

rcv1_train_data_gz = ['lyrl2004_tokens_train.dat.gz']

rcv1_doc_cats_data_gz = 'rcv1-v2.topics.qrels.gz'

class RCV_Document:
    def __init__(self, id, text, categories, date=''):
        self.id = id
        self.date = date
        self.text = text
        self.categories = categories

class IDRangeException(Exception): pass

nwords = []

def parse_document(xml_content, valid_id_range=None):
    root = ET.fromstring(xml_content)

    doc_id = root.attrib['itemid']
    if valid_id_range is not None:
        if not valid_id_range[0] <= int(doc_id) <= valid_id_range[1]:
            raise IDRangeException

    doc_categories = [cat.attrib['code'] for cat in
                      root.findall('.//metadata/codes[@class="bip:topics:1.0"]/code')]

    doc_date = root.attrib['date']
    doc_title = root.find('.//title').text
    doc_headline = root.find('.//headline').text
    doc_body = '\n'.join([p.text for p in root.findall('.//text/p')])

    if not doc_body:
        raise ValueError('Empty document')

    if doc_title is None: doc_title = ''
    if doc_headline is None or doc_headline in doc_title: doc_headline = ''
    text = '\n'.join([doc_title, doc_headline, doc_body]).strip()

    return RCV_Document(id=doc_id, text=text, categories=doc_categories, date=doc_date)


def fetch_RCV1(data_path, subset='all'):

    assert subset in ['train', 'test', 'all'], 'split should either be "train", "test", or "all"'

    request = []
    labels = set()
    read_documents = 0

    training_documents = 23149
    test_documents = 781265

    if subset == 'all':
        split_range = (2286, 810596)
        expected = training_documents+test_documents
    elif subset == 'train':
        split_range = (2286, 26150)
        expected = training_documents
    else:
        split_range = (26151, 810596)
        expected = test_documents

    # global nwords
    # nwords=[]
    for part in list_files(data_path):
        if not re.match('\d+\.zip', part): continue
        target_file = join(data_path, part)
        assert exists(target_file), \
            "You don't seem to have the file "+part+" in " + data_path + ", and the RCV1 corpus can not be downloaded"+\
            " w/o a formal permission. Please, refer to " + RCV1_BASE_URL + " for more information."
        zipfile = ZipFile(target_file)
        for xmlfile in zipfile.namelist():
            xmlcontent = zipfile.open(xmlfile).read()
            try:
                doc = parse_document(xmlcontent, valid_id_range=split_range)
                labels.update(doc.categories)
                request.append(doc)
                read_documents += 1
            except (IDRangeException,ValueError) as e:
                pass
            print('\r[{}] read {} documents'.format(part, len(request)), end='')
            if read_documents == expected: break
        if read_documents == expected: break

    print()
    # print('ave:{} std {} min {} max {}'.format(np.mean(nwords), np.std(nwords), np.min(nwords), np.max(nwords)))

    return LabelledDocuments(data=[d.text for d in request], target=[d.categories for d in request], target_names=list(labels))



def fetch_topic_hierarchy(path, topics='all'):
    assert topics in ['all', 'leaves']

    download_file_if_not_exists(RCV1_TOPICHIER_URL, path)
    hierarchy = {}
    for line in open(path, 'rt'):
        parts = line.strip().split()
        parent,child = parts[1],parts[3]
        if parent not in hierarchy:
            hierarchy[parent]=[]
        hierarchy[parent].append(child)

    del hierarchy['None']
    del hierarchy['Root']
    print(hierarchy)

    if topics=='all':
        topics = set(hierarchy.keys())
        for parent in hierarchy.keys():
            topics.update(hierarchy[parent])
        return list(topics)
    elif topics=='leaves':
        parents = set(hierarchy.keys())
        childs = set()
        for parent in hierarchy.keys():
            childs.update(hierarchy[parent])
        return list(childs.difference(parents))


if __name__=='__main__':

    # example

    RCV1_PATH = '../../datasets/RCV1-v2/unprocessed_corpus'

    rcv1_train = fetch_RCV1(RCV1_PATH, subset='train')
    rcv1_test = fetch_RCV1(RCV1_PATH, subset='test')

    print('read {} documents in rcv1-train, and {} labels'.format(len(rcv1_train.data), len(rcv1_train.target_names)))
    print('read {} documents in rcv1-test, and {} labels'.format(len(rcv1_test.data), len(rcv1_test.target_names)))

    cats = Counter()
    for cats in rcv1_train.target: cats.update(cats)
    print('RCV1', cats)