|
import tarfile |
|
from collections import defaultdict |
|
from pathlib import Path |
|
|
|
import faiss |
|
import numpy as np |
|
import pyarrow as pa |
|
import requests |
|
from tqdm import tqdm |
|
|
|
__all__ = ["RetrievalDatabase", "download_retrieval_databases"] |
|
|
|
RETRIEVAL_DATABASES_URLS = { |
|
"cc12m": { |
|
"url": "https://storage-cased.alessandroconti.me/cc12m.tar.gz", |
|
"cache_subdir": "./cc12m/vit-l-14/", |
|
}, |
|
} |
|
|
|
|
|
def download_retrieval_databases(cache_dir: str): |
|
"""Download data if needed. |
|
|
|
Args: |
|
cache_dir (str): Path to cache directory. Defaults to "~/.cache/cased". |
|
""" |
|
databases_path = Path(cache_dir, "databases") |
|
|
|
for name, items in RETRIEVAL_DATABASES_URLS.items(): |
|
url = items["url"] |
|
database_path = Path(databases_path, name) |
|
if database_path.exists(): |
|
continue |
|
|
|
|
|
target_path = Path(databases_path, name + ".tar.gz") |
|
target_path.parent.mkdir(parents=True, exist_ok=True) |
|
with requests.get(url, stream=True) as r: |
|
r.raise_for_status() |
|
total_bytes_size = int(r.headers.get("content-length", 0)) |
|
chunk_size = 8192 |
|
p_bar = tqdm( |
|
desc="Downloading cc12m index", |
|
total=total_bytes_size, |
|
unit="iB", |
|
unit_scale=True, |
|
) |
|
with open(target_path, "wb") as f: |
|
for chunk in r.iter_content(chunk_size=chunk_size): |
|
f.write(chunk) |
|
p_bar.update(len(chunk)) |
|
p_bar.close() |
|
|
|
|
|
tar = tarfile.open(target_path, "r:gz") |
|
tar.extractall(target_path.parent) |
|
tar.close() |
|
target_path.unlink() |
|
|
|
|
|
class RetrievalDatabaseMetadataProvider: |
|
"""Metadata provider for the retrieval database. |
|
|
|
Args: |
|
metadata_dir (str): Path to the metadata directory. |
|
""" |
|
|
|
def __init__(self, metadata_dir: str): |
|
metadatas = [str(a) for a in sorted(Path(metadata_dir).glob("**/*")) if a.is_file()] |
|
self.table = pa.concat_tables( |
|
[ |
|
pa.ipc.RecordBatchFileReader(pa.memory_map(metadata, "r")).read_all() |
|
for metadata in metadatas |
|
] |
|
) |
|
|
|
def get(self, ids): |
|
"""Get the metadata for the given ids. |
|
|
|
Args: |
|
ids (list): List of ids. |
|
""" |
|
columns = self.table.schema.names |
|
end_ids = [i + 1 for i in ids] |
|
t = pa.concat_tables([self.table[start:end] for start, end in zip(ids, end_ids)]) |
|
return t.select(columns).to_pandas().to_dict("records") |
|
|
|
|
|
class RetrievalDatabase: |
|
"""Retrieval database. |
|
|
|
Args: |
|
database_name (str): Name of the database. |
|
cache_dir (str): Path to cache directory. Defaults to "~/.cache/cased". |
|
""" |
|
|
|
def __init__(self, database_name: str, cache_dir: str): |
|
assert database_name in RETRIEVAL_DATABASES_URLS.keys(), ( |
|
f"Database name should be one of " |
|
f"{list(RETRIEVAL_DATABASES_URLS.keys())}, got {database_name}." |
|
) |
|
|
|
database_dir = Path(cache_dir) / "databases" |
|
database_dir = database_dir / RETRIEVAL_DATABASES_URLS[database_name]["cache_subdir"] |
|
self._database_dir = database_dir |
|
|
|
image_index_fp = Path(database_dir) / "image.index" |
|
text_index_fp = Path(database_dir) / "text.index" |
|
|
|
image_index = ( |
|
faiss.read_index(str(image_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) |
|
if image_index_fp.exists() |
|
else None |
|
) |
|
text_index = ( |
|
faiss.read_index(str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) |
|
if text_index_fp.exists() |
|
else None |
|
) |
|
|
|
metadata_dir = str(Path(database_dir) / "metadata") |
|
metadata_provider = RetrievalDatabaseMetadataProvider(metadata_dir) |
|
|
|
self._image_index = image_index |
|
self._text_index = text_index |
|
self._metadata_provider = metadata_provider |
|
|
|
def _map_to_metadata(self, indices: list, distances: list, embs: list, num_images: int): |
|
"""Map the indices to metadata. |
|
|
|
Args: |
|
indices (list): List of indices. |
|
distances (list): List of distances. |
|
embs (list): List of results embeddings. |
|
num_images (int): Number of images. |
|
""" |
|
results = [] |
|
metas = self._metadata_provider.get(indices[:num_images]) |
|
for key, (d, i, emb) in enumerate(zip(distances, indices, embs)): |
|
output = {} |
|
meta = None if key + 1 > len(metas) else metas[key] |
|
if meta is not None: |
|
output.update(self._meta_to_dict(meta)) |
|
output["id"] = i.item() |
|
output["similarity"] = d.item() |
|
output["sample_z"] = emb.tolist() |
|
results.append(output) |
|
|
|
return results |
|
|
|
def _meta_to_dict(self, metadata): |
|
"""Convert metadata to dict. |
|
|
|
Args: |
|
metadata (dict): Metadata. |
|
""" |
|
output = {} |
|
for k, v in metadata.items(): |
|
if isinstance(v, bytes): |
|
v = v.decode() |
|
elif type(v).__module__ == np.__name__: |
|
v = v.item() |
|
output[k] = v |
|
return output |
|
|
|
def _get_connected_components(self, neighbors): |
|
"""Find connected components in a graph. |
|
|
|
Args: |
|
neighbors (dict): Dictionary of neighbors. |
|
""" |
|
seen = set() |
|
|
|
def component(node): |
|
r = [] |
|
nodes = {node} |
|
while nodes: |
|
node = nodes.pop() |
|
seen.add(node) |
|
nodes |= set(neighbors[node]) - seen |
|
r.append(node) |
|
return r |
|
|
|
u = [] |
|
for node in neighbors: |
|
if node not in seen: |
|
u.append(component(node)) |
|
return u |
|
|
|
def _deduplicate_embeddings(self, embeddings, threshold=0.94): |
|
"""Deduplicate embeddings. |
|
|
|
Args: |
|
embeddings (np.matrix): Embeddings to deduplicate. |
|
threshold (float): Threshold to use for deduplication. Default is 0.94. |
|
""" |
|
index = faiss.IndexFlatIP(embeddings.shape[1]) |
|
index.add(embeddings) |
|
l, _, indices = index.range_search(embeddings, threshold) |
|
|
|
same_mapping = defaultdict(list) |
|
|
|
for i in range(embeddings.shape[0]): |
|
start = l[i] |
|
end = l[i + 1] |
|
for j in indices[start:end]: |
|
same_mapping[int(i)].append(int(j)) |
|
|
|
groups = self._get_connected_components(same_mapping) |
|
non_uniques = set() |
|
for g in groups: |
|
for e in g[1:]: |
|
non_uniques.add(e) |
|
|
|
return set(list(non_uniques)) |
|
|
|
def query( |
|
self, query: np.matrix, modality: str = "text", num_samples: int = 10 |
|
) -> list[list[dict]]: |
|
"""Query the database. |
|
|
|
Args: |
|
query (np.matrix): Query to search. |
|
modality (str): Modality to search. One of `image` or `text`. Default to `text`. |
|
num_samples (int): Number of samples to return. Default is 40. |
|
""" |
|
index = self._image_index if modality == "image" else self._text_index |
|
|
|
distances, indices, embeddings = index.search_and_reconstruct(query, num_samples) |
|
results = [indices[i] for i in range(len(indices))] |
|
|
|
nb_results = [np.where(r == -1)[0] for r in results] |
|
total_distances = [] |
|
total_indices = [] |
|
total_embeddings = [] |
|
for i in range(len(results)): |
|
num_res = nb_results[i][0] if len(nb_results[i]) > 0 else len(results[i]) |
|
|
|
result_indices = results[i][:num_res] |
|
result_distances = distances[i][:num_res] |
|
result_embeddings = embeddings[i][:num_res] |
|
|
|
|
|
l2 = np.atleast_1d(np.linalg.norm(result_embeddings, 2, -1)) |
|
l2[l2 == 0] = 1 |
|
result_embeddings = result_embeddings / np.expand_dims(l2, -1) |
|
|
|
|
|
local_indices_to_remove = self._deduplicate_embeddings(result_embeddings) |
|
indices_to_remove = set() |
|
for local_index in local_indices_to_remove: |
|
indices_to_remove.add(result_indices[local_index]) |
|
|
|
curr_indices = [] |
|
curr_distances = [] |
|
curr_embeddings = [] |
|
for ind, dis, emb in zip(result_indices, result_distances, result_embeddings): |
|
if ind not in indices_to_remove: |
|
indices_to_remove.add(ind) |
|
curr_indices.append(ind) |
|
curr_distances.append(dis) |
|
curr_embeddings.append(emb) |
|
|
|
total_indices.append(curr_indices) |
|
total_distances.append(curr_distances) |
|
total_embeddings.append(curr_embeddings) |
|
|
|
if len(total_distances) == 0: |
|
return [] |
|
|
|
total_results = [] |
|
for i in range(len(total_distances)): |
|
results = self._map_to_metadata( |
|
total_indices[i], total_distances[i], total_embeddings[i], num_samples |
|
) |
|
total_results.append(results) |
|
|
|
return total_results |
|
|