gfun_multimodal/dataManager/multiNewsDataset.py

191 lines
6.3 KiB
Python

import os
import sys
sys.path.append(os.getcwd())
import re
from os import listdir
from os.path import isdir, join
import requests
from bs4 import BeautifulSoup
from PIL import Image
from sklearn.preprocessing import MultiLabelBinarizer
# TODO: labels must be aligned between languages
# TODO: remove copyright and also tags (doc.split("More about:")[0])
# TODO: this should be a instance of a abstract MultimodalMultilingualDataset
def get_label_binarizer(cats):
mlb = MultiLabelBinarizer()
mlb.fit([cats])
return mlb
class MultiNewsDataset:
def __init__(self, data_dir, excluded_langs=[], debug=False):
self.debug = debug
self.data_dir = data_dir
self.dataset_langs = self.get_langs()
self.excluded_langs = excluded_langs
self.lang_multiModalDataset = {}
print(
f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {[l for l in self.dataset_langs if l not in self.excluded_langs]}]"
)
self.load_data()
self.all_labels = self.get_labels()
self.label_binarizer = get_label_binarizer(self.all_labels)
self.print_stats()
def load_data(self):
for lang in self.dataset_langs:
if lang not in self.excluded_langs:
self.lang_multiModalDataset[lang] = MultiModalDataset(
lang, join(self.data_dir, lang)
)
def langs(self):
return [l for l in self.dataset_langs if l not in self.excluded_langs]
return self.get_langs()
def get_langs(self):
from os import listdir
if self.debug:
return ["it", "en"]
return tuple(sorted([folder for folder in listdir(self.data_dir)]))
def print_stats(self):
print(f"[MultiNewsDataset stats]")
total_docs = 0
for lang in self.dataset_langs:
if lang not in self.excluded_langs:
_len = len(self.lang_multiModalDataset[lang].data)
total_docs += _len
print(
f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].labels)}"
)
print(f" - total docs: {total_docs}\n")
def _count_lang_labels(self, labels):
lang_labels = set()
for l in labels:
lang_labels.update(l)
return len(lang_labels)
def export_to_torch_dataset(self, tokenizer_id):
raise NotImplementedError
def save_to_disk(self):
raise NotImplementedError
def training(self):
# TODO: this is a (working) mess - clean this up
lXtr = {}
lYtr = {}
for lang, data in self.lang_multiModalDataset.items():
_data = [(clean_text, img) for _, clean_text, _, img in data.data]
lXtr[lang] = _data
lYtr = {
lang: self.label_binarizer.transform(data.labels)
for lang, data in self.lang_multiModalDataset.items()
}
return lXtr, lYtr
def testing(self):
raise NotImplementedError
def get_labels(self):
all_labels = set()
for lang, data in self.lang_multiModalDataset.items():
for label in data.labels:
all_labels.update(label)
return all_labels
class MultiModalDataset:
def __init__(self, lang, data_dir):
self.lang = lang
self.data_dir = data_dir
self.re_labels = re.compile(r"<a rel=\"tag\" href=\"\/tag\/.+?\/\">(.+?)<\/a>")
self.re_cleaner = re.compile(r"<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
self.re_white = re.compile(r" +")
self.data, self.labels = self.get_docs()
def get_imgs(self):
raise NotImplementedError
def get_labels(self):
raise NotImplementedError
def get_ids(self):
raise NotImplementedError
def get_docs(self):
# FIXME: this is a mess
data = []
labels = []
news_folder = [doc_folder for doc_folder in listdir(self.data_dir)]
for news_folder in news_folder:
if isdir(join(self.data_dir, news_folder)):
fname_doc = f"text.{news_folder.split('.')[-1]}"
with open(join(self.data_dir, news_folder, fname_doc)) as f:
html_doc = f.read()
index_path = join(self.data_dir, news_folder, "index.html")
if not any(
File.endswith(".jpg")
for File in listdir(join(self.data_dir, news_folder))
):
img_link, img = self.get_images(index_path)
self.save_img(join(self.data_dir, news_folder, "img.jpg"), img)
# TODO: convert img to PIL image
img = Image.open(join(self.data_dir, news_folder, "img.jpg"))
clean_doc, doc_labels = self.preprocess_html(html_doc)
data.append((fname_doc, clean_doc, html_doc, img))
labels.append(doc_labels)
return data, labels
def save_img(self, path, img):
with open(path, "wb") as f:
f.write(img)
def get_images(self, index_path):
imgs = BeautifulSoup(open(index_path), "html.parser").findAll("img")
imgs = imgs[1]
# TODO: forcing to take the first image (i.e. index 1 should be the main image)
content = requests.get(imgs["src"]).content
return imgs, content
def preprocess_html(self, html_doc):
# TODO: this could be replaced by BeautifulSoup call or something similar
labels = self._extract_labels(html_doc)
cleaned = self._clean_up_str(self._remove_html_tags(html_doc))
return cleaned, labels
def _extract_labels(self, data):
return re.findall(self.re_labels, data)
def _remove_html_tags(self, data):
cleaned = re.sub(self.re_cleaner, "", data)
return cleaned
def _clean_up_str(self, doc):
doc = re.sub(self.re_white, " ", doc)
doc = doc.lstrip()
doc = doc.rstrip()
doc = doc.replace("\n", " ")
doc = doc.replace("\t", " ")
return doc
if __name__ == "__main__":
from os.path import expanduser
_dataset_path_hardcoded = "~/datasets/MultiNews/20110730/"
dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True)
lXtr, lYtr = dataset.training()
exit()