# Modified version of the code originally implemented by Eustache Diemert <eustache@diemert.fr>
#          @FedericoV <https://github.com/FedericoV/>
# with License: BSD 3 clause

import os.path
import re
import tarfile
from sklearn.datasets import get_data_home
from six.moves import html_parser
from six.moves import urllib
import pickle
from glob import glob
import numpy as np
from data.labeled import LabelledDocuments


def _not_in_sphinx():
    # Hack to detect whether we are running by the sphinx builder
    return '__file__' in globals()


class ReutersParser(html_parser.HTMLParser):
    """Utility class to parse a SGML file and yield documents one at a time."""

    def __init__(self, encoding='latin-1', data_path=None):
        self.data_path = data_path
        self.download_if_not_exist()
        self.tr_docs = []
        self.te_docs = []
        html_parser.HTMLParser.__init__(self)
        self._reset()
        self.encoding = encoding
        self.empty_docs = 0

    def handle_starttag(self, tag, attrs):
        method = 'start_' + tag
        getattr(self, method, lambda x: None)(attrs)

    def handle_endtag(self, tag):
        method = 'end_' + tag
        getattr(self, method, lambda: None)()

    def _reset(self):
        self.in_title = 0
        self.in_body = 0
        self.in_topics = 0
        self.in_topic_d = 0
        self.in_unproc_text = 0
        self.title = ""
        self.body = ""
        self.topics = []
        self.topic_d = ""
        self.text = ""

    def parse(self, fd):
        for chunk in fd:
            self.feed(chunk.decode(self.encoding))
        self.close()

    def handle_data(self, data):
        if self.in_body:
            self.body += data
        elif self.in_title:
            self.title += data
        elif self.in_topic_d:
            self.topic_d += data
        elif self.in_unproc_text:
            self.text += data

    def start_reuters(self, attributes):
        topic_attr = attributes[0][1]
        lewissplit_attr = attributes[1][1]
        self.lewissplit = u'unused'
        if topic_attr==u'YES':
            if lewissplit_attr == u'TRAIN':
                self.lewissplit = 'train'
            elif lewissplit_attr == u'TEST':
                self.lewissplit = 'test'
        pass

    def end_reuters(self):
        self.body = re.sub(r'\s+', r' ', self.body)
        if self.lewissplit != u'unused':
            parsed_doc = {'title': self.title, 'body': self.body, 'unproc':self.text, 'topics': self.topics}
            if (self.title+self.body+self.text).strip() == '':
                self.empty_docs += 1
            if self.lewissplit == u'train':
                self.tr_docs.append(parsed_doc)
            elif self.lewissplit == u'test':
                self.te_docs.append(parsed_doc)
        self._reset()

    def start_title(self, attributes):
        self.in_title = 1

    def end_title(self):
        self.in_title = 0

    def start_body(self, attributes):
        self.in_body = 1

    def end_body(self):
        self.in_body = 0

    def start_topics(self, attributes):
        self.in_topics = 1

    def end_topics(self):
        self.in_topics = 0

    def start_text(self, attributes):
        if len(attributes)>0 and attributes[0][1] == u'UNPROC':
            self.in_unproc_text = 1

    def end_text(self):
        self.in_unproc_text = 0

    def start_d(self, attributes):
        self.in_topic_d = 1

    def end_d(self):
        if self.in_topics:
            self.topics.append(self.topic_d)
        self.in_topic_d = 0
        self.topic_d = ""

    def download_if_not_exist(self):
        DOWNLOAD_URL = ('http://archive.ics.uci.edu/ml/machine-learning-databases/'
                        'reuters21578-mld/reuters21578.tar.gz')
        ARCHIVE_FILENAME = 'reuters21578.tar.gz'

        if self.data_path is None:
            self.data_path = os.path.join(get_data_home(), "reuters")
        if not os.path.exists(self.data_path):
            """Download the dataset."""
            print("downloading dataset (once and for all) into %s" % self.data_path)
            os.mkdir(self.data_path)

            def progress(blocknum, bs, size):
                total_sz_mb = '%.2f MB' % (size / 1e6)
                current_sz_mb = '%.2f MB' % ((blocknum * bs) / 1e6)
                if _not_in_sphinx():
                    print('\rdownloaded %s / %s' % (current_sz_mb, total_sz_mb), end='')

            archive_path = os.path.join(self.data_path, ARCHIVE_FILENAME)
            urllib.request.urlretrieve(DOWNLOAD_URL, filename=archive_path,
                                       reporthook=progress)
            if _not_in_sphinx():
                print('\r', end='')
            print("untarring Reuters dataset...")
            tarfile.open(archive_path, 'r:gz').extractall(self.data_path)
            print("done.")


def fetch_reuters21578(data_path=None, subset='train'):
    if data_path is None:
        data_path = os.path.join(get_data_home(), 'reuters21578')
    reuters_pickle_path = os.path.join(data_path, "reuters." + subset + ".pickle")
    if not os.path.exists(reuters_pickle_path):
        parser = ReutersParser(data_path=data_path)
        for filename in glob(os.path.join(data_path, "*.sgm")):
            parser.parse(open(filename, 'rb'))
        # index category names with a unique numerical code (only considering categories with training examples)
        tr_categories = np.unique(np.concatenate([doc['topics'] for doc in parser.tr_docs])).tolist()

        def pickle_documents(docs, subset):
            for doc in docs:
                doc['topics'] = [tr_categories.index(t) for t in doc['topics'] if t in tr_categories]
            pickle_docs = {'categories': tr_categories, 'documents': docs}
            pickle.dump(pickle_docs, open(os.path.join(data_path, "reuters." + subset + ".pickle"), 'wb'),
                        protocol=pickle.HIGHEST_PROTOCOL)
            return pickle_docs

        pickle_tr = pickle_documents(parser.tr_docs, "train")
        pickle_te = pickle_documents(parser.te_docs, "test")
        # self.sout('Empty docs %d' % parser.empty_docs)
        requested_subset = pickle_tr if subset == 'train' else pickle_te
    else:
        requested_subset = pickle.load(open(reuters_pickle_path, 'rb'))

    data = [(u'{title}\n{body}\n{unproc}'.format(**doc), doc['topics']) for doc in requested_subset['documents']]
    text_data, topics = zip(*data)
    return LabelledDocuments(data=text_data, target=topics, target_names=requested_subset['categories'])



if __name__=='__main__':
    reuters_train = fetch_reuters21578(subset='train')
    print(reuters_train.data)