371 lines
12 KiB
Python
371 lines
12 KiB
Python
import gzip
|
|
import os
|
|
import re
|
|
import warnings
|
|
from argparse import ArgumentParser
|
|
from collections import Counter
|
|
|
|
import numpy as np
|
|
from bs4 import BeautifulSoup
|
|
from sklearn.preprocessing import MultiLabelBinarizer
|
|
|
|
from plotters.distributions import plot_distribution
|
|
|
|
# TODO: AmazonDataset should be a instanc of MultimodalDataset
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="bs4")
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
|
|
|
|
BASEPATH = "/home/moreo/Datasets/raw"
|
|
with open("dataManager/excluded.csv", "r") as f:
|
|
EXCLUDED = f.read().splitlines()
|
|
REGEX = re.compile(r"\s{2,}", re.MULTILINE)
|
|
|
|
|
|
def parse(dataset_name, ext="json.gz", nrows=0):
|
|
dataset_name = dataset_name.replace(" ", "_")
|
|
meta_path = os.path.join(BASEPATH, f"meta_{dataset_name}.{ext}")
|
|
path = os.path.join(BASEPATH, f"{dataset_name}.{ext}")
|
|
|
|
mapper = {"false": False, "true": True}
|
|
data = []
|
|
metadata = []
|
|
|
|
_data = gzip.open(path, "r")
|
|
_metadata = gzip.open(meta_path, "r")
|
|
for i, (d, m) in enumerate(zip(_data, _metadata)):
|
|
data.append(eval(d.replace(b"&", b"&"), mapper))
|
|
metadata.append(eval(m.replace(b"&", b"&"), mapper))
|
|
if i + 1 == nrows:
|
|
break
|
|
|
|
return data, metadata
|
|
|
|
|
|
def get_categories(data, min_count=0):
|
|
if data[0].get("category", None) is None:
|
|
return [], set()
|
|
|
|
categories = []
|
|
for item in data:
|
|
if item["category"] != "":
|
|
categories.extend(item["category"])
|
|
categories = list(filter(lambda x: x not in EXCLUDED, categories))
|
|
# return categories, sorted(set(categories))
|
|
return categories, _filter_counter(Counter(categories), min_count)
|
|
|
|
|
|
def _filter_counter(counter, min_count):
|
|
return {k: v for k, v in counter.items() if v >= min_count}
|
|
|
|
|
|
def get_main_cat(data, min_count=0):
|
|
if data[0].get("main_cat", None) is None:
|
|
return [], set()
|
|
|
|
main_cats = [item["main_cat"] for item in data if item["main_cat"] != ""]
|
|
main_cats = list(filter(lambda x: x not in EXCLUDED, main_cats))
|
|
# return main_cats, sorted(set(main_cats))
|
|
return main_cats, _filter_counter(Counter(main_cats), min_count)
|
|
|
|
|
|
def filter_sample_with_images(metadata):
|
|
# TODO: check whether images are really available and store them locally
|
|
# print(f"(Pre-filter) Total items: {len(metadata)}")
|
|
data = []
|
|
for i, m in enumerate(metadata):
|
|
if "imageURL" not in m.keys():
|
|
continue
|
|
if len(m["imageURL"]) != 0 or len(m["imageURLHighRes"]) != 0:
|
|
data.append(m)
|
|
# print(f"(Post-filter) Total items: {len(data)}")
|
|
return data
|
|
|
|
|
|
def select_description(descriptions):
|
|
"""
|
|
Some items have multiple descriptions (len(item["description"]) > 1).
|
|
Most of these descriptions are just empty strings. Some items instead actually have
|
|
multiple strings describing them
|
|
At the moment, we rely on a simple heuristic: select the longest string and use it
|
|
the only description.
|
|
"""
|
|
if len(descriptions) == 0:
|
|
return [""]
|
|
return [max(descriptions, key=len)]
|
|
|
|
|
|
def build_product_json(metadata, binarizer):
|
|
data = []
|
|
for item in metadata:
|
|
if len(item["description"]) != 1:
|
|
item["description"] = select_description(item["description"])
|
|
|
|
product = {
|
|
"asin": item["asin"],
|
|
"title": item["title"],
|
|
"description": item["description"],
|
|
# TODO: some items have multiple descriptions (len(item["description"]) > 1))
|
|
"cleaned_description": clean_description(
|
|
BeautifulSoup(
|
|
item["title"] + ". " + item["description"][0],
|
|
features="html.parser",
|
|
).text
|
|
),
|
|
# TODO: is it faster to call transform on the whole dataset?
|
|
"main_category": item["main_cat"],
|
|
"categories": item["category"],
|
|
"all_categories": _get_cats(item["main_cat"], item["category"]),
|
|
"vect_categories": binarizer.transform(
|
|
[_get_cats(item["main_cat"], item["category"])]
|
|
)[0],
|
|
}
|
|
data.append(product)
|
|
return data
|
|
|
|
|
|
def _get_cats(main_cat, cats):
|
|
return [main_cat] + cats
|
|
|
|
|
|
def get_label_binarizer(cats):
|
|
mlb = MultiLabelBinarizer()
|
|
mlb.fit([cats])
|
|
return mlb
|
|
|
|
|
|
def clean_description(description):
|
|
description = re.sub(REGEX, " ", description)
|
|
description = description.rstrip()
|
|
description = description.replace("\t", "")
|
|
description = description.replace("\n", " ")
|
|
return description
|
|
|
|
|
|
def construct_target_matrix(data):
|
|
return np.stack([d["vect_categories"] for d in data], axis=0)
|
|
|
|
|
|
def get_all_classes(counter_cats, counter_sub_cats):
|
|
if len(counter_cats) == 0:
|
|
return counter_sub_cats.keys()
|
|
elif len(counter_sub_cats) == 0:
|
|
return counter_cats.keys()
|
|
else:
|
|
return list(counter_cats.keys()) + list(counter_sub_cats.keys())
|
|
|
|
|
|
class AmazonDataset:
|
|
def __init__(
|
|
self,
|
|
domains=["Appliances", "Automotive", "Movies and TV"],
|
|
basepath="/home/moreo/Datasets/raw",
|
|
min_count=10,
|
|
max_labels=50,
|
|
nrows=1000,
|
|
):
|
|
print(f"[Init AmazonDataset]")
|
|
print(f"- Domains: {domains}")
|
|
self.REGEX = re.compile(r"\s{2,}", re.MULTILINE)
|
|
with open("dataManager/excluded.csv", "r") as f:
|
|
self.EXCLUDED = f.read().splitlines()
|
|
self.basepath = basepath
|
|
self.domains = self.parse_domains(domains)
|
|
self.nrows = nrows
|
|
self.min_count = min_count
|
|
self.max_labels = max_labels
|
|
self.len_data = 0
|
|
self.domain_data = self.load_data()
|
|
self.labels, self.domain_labels = self.get_all_cats()
|
|
self.label_binarizer = get_label_binarizer(self.labels)
|
|
self.vectorized_labels = self.vecorize_labels()
|
|
self.dX = self.construct_data_matrix()
|
|
self.dY = self.construct_target_matrix()
|
|
self.langs = ["en"]
|
|
|
|
def parse_domains(self, domains):
|
|
with open("amazon_categories.txt", "r") as f:
|
|
all_domains = f.read().splitlines()
|
|
if domains == "all":
|
|
return all_domains
|
|
else:
|
|
assert all([d in all_domains for d in domains]), "Invalid domain name"
|
|
return domains
|
|
|
|
def parse(self, dataset_name, nrows, ext="json.gz"):
|
|
dataset_name = dataset_name.replace(" ", "_")
|
|
meta_path = os.path.join(self.basepath, f"meta_{dataset_name}.{ext}")
|
|
path = os.path.join(self.basepath, f"{dataset_name}.{ext}")
|
|
|
|
mapper = {"false": False, "true": True}
|
|
data = []
|
|
metadata = []
|
|
|
|
_data = gzip.open(path, "r")
|
|
_metadata = gzip.open(meta_path, "r")
|
|
for i, (d, m) in enumerate(zip(_data, _metadata)):
|
|
data.append(eval(d.replace(b"&", b"&"), mapper))
|
|
metadata.append(eval(m.replace(b"&", b"&"), mapper))
|
|
if i + 1 == nrows:
|
|
break
|
|
|
|
return data, metadata
|
|
|
|
def load_data(self):
|
|
print(f"- Loading up to {self.nrows} items per domain")
|
|
domain_data = {}
|
|
for domain in self.domains:
|
|
_, metadata = self.parse(domain, nrows=self.nrows)
|
|
metadata = filter_sample_with_images(metadata)
|
|
domain_data[domain] = self.build_product_scheme(metadata)
|
|
self.len_data += len(metadata)
|
|
print(f"- Loaded {self.len_data} items")
|
|
return domain_data
|
|
|
|
def get_all_cats(self):
|
|
assert len(self.domain_data) != 0, "Load data first"
|
|
labels = set()
|
|
domain_labels = {}
|
|
for domain, data in self.domain_data.items():
|
|
_, counter_cats = self._get_counter_cats(data, self.min_count)
|
|
labels.update(counter_cats.keys())
|
|
domain_labels[domain] = counter_cats
|
|
print(f"- Found {len(labels)} labels")
|
|
return labels, domain_labels
|
|
|
|
def export_to_torch(self):
|
|
pass
|
|
|
|
def get_label_binarizer(self):
|
|
mlb = MultiLabelBinarizer()
|
|
mlb.fit([self.labels])
|
|
return mlb
|
|
|
|
def vecorize_labels(self):
|
|
for domain, data in self.domain_data.items():
|
|
for item in data:
|
|
item["vect_categories"] = self.label_binarizer.transform(
|
|
[item["all_categories"]]
|
|
)[0]
|
|
|
|
def build_product_scheme(self, metadata):
|
|
data = []
|
|
for item in metadata:
|
|
if len(item["description"]) != 1:
|
|
_desc = self._select_description(item["description"])
|
|
else:
|
|
_desc = item["description"][0]
|
|
|
|
product = {
|
|
"asin": item["asin"],
|
|
"title": item["title"],
|
|
"description": _desc,
|
|
# TODO: some items have multiple descriptions (len(item["description"]) > 1))
|
|
"cleaned_text": self._clean_description(
|
|
BeautifulSoup(
|
|
item["title"] + ". " + _desc,
|
|
features="html.parser",
|
|
).text
|
|
),
|
|
# TODO: is it faster to call transform on the whole dataset?
|
|
"main_category": item["main_cat"],
|
|
"categories": item["category"],
|
|
"all_categories": self._get_cats(item["main_cat"], item["category"]),
|
|
# "vect_categories": binarizer.transform(
|
|
# [_get_cats(item["main_cat"], item["category"])]
|
|
# )[0],
|
|
}
|
|
data.append(product)
|
|
return data
|
|
|
|
def construct_data_matrix(self):
|
|
dX = {}
|
|
for domain, data in self.domain_data.items():
|
|
dX[domain] = [d["cleaned_text"] for d in data]
|
|
return dX
|
|
|
|
def construct_target_matrix(self):
|
|
dY = {}
|
|
for domain, data in self.domain_data.items():
|
|
dY[domain] = np.stack([d["vect_categories"] for d in data], axis=0)
|
|
return dY
|
|
|
|
def get_overall_label_matrix(self):
|
|
assert hasattr(self, "label_matrices"), "Init label matrices first"
|
|
return np.vstack([x for x in self.dY.values()])
|
|
|
|
def _get_counter_cats(self, data, min_count):
|
|
cats = []
|
|
for item in data:
|
|
cats.extend(item["all_categories"])
|
|
cats = list(filter(lambda x: x not in self.EXCLUDED, cats))
|
|
return cats, self._filter_counter(Counter(cats), min_count)
|
|
|
|
def _filter_counter(self, counter, min_count):
|
|
return {k: v for k, v in counter.items() if v >= min_count}
|
|
|
|
def _clean_description(self, description):
|
|
description = re.sub(self.REGEX, " ", description)
|
|
description = description.rstrip()
|
|
description = description.replace("\t", "")
|
|
description = description.replace("\n", " ")
|
|
return description
|
|
|
|
def _get_cats(self, main_cat, cats):
|
|
return [main_cat] + cats
|
|
|
|
def _select_description(self, descriptions) -> str:
|
|
"""
|
|
Some items have multiple descriptions (len(item["description"]) > 1).
|
|
Most of these descriptions are just empty strings. Some items instead actually have
|
|
multiple strings describing them
|
|
At the moment, we rely on a simple heuristic: select the longest string and use it
|
|
the only description.
|
|
"""
|
|
if len(descriptions) == 0:
|
|
return ""
|
|
return max(descriptions, key=len)
|
|
|
|
def plot_label_distribution(self):
|
|
overall_mat = self.get_overall_label_matrix()
|
|
plot_distribution(
|
|
np.arange(len(self.labels)),
|
|
np.sum(overall_mat, axis=0),
|
|
title="Amazon Dataset",
|
|
labels=self.labels,
|
|
notes=overall_mat.shape,
|
|
max_labels=args.max_labels,
|
|
figsize=(10, 10),
|
|
save=True,
|
|
path="out",
|
|
)
|
|
|
|
def plot_per_domain_label_distribution(self):
|
|
for domain, matrix in self.vecorize_labels:
|
|
pass
|
|
|
|
|
|
def main(args):
|
|
dataset = AmazonDataset(
|
|
domains=args.domains,
|
|
nrows=args.nrows,
|
|
min_count=args.min_count,
|
|
max_labels=args.max_labels,
|
|
)
|
|
|
|
dataset.plot_label_distribution()
|
|
exit()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
|
|
sys.path.append("/home/andreapdr/devel/gFunMultiModal/")
|
|
|
|
parser = ArgumentParser()
|
|
parser.add_argument("--domains", type=str, default="all")
|
|
parser.add_argument("--nrows", type=int, default=10000)
|
|
parser.add_argument("--min_count", type=int, default=10)
|
|
parser.add_argument("--max_labels", type=int, default=50)
|
|
args = parser.parse_args()
|
|
main(args)
|