import tarfile from collections import defaultdict from pathlib import Path import faiss import numpy as np import pyarrow as pa import requests from tqdm import tqdm __all__ = ["RetrievalDatabase", "download_retrieval_databases"] RETRIEVAL_DATABASES_URLS = { "cc12m": { "url": "https://storage-cased.alessandroconti.me/cc12m.tar.gz", "cache_subdir": "./cc12m/vit-l-14/", }, } def download_retrieval_databases(cache_dir: str = "~/.cache/cased"): """Download data if needed. Args: cache_dir (str): Path to cache directory. Defaults to "~/.cache/cased". """ databases_path = Path(cache_dir, "databases") for name, items in RETRIEVAL_DATABASES_URLS.items(): url = items["url"] database_path = Path(databases_path, name) if database_path.exists(): continue # download data target_path = Path(databases_path, name + ".tar.gz") target_path.parent.mkdir(parents=True, exist_ok=True) with requests.get(url, stream=True) as r: r.raise_for_status() total_bytes_size = int(r.headers.get("content-length", 0)) chunk_size = 8192 p_bar = tqdm( desc="Downloading cc12m index", total=total_bytes_size, unit="iB", unit_scale=True, ) with open(target_path, "wb") as f: for chunk in r.iter_content(chunk_size=chunk_size): f.write(chunk) p_bar.update(len(chunk)) p_bar.close() # extract data tar = tarfile.open(target_path, "r:gz") tar.extractall(target_path.parent) tar.close() target_path.unlink() class RetrievalDatabaseMetadataProvider: """Metadata provider for the retrieval database. Args: metadata_dir (str): Path to the metadata directory. """ def __init__(self, metadata_dir: str): metadatas = [str(a) for a in sorted(Path(metadata_dir).glob("**/*")) if a.is_file()] self.table = pa.concat_tables( [ pa.ipc.RecordBatchFileReader(pa.memory_map(metadata, "r")).read_all() for metadata in metadatas ] ) def get(self, ids): """Get the metadata for the given ids. Args: ids (list): List of ids. """ columns = self.table.schema.names end_ids = [i + 1 for i in ids] t = pa.concat_tables([self.table[start:end] for start, end in zip(ids, end_ids)]) return t.select(columns).to_pandas().to_dict("records") class RetrievalDatabase: """Retrieval database. Args: database_name (str): Name of the database. cache_dir (str): Path to cache directory. Defaults to "~/.cache/cased". """ def __init__(self, database_name: str, cache_dir: str = "~/.cache/cased"): assert database_name in RETRIEVAL_DATABASES_URLS.keys(), ( f"Database name should be one of " f"{list(RETRIEVAL_DATABASES_URLS.keys())}, got {database_name}." ) database_dir = Path(cache_dir) / "databases" database_dir = database_dir / RETRIEVAL_DATABASES_URLS[database_name]["cache_subdir"] self._database_dir = database_dir image_index_fp = Path(database_dir) / "image.index" text_index_fp = Path(database_dir) / "text.index" image_index = ( faiss.read_index(str(image_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) if image_index_fp.exists() else None ) text_index = ( faiss.read_index(str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) if text_index_fp.exists() else None ) metadata_dir = str(Path(database_dir) / "metadata") metadata_provider = RetrievalDatabaseMetadataProvider(metadata_dir) self._image_index = image_index self._text_index = text_index self._metadata_provider = metadata_provider def _map_to_metadata(self, indices: list, distances: list, embs: list, num_images: int): """Map the indices to metadata. Args: indices (list): List of indices. distances (list): List of distances. embs (list): List of results embeddings. num_images (int): Number of images. """ results = [] metas = self._metadata_provider.get(indices[:num_images]) for key, (d, i, emb) in enumerate(zip(distances, indices, embs)): output = {} meta = None if key + 1 > len(metas) else metas[key] if meta is not None: output.update(self._meta_to_dict(meta)) output["id"] = i.item() output["similarity"] = d.item() output["sample_z"] = emb.tolist() results.append(output) return results def _meta_to_dict(self, metadata): """Convert metadata to dict. Args: metadata (dict): Metadata. """ output = {} for k, v in metadata.items(): if isinstance(v, bytes): v = v.decode() elif type(v).__module__ == np.__name__: v = v.item() output[k] = v return output def _get_connected_components(self, neighbors): """Find connected components in a graph. Args: neighbors (dict): Dictionary of neighbors. """ seen = set() def component(node): r = [] nodes = {node} while nodes: node = nodes.pop() seen.add(node) nodes |= set(neighbors[node]) - seen r.append(node) return r u = [] for node in neighbors: if node not in seen: u.append(component(node)) return u def _deduplicate_embeddings(self, embeddings, threshold=0.94): """Deduplicate embeddings. Args: embeddings (np.matrix): Embeddings to deduplicate. threshold (float): Threshold to use for deduplication. Default is 0.94. """ index = faiss.IndexFlatIP(embeddings.shape[1]) index.add(embeddings) l, _, indices = index.range_search(embeddings, threshold) same_mapping = defaultdict(list) for i in range(embeddings.shape[0]): start = l[i] end = l[i + 1] for j in indices[start:end]: same_mapping[int(i)].append(int(j)) groups = self._get_connected_components(same_mapping) non_uniques = set() for g in groups: for e in g[1:]: non_uniques.add(e) return set(list(non_uniques)) def query( self, query: np.matrix, modality: str = "text", num_samples: int = 10 ) -> list[list[dict]]: """Query the database. Args: query (np.matrix): Query to search. modality (str): Modality to search. One of `image` or `text`. Default to `text`. num_samples (int): Number of samples to return. Default is 40. """ index = self._image_index if modality == "image" else self._text_index distances, indices, embeddings = index.search_and_reconstruct(query, num_samples) results = [indices[i] for i in range(len(indices))] nb_results = [np.where(r == -1)[0] for r in results] total_distances = [] total_indices = [] total_embeddings = [] for i in range(len(results)): num_res = nb_results[i][0] if len(nb_results[i]) > 0 else len(results[i]) result_indices = results[i][:num_res] result_distances = distances[i][:num_res] result_embeddings = embeddings[i][:num_res] # normalise embeddings l2 = np.atleast_1d(np.linalg.norm(result_embeddings, 2, -1)) l2[l2 == 0] = 1 result_embeddings = result_embeddings / np.expand_dims(l2, -1) # deduplicate embeddings local_indices_to_remove = self._deduplicate_embeddings(result_embeddings) indices_to_remove = set() for local_index in local_indices_to_remove: indices_to_remove.add(result_indices[local_index]) curr_indices = [] curr_distances = [] curr_embeddings = [] for ind, dis, emb in zip(result_indices, result_distances, result_embeddings): if ind not in indices_to_remove: indices_to_remove.add(ind) curr_indices.append(ind) curr_distances.append(dis) curr_embeddings.append(emb) total_indices.append(curr_indices) total_distances.append(curr_distances) total_embeddings.append(curr_embeddings) if len(total_distances) == 0: return [] total_results = [] for i in range(len(total_distances)): results = self._map_to_metadata( total_indices[i], total_distances[i], total_embeddings[i], num_samples ) total_results.append(results) return total_results