import copy import logging import os import zipfile from tempfile import TemporaryFile from typing import BinaryIO, Optional, Dict import requests from tqdm import tqdm import pandas as pd from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer # From https://github.com/glami/glami-1m/blob/main/load_dataset.py GLAMI_DATAPATH = os.path.expanduser("~/datasets/GLAMI-1M-dataset") # DATASET_URL = os.environ.get( # "DATASET_URL", # "https://zenodo.org/record/7326406/files/GLAMI-1M-dataset.zip?download=1", # ) # EXTRACT_DIR = os.environ.get("EXTRACT_DIR", "/tmp/GLAMI-1M") # DATASET_SUBDIR = "GLAMI-1M-dataset" # DATASET_DIR = dataset_dir = EXTRACT_DIR + "/" + DATASET_SUBDIR # MODEL_DIR = os.environ.get("MODEL_DIR", "/tmp/GLAMI-1M/models") # EMBS_DIR = EXTRACT_DIR + "/embs" # CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-visual" # CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-textual" # # CLIP_VISUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-visual" # # CLIP_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-l5b-textual" # CLIP_EN_TEXTUAL_EMBS_DIR = EXTRACT_DIR + "/embs-clip-en-textual" # GENERATED_DIR = EXTRACT_DIR + "/generated_images" COL_NAME_ITEM_ID = "item_id" COL_NAME_IMAGE_ID = "image_id" COL_NAME_IMAGE_FILE = "image_file" COL_NAME_IMAGE_URL = "image_url" COL_NAME_NAME = "name" COL_NAME_DESCRIPTION = "description" COL_NAME_GEO = "geo" COL_NAME_CATEGORY = "category" COL_NAME_CAT_NAME = "category_name" COL_NAME_LABEL_SOURCE = "label_source" COL_NAME_EMB_FILE = "emb_file" COL_NAME_MASK_FILE = "mask_file" DEFAULT_IMAGE_SIZE = (298, 228) COUNTRY_CODE_TO_COUNTRY_NAME = { "cz": "Czechia", "sk": "Slovakia", "ro": "Romania", "gr": "Greece", "si": "Slovenia", "hu": "Hungary", "hr": "Croatia", "es": "Spain", "lt": "Lithuania", "lv": "Latvia", "tr": "Turkey", "ee": "Estonia", "bg": "Bulgaria", } COUNTRY_CODE_TO_COUNTRY_NAME_W_CC = { name + f" ({cc})" for cc, name in COUNTRY_CODE_TO_COUNTRY_NAME } def get_dataframe(split_type: str, dataset_dir=GLAMI_DATAPATH, nrows=None): assert split_type in ("train", "test") if nrows is not None: df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv", nrows=nrows) else: df = pd.read_csv(dataset_dir + f"/GLAMI-1M-{split_type}.csv") df[COL_NAME_IMAGE_FILE] = ( dataset_dir + "/images/" + df[COL_NAME_IMAGE_ID].astype(str) + ".jpg" ) df[COL_NAME_DESCRIPTION] = df[COL_NAME_DESCRIPTION].fillna("") assert os.path.exists(df.loc[0, COL_NAME_IMAGE_FILE]) return df[ [ COL_NAME_ITEM_ID, COL_NAME_IMAGE_ID, COL_NAME_NAME, COL_NAME_DESCRIPTION, COL_NAME_GEO, COL_NAME_CAT_NAME, COL_NAME_LABEL_SOURCE, COL_NAME_IMAGE_FILE, ] ] class GlamiDataset: def __init__(self, dataset_dir, langs=None, labels=None, nrows=None): self.dataset_dir = dataset_dir self.data_langs = langs self.labels = labels self.nrows = nrows self.multilingual_dataset = {} def num_labels(self): return len(self.labels) def langs(self): return self.data_langs def get_label_binarizer(self, labels): mlb = LabelBinarizer() mlb.fit(labels) print(f"- Label binarizer initialized with {len(mlb.classes_)} labels") return mlb def binarize_labels(self, labels): if hasattr(self, "mlb"): return self.mlb.transform(labels) else: raise ValueError("Label binarizer not initialized") def load_df(self, split, dataset_dir): return get_dataframe(split, dataset_dir=dataset_dir, nrows=self.nrows) def build_dataset(self): train_dataset = self.load_df("train", self.dataset_dir) test_dataset = self.load_df("test", self.dataset_dir) if self.data_langs is None: self.data_langs = train_dataset.geo.unique().tolist() if self.labels is None: self.labels = train_dataset.category_name.unique().tolist() self.mlb = self.get_label_binarizer(self.labels) self.multilingual_dataset = { lang: [data_tr, data_te] for (lang, data_tr), (_, data_te) in zip( train_dataset.groupby("geo"), test_dataset.groupby("geo") ) if lang in self.data_langs } def training(self): # TODO: tolist() or ??? lXtr = { lang: (df.name + " " + df.description).tolist() for lang, (df, _) in self.multilingual_dataset.items() } lYtr = { lang: self.binarize_labels(df.category_name.tolist()) for lang, (df, _) in self.multilingual_dataset.items() } return lXtr, lYtr def test(self): lXte = { lang: (df.name + " " + df.description).tolist() for lang, (_, df) in self.multilingual_dataset.items() } lYte = { lang: self.binarize_labels(df.category_name.tolist()) for lang, (_, df) in self.multilingual_dataset.items() } return lXte, lYte if __name__ == "__main__": print("Hello glamiDataset") dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=None) dataset.build_dataset() lXtr, lYtr = dataset.training() lXte, lYte = dataset.testing() exit(0)