import h5py import numpy as np import WebAppSettings as settings class GEMSearcher: def __init__(self): #self.dataset = h5py.File(settings.dataset_file, 'r')['rmac'][...] #np.save('/media/Data/data/beni_culturali/deploy/dataset', self.dataset) self.descs = np.load(settings.DATASET_GEM) #self.desc1 = np.load(settings.DATASET1) #self.desc2 = np.load(settings.DATASET2) #self.descs = (self.desc1 + self.desc2) / 2 #self.descs /= np.linalg.norm(self.descs, axis=1, keepdims=True) self.ids = np.loadtxt(settings.DATASET_IDS, dtype=str).tolist() def get_id(self, idx): return self.ids[idx] def add(self, desc, id): self.ids.append(id) self.descs = np.vstack((self.descs, desc)) self.save() def remove(self, id): idx = self.ids.index(id) del self.ids[idx] self.descs = np.delete(self.descs, idx, axis=0) def search_by_id(self, query_id, k=10): query_idx = self.ids.index(query_id) return self.search_by_img(self.descs[query_idx], k) def search_by_img(self, query, k=10): # print('----------query features-------') #print(query) dot_product = np.dot(self.descs, query[0]) idx = dot_product.argsort()[::-1][:k] res = [] for i in idx: res.append((self.ids[i], round(float(dot_product[i]), 3))) return res def save(self, is_backup=False): descs_file = settings.DATASET ids_file = settings.DATASET_IDS if is_backup: descs_file += '.bak' ids_file += '.bak' np.save(descs_file, self.descs) np.savetxt(ids_file, self.ids, fmt='%s')