gfun_multimodal/dataManager/amazonDataset.py

371 lines
12 KiB
Python

import gzip
import os
import re
import warnings
from argparse import ArgumentParser
from collections import Counter
import numpy as np
from bs4 import BeautifulSoup
from sklearn.preprocessing import MultiLabelBinarizer
from plotters.distributions import plot_distribution
# TODO: AmazonDataset should be a instanc of MultimodalDataset
warnings.filterwarnings("ignore", category=UserWarning, module="bs4")
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
BASEPATH = "/home/moreo/Datasets/raw"
with open("dataManager/excluded.csv", "r") as f:
EXCLUDED = f.read().splitlines()
REGEX = re.compile(r"\s{2,}", re.MULTILINE)
def parse(dataset_name, ext="json.gz", nrows=0):
dataset_name = dataset_name.replace(" ", "_")
meta_path = os.path.join(BASEPATH, f"meta_{dataset_name}.{ext}")
path = os.path.join(BASEPATH, f"{dataset_name}.{ext}")
mapper = {"false": False, "true": True}
data = []
metadata = []
_data = gzip.open(path, "r")
_metadata = gzip.open(meta_path, "r")
for i, (d, m) in enumerate(zip(_data, _metadata)):
data.append(eval(d.replace(b"&", b"&"), mapper))
metadata.append(eval(m.replace(b"&", b"&"), mapper))
if i + 1 == nrows:
break
return data, metadata
def get_categories(data, min_count=0):
if data[0].get("category", None) is None:
return [], set()
categories = []
for item in data:
if item["category"] != "":
categories.extend(item["category"])
categories = list(filter(lambda x: x not in EXCLUDED, categories))
# return categories, sorted(set(categories))
return categories, _filter_counter(Counter(categories), min_count)
def _filter_counter(counter, min_count):
return {k: v for k, v in counter.items() if v >= min_count}
def get_main_cat(data, min_count=0):
if data[0].get("main_cat", None) is None:
return [], set()
main_cats = [item["main_cat"] for item in data if item["main_cat"] != ""]
main_cats = list(filter(lambda x: x not in EXCLUDED, main_cats))
# return main_cats, sorted(set(main_cats))
return main_cats, _filter_counter(Counter(main_cats), min_count)
def filter_sample_with_images(metadata):
# TODO: check whether images are really available and store them locally
# print(f"(Pre-filter) Total items: {len(metadata)}")
data = []
for i, m in enumerate(metadata):
if "imageURL" not in m.keys():
continue
if len(m["imageURL"]) != 0 or len(m["imageURLHighRes"]) != 0:
data.append(m)
# print(f"(Post-filter) Total items: {len(data)}")
return data
def select_description(descriptions):
"""
Some items have multiple descriptions (len(item["description"]) > 1).
Most of these descriptions are just empty strings. Some items instead actually have
multiple strings describing them
At the moment, we rely on a simple heuristic: select the longest string and use it
the only description.
"""
if len(descriptions) == 0:
return [""]
return [max(descriptions, key=len)]
def build_product_json(metadata, binarizer):
data = []
for item in metadata:
if len(item["description"]) != 1:
item["description"] = select_description(item["description"])
product = {
"asin": item["asin"],
"title": item["title"],
"description": item["description"],
# TODO: some items have multiple descriptions (len(item["description"]) > 1))
"cleaned_description": clean_description(
BeautifulSoup(
item["title"] + ". " + item["description"][0],
features="html.parser",
).text
),
# TODO: is it faster to call transform on the whole dataset?
"main_category": item["main_cat"],
"categories": item["category"],
"all_categories": _get_cats(item["main_cat"], item["category"]),
"vect_categories": binarizer.transform(
[_get_cats(item["main_cat"], item["category"])]
)[0],
}
data.append(product)
return data
def _get_cats(main_cat, cats):
return [main_cat] + cats
def get_label_binarizer(cats):
mlb = MultiLabelBinarizer()
mlb.fit([cats])
return mlb
def clean_description(description):
description = re.sub(REGEX, " ", description)
description = description.rstrip()
description = description.replace("\t", "")
description = description.replace("\n", " ")
return description
def construct_target_matrix(data):
return np.stack([d["vect_categories"] for d in data], axis=0)
def get_all_classes(counter_cats, counter_sub_cats):
if len(counter_cats) == 0:
return counter_sub_cats.keys()
elif len(counter_sub_cats) == 0:
return counter_cats.keys()
else:
return list(counter_cats.keys()) + list(counter_sub_cats.keys())
class AmazonDataset:
def __init__(
self,
domains=["Appliances", "Automotive", "Movies and TV"],
basepath="/home/moreo/Datasets/raw",
min_count=10,
max_labels=50,
nrows=1000,
):
print(f"[Init AmazonDataset]")
print(f"- Domains: {domains}")
self.REGEX = re.compile(r"\s{2,}", re.MULTILINE)
with open("dataManager/excluded.csv", "r") as f:
self.EXCLUDED = f.read().splitlines()
self.basepath = basepath
self.domains = self.parse_domains(domains)
self.nrows = nrows
self.min_count = min_count
self.max_labels = max_labels
self.len_data = 0
self.domain_data = self.load_data()
self.labels, self.domain_labels = self.get_all_cats()
self.label_binarizer = get_label_binarizer(self.labels)
self.vectorized_labels = self.vecorize_labels()
self.dX = self.construct_data_matrix()
self.dY = self.construct_target_matrix()
self.langs = ["en"]
def parse_domains(self, domains):
with open("amazon_categories.txt", "r") as f:
all_domains = f.read().splitlines()
if domains == "all":
return all_domains
else:
assert all([d in all_domains for d in domains]), "Invalid domain name"
return domains
def parse(self, dataset_name, nrows, ext="json.gz"):
dataset_name = dataset_name.replace(" ", "_")
meta_path = os.path.join(self.basepath, f"meta_{dataset_name}.{ext}")
path = os.path.join(self.basepath, f"{dataset_name}.{ext}")
mapper = {"false": False, "true": True}
data = []
metadata = []
_data = gzip.open(path, "r")
_metadata = gzip.open(meta_path, "r")
for i, (d, m) in enumerate(zip(_data, _metadata)):
data.append(eval(d.replace(b"&", b"&"), mapper))
metadata.append(eval(m.replace(b"&", b"&"), mapper))
if i + 1 == nrows:
break
return data, metadata
def load_data(self):
print(f"- Loading up to {self.nrows} items per domain")
domain_data = {}
for domain in self.domains:
_, metadata = self.parse(domain, nrows=self.nrows)
metadata = filter_sample_with_images(metadata)
domain_data[domain] = self.build_product_scheme(metadata)
self.len_data += len(metadata)
print(f"- Loaded {self.len_data} items")
return domain_data
def get_all_cats(self):
assert len(self.domain_data) != 0, "Load data first"
labels = set()
domain_labels = {}
for domain, data in self.domain_data.items():
_, counter_cats = self._get_counter_cats(data, self.min_count)
labels.update(counter_cats.keys())
domain_labels[domain] = counter_cats
print(f"- Found {len(labels)} labels")
return labels, domain_labels
def export_to_torch(self):
pass
def get_label_binarizer(self):
mlb = MultiLabelBinarizer()
mlb.fit([self.labels])
return mlb
def vecorize_labels(self):
for domain, data in self.domain_data.items():
for item in data:
item["vect_categories"] = self.label_binarizer.transform(
[item["all_categories"]]
)[0]
def build_product_scheme(self, metadata):
data = []
for item in metadata:
if len(item["description"]) != 1:
_desc = self._select_description(item["description"])
else:
_desc = item["description"][0]
product = {
"asin": item["asin"],
"title": item["title"],
"description": _desc,
# TODO: some items have multiple descriptions (len(item["description"]) > 1))
"cleaned_text": self._clean_description(
BeautifulSoup(
item["title"] + ". " + _desc,
features="html.parser",
).text
),
# TODO: is it faster to call transform on the whole dataset?
"main_category": item["main_cat"],
"categories": item["category"],
"all_categories": self._get_cats(item["main_cat"], item["category"]),
# "vect_categories": binarizer.transform(
# [_get_cats(item["main_cat"], item["category"])]
# )[0],
}
data.append(product)
return data
def construct_data_matrix(self):
dX = {}
for domain, data in self.domain_data.items():
dX[domain] = [d["cleaned_text"] for d in data]
return dX
def construct_target_matrix(self):
dY = {}
for domain, data in self.domain_data.items():
dY[domain] = np.stack([d["vect_categories"] for d in data], axis=0)
return dY
def get_overall_label_matrix(self):
assert hasattr(self, "label_matrices"), "Init label matrices first"
return np.vstack([x for x in self.dY.values()])
def _get_counter_cats(self, data, min_count):
cats = []
for item in data:
cats.extend(item["all_categories"])
cats = list(filter(lambda x: x not in self.EXCLUDED, cats))
return cats, self._filter_counter(Counter(cats), min_count)
def _filter_counter(self, counter, min_count):
return {k: v for k, v in counter.items() if v >= min_count}
def _clean_description(self, description):
description = re.sub(self.REGEX, " ", description)
description = description.rstrip()
description = description.replace("\t", "")
description = description.replace("\n", " ")
return description
def _get_cats(self, main_cat, cats):
return [main_cat] + cats
def _select_description(self, descriptions) -> str:
"""
Some items have multiple descriptions (len(item["description"]) > 1).
Most of these descriptions are just empty strings. Some items instead actually have
multiple strings describing them
At the moment, we rely on a simple heuristic: select the longest string and use it
the only description.
"""
if len(descriptions) == 0:
return ""
return max(descriptions, key=len)
def plot_label_distribution(self):
overall_mat = self.get_overall_label_matrix()
plot_distribution(
np.arange(len(self.labels)),
np.sum(overall_mat, axis=0),
title="Amazon Dataset",
labels=self.labels,
notes=overall_mat.shape,
max_labels=args.max_labels,
figsize=(10, 10),
save=True,
path="out",
)
def plot_per_domain_label_distribution(self):
for domain, matrix in self.vecorize_labels:
pass
def main(args):
dataset = AmazonDataset(
domains=args.domains,
nrows=args.nrows,
min_count=args.min_count,
max_labels=args.max_labels,
)
dataset.plot_label_distribution()
exit()
if __name__ == "__main__":
import sys
sys.path.append("/home/andreapdr/devel/gFunMultiModal/")
parser = ArgumentParser()
parser.add_argument("--domains", type=str, default="all")
parser.add_argument("--nrows", type=int, default=10000)
parser.add_argument("--min_count", type=int, default=10)
parser.add_argument("--max_labels", type=int, default=50)
args = parser.parse_args()
main(args)