bala1802 commited on
Commit
b26e169
1 Parent(s): b56e323

Update clip_inferencing.py

Browse files
Files changed (1) hide show
  1. clip_inferencing.py +0 -42
clip_inferencing.py CHANGED
@@ -1,15 +1,10 @@
1
  import torch
2
  import torch.nn.functional as F
3
- from transformers import DistilBertTokenizer
4
- from tqdm.autonotebook import tqdm
5
  import pickle
6
 
7
  from clip_model import CLIPModel
8
  from configuration import CFG
9
 
10
- import matplotlib.pyplot as plt
11
- import cv2
12
-
13
  def load_model(model_path):
14
  model = CLIPModel().to(CFG.device)
15
  model.load_state_dict(torch.load(model_path, map_location=CFG.device))
@@ -25,41 +20,4 @@ def load_image_embeddings():
25
  with open("pickles/image_embeddings.pkl", 'rb') as file:
26
  image_embeddings = pickle.load(file)
27
  return image_embeddings
28
-
29
- def find_matches(model, image_embeddings, query, image_filenames, n=9):
30
- tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
31
- encoded_query = tokenizer([query])
32
- batch = {
33
- key: torch.tensor(values).to(CFG.device)
34
- for key, values in encoded_query.items()
35
- }
36
- with torch.no_grad():
37
- text_features = model.text_encoder(
38
- input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
39
- )
40
- text_embeddings = model.text_projection(text_features)
41
-
42
- image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
43
- text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
44
- dot_similarity = text_embeddings_n @ image_embeddings_n.T
45
-
46
- values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
47
- matches = [image_filenames[idx] for idx in indices[::5]]
48
-
49
- _, axes = plt.subplots(3, 3, figsize=(10, 10))
50
- for match, ax in zip(matches, axes.flatten()):
51
- image = cv2.imread(f"Images/{match}")
52
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
53
- ax.imshow(image)
54
- ax.axis("off")
55
-
56
- plt.show()
57
-
58
- def inference(query):
59
- valid_df = load_df()
60
- image_embeddings = load_image_embeddings()
61
- find_matches(load_model(model_path="model/best.pt"),
62
- image_embeddings,
63
- query=query,
64
- image_filenames=valid_df['image'].values, n=9)
65
 
 
1
  import torch
2
  import torch.nn.functional as F
 
 
3
  import pickle
4
 
5
  from clip_model import CLIPModel
6
  from configuration import CFG
7
 
 
 
 
8
  def load_model(model_path):
9
  model = CLIPModel().to(CFG.device)
10
  model.load_state_dict(torch.load(model_path, map_location=CFG.device))
 
20
  with open("pickles/image_embeddings.pkl", 'rb') as file:
21
  image_embeddings = pickle.load(file)
22
  return image_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23