Spaces:
Sleeping
Sleeping
import base64 | |
import os | |
from io import BytesIO | |
import numpy as np | |
import torch | |
from PIL import Image | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from img2art_search.data.dataset import ImageRetrievalDataset | |
from img2art_search.data.transforms import transform | |
from img2art_search.models.model import ViTImageSearchModel | |
from img2art_search.utils import ( | |
get_or_create_pinecone_index, | |
get_pinecone_client, | |
inverse_transform_img, | |
) | |
def extract_embedding(image_data_batch, fine_tuned_model): | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
image_data_batch = image_data_batch.to(DEVICE) | |
with torch.no_grad(): | |
embeddings = fine_tuned_model(image_data_batch).cpu().numpy() | |
return embeddings | |
def load_fine_tuned_model(): | |
fine_tuned_model = ViTImageSearchModel() | |
fine_tuned_model.load_state_dict(torch.load("models/model.pth")) | |
fine_tuned_model.eval() | |
return fine_tuned_model | |
def create_gallery( | |
img_dataset: ImageRetrievalDataset, | |
fine_tuned_model: ViTImageSearchModel, | |
save: bool = True, | |
) -> list: | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
batch_size = 4 | |
fine_tuned_model.to(DEVICE) | |
gallery_embeddings = [] | |
gallery_dataloader = DataLoader( | |
img_dataset, batch_size=batch_size, num_workers=1, shuffle=False | |
) | |
pc = get_pinecone_client() | |
gallery_index = get_or_create_pinecone_index(pc) | |
try: | |
count = 0 | |
for img_data, _, img_name, _ in tqdm(gallery_dataloader): | |
data_objects = [] | |
batch_embedding = extract_embedding(img_data, fine_tuned_model) | |
gallery_embeddings.append(batch_embedding) | |
for idx, embedding in enumerate(batch_embedding): | |
image = Image.fromarray( | |
inverse_transform_img(img_data[idx]).numpy().astype("uint8"), "RGB" | |
) | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
data_objects.append( | |
{ | |
"id": str(count), | |
"values": embedding.tolist(), | |
"metadata": { | |
"image": img_base64, | |
"name": img_name[idx] | |
.split("/")[-1] | |
.replace(".jpg", "") | |
.replace(".jpeg", "") | |
.replace(".png", "") | |
.replace(".JPG", "") | |
.replace(".JPEG", "") | |
.replace("-", " ") | |
.replace("_", " - ") | |
.title(), | |
}, | |
} | |
) | |
count += 1 | |
gallery_index.upsert(vectors=data_objects) | |
except Exception as e: | |
print(f"Error creating gallery: {e}") | |
if save: | |
np.save("models/embeddings", gallery_embeddings) | |
return gallery_embeddings | |
def search_image(query_image_path: str, k: int = 4) -> tuple: | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
fine_tuned_model = load_fine_tuned_model() | |
fine_tuned_model.to(DEVICE) | |
query_embedding = extract_embedding(query_image_path, fine_tuned_model) | |
pc = get_pinecone_client() | |
index = get_or_create_pinecone_index(pc) | |
response = index.query( | |
vector=[query_embedding.tolist()[0]], top_k=k, include_metadata=True | |
) | |
distances = [] | |
results = [] | |
for obj in response["matches"]: | |
result = base64.b64decode(obj.metadata["image"]) | |
results.append(result) | |
distances.append( | |
str(round(obj["score"], 2) * 100) + " " + str(obj.metadata["name"]) | |
) | |
return results, distances | |
def create_gallery_embeddings(folder: str) -> None: | |
x = np.array([f"{folder}/{file}" for file in os.listdir(folder)]) | |
gallery_data = np.array([x, x]) | |
gallery_dataset = ImageRetrievalDataset(gallery_data, transform=transform) | |
fine_tuned_model = load_fine_tuned_model() | |
create_gallery(gallery_dataset, fine_tuned_model) | |