191 lines
6.3 KiB
Python
191 lines
6.3 KiB
Python
import os
|
|
import sys
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
import re
|
|
from os import listdir
|
|
from os.path import isdir, join
|
|
|
|
import requests
|
|
from bs4 import BeautifulSoup
|
|
from PIL import Image
|
|
from sklearn.preprocessing import MultiLabelBinarizer
|
|
|
|
# TODO: labels must be aligned between languages
|
|
# TODO: remove copyright and also tags (doc.split("More about:")[0])
|
|
# TODO: this should be a instance of a abstract MultimodalMultilingualDataset
|
|
|
|
|
|
def get_label_binarizer(cats):
|
|
mlb = MultiLabelBinarizer()
|
|
mlb.fit([cats])
|
|
return mlb
|
|
|
|
|
|
class MultiNewsDataset:
|
|
def __init__(self, data_dir, excluded_langs=[], debug=False):
|
|
self.debug = debug
|
|
self.data_dir = data_dir
|
|
self.dataset_langs = self.get_langs()
|
|
self.excluded_langs = excluded_langs
|
|
self.lang_multiModalDataset = {}
|
|
print(
|
|
f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {[l for l in self.dataset_langs if l not in self.excluded_langs]}]"
|
|
)
|
|
self.load_data()
|
|
self.all_labels = self.get_labels()
|
|
self.label_binarizer = get_label_binarizer(self.all_labels)
|
|
self.print_stats()
|
|
|
|
def load_data(self):
|
|
for lang in self.dataset_langs:
|
|
if lang not in self.excluded_langs:
|
|
self.lang_multiModalDataset[lang] = MultiModalDataset(
|
|
lang, join(self.data_dir, lang)
|
|
)
|
|
|
|
def langs(self):
|
|
return [l for l in self.dataset_langs if l not in self.excluded_langs]
|
|
return self.get_langs()
|
|
|
|
def get_langs(self):
|
|
from os import listdir
|
|
|
|
if self.debug:
|
|
return ["it", "en"]
|
|
|
|
return tuple(sorted([folder for folder in listdir(self.data_dir)]))
|
|
|
|
def print_stats(self):
|
|
print(f"[MultiNewsDataset stats]")
|
|
total_docs = 0
|
|
for lang in self.dataset_langs:
|
|
if lang not in self.excluded_langs:
|
|
_len = len(self.lang_multiModalDataset[lang].data)
|
|
total_docs += _len
|
|
print(
|
|
f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].labels)}"
|
|
)
|
|
print(f" - total docs: {total_docs}\n")
|
|
|
|
def _count_lang_labels(self, labels):
|
|
lang_labels = set()
|
|
for l in labels:
|
|
lang_labels.update(l)
|
|
return len(lang_labels)
|
|
|
|
def export_to_torch_dataset(self, tokenizer_id):
|
|
raise NotImplementedError
|
|
|
|
def save_to_disk(self):
|
|
raise NotImplementedError
|
|
|
|
def training(self):
|
|
# TODO: this is a (working) mess - clean this up
|
|
lXtr = {}
|
|
lYtr = {}
|
|
for lang, data in self.lang_multiModalDataset.items():
|
|
_data = [(clean_text, img) for _, clean_text, _, img in data.data]
|
|
lXtr[lang] = _data
|
|
lYtr = {
|
|
lang: self.label_binarizer.transform(data.labels)
|
|
for lang, data in self.lang_multiModalDataset.items()
|
|
}
|
|
|
|
return lXtr, lYtr
|
|
|
|
def testing(self):
|
|
raise NotImplementedError
|
|
|
|
def get_labels(self):
|
|
all_labels = set()
|
|
for lang, data in self.lang_multiModalDataset.items():
|
|
for label in data.labels:
|
|
all_labels.update(label)
|
|
return all_labels
|
|
|
|
|
|
class MultiModalDataset:
|
|
def __init__(self, lang, data_dir):
|
|
self.lang = lang
|
|
self.data_dir = data_dir
|
|
self.re_labels = re.compile(r"<a rel=\"tag\" href=\"\/tag\/.+?\/\">(.+?)<\/a>")
|
|
self.re_cleaner = re.compile(r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
|
|
self.re_white = re.compile(r" +")
|
|
self.data, self.labels = self.get_docs()
|
|
|
|
def get_imgs(self):
|
|
raise NotImplementedError
|
|
|
|
def get_labels(self):
|
|
raise NotImplementedError
|
|
|
|
def get_ids(self):
|
|
raise NotImplementedError
|
|
|
|
def get_docs(self):
|
|
# FIXME: this is a mess
|
|
data = []
|
|
labels = []
|
|
news_folder = [doc_folder for doc_folder in listdir(self.data_dir)]
|
|
for news_folder in news_folder:
|
|
if isdir(join(self.data_dir, news_folder)):
|
|
fname_doc = f"text.{news_folder.split('.')[-1]}"
|
|
with open(join(self.data_dir, news_folder, fname_doc)) as f:
|
|
html_doc = f.read()
|
|
index_path = join(self.data_dir, news_folder, "index.html")
|
|
if not any(
|
|
File.endswith(".jpg")
|
|
for File in listdir(join(self.data_dir, news_folder))
|
|
):
|
|
img_link, img = self.get_images(index_path)
|
|
self.save_img(join(self.data_dir, news_folder, "img.jpg"), img)
|
|
# TODO: convert img to PIL image
|
|
img = Image.open(join(self.data_dir, news_folder, "img.jpg"))
|
|
clean_doc, doc_labels = self.preprocess_html(html_doc)
|
|
data.append((fname_doc, clean_doc, html_doc, img))
|
|
labels.append(doc_labels)
|
|
return data, labels
|
|
|
|
def save_img(self, path, img):
|
|
with open(path, "wb") as f:
|
|
f.write(img)
|
|
|
|
def get_images(self, index_path):
|
|
imgs = BeautifulSoup(open(index_path), "html.parser").findAll("img")
|
|
imgs = imgs[1]
|
|
# TODO: forcing to take the first image (i.e. index 1 should be the main image)
|
|
content = requests.get(imgs["src"]).content
|
|
return imgs, content
|
|
|
|
def preprocess_html(self, html_doc):
|
|
# TODO: this could be replaced by BeautifulSoup call or something similar
|
|
labels = self._extract_labels(html_doc)
|
|
cleaned = self._clean_up_str(self._remove_html_tags(html_doc))
|
|
return cleaned, labels
|
|
|
|
def _extract_labels(self, data):
|
|
return re.findall(self.re_labels, data)
|
|
|
|
def _remove_html_tags(self, data):
|
|
cleaned = re.sub(self.re_cleaner, "", data)
|
|
return cleaned
|
|
|
|
def _clean_up_str(self, doc):
|
|
doc = re.sub(self.re_white, " ", doc)
|
|
doc = doc.lstrip()
|
|
doc = doc.rstrip()
|
|
doc = doc.replace("\n", " ")
|
|
doc = doc.replace("\t", " ")
|
|
return doc
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from os.path import expanduser
|
|
|
|
_dataset_path_hardcoded = "~/datasets/MultiNews/20110730/"
|
|
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
|
|
lXtr, lYtr = dataset.training()
|
|
exit()
|