import torch import torch.nn.functional as F from transformers import DistilBertTokenizer from tqdm.autonotebook import tqdm import pickle from clip_model import CLIPModel from configuration import CFG import matplotlib.pyplot as plt import cv2 def load_model(model_path): model = CLIPModel().to(CFG.device) model.load_state_dict(torch.load(model_path, map_location=CFG.device)) model.eval() return model def load_df(): with open("pickles/valid_df.pkl", 'rb') as file: valid_df = pickle.load(file) return valid_df def load_image_embeddings(): with open("pickles/image_embeddings.pkl", 'rb') as file: image_embeddings = pickle.load(file) return image_embeddings def find_matches(model, image_embeddings, query, image_filenames, n=9): tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) encoded_query = tokenizer([query]) batch = { key: torch.tensor(values).to(CFG.device) for key, values in encoded_query.items() } with torch.no_grad(): text_features = model.text_encoder( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] ) text_embeddings = model.text_projection(text_features) image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) dot_similarity = text_embeddings_n @ image_embeddings_n.T values, indices = torch.topk(dot_similarity.squeeze(0), n * 5) matches = [image_filenames[idx] for idx in indices[::5]] _, axes = plt.subplots(3, 3, figsize=(10, 10)) for match, ax in zip(matches, axes.flatten()): image = cv2.imread(f"Images/{match}") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) ax.imshow(image) ax.axis("off") plt.show() def inference(query): valid_df = load_df() image_embeddings = load_image_embeddings() find_matches(load_model(model_path="model/best.pt"), image_embeddings, query=query, image_filenames=valid_df['image'].values, n=9)