bala1802 commited on
Commit
d4e8957
1 Parent(s): 83526ef

Upload 3 files

Browse files

added inferencing files

Files changed (3) hide show
  1. clip_inferencing.py +65 -0
  2. clip_model.py +53 -0
  3. configuration.py +33 -0
clip_inferencing.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import DistilBertTokenizer
4
+ from tqdm.autonotebook import tqdm
5
+ import pickle
6
+
7
+ from clip_model import CLIPModel
8
+ from configuration import CFG
9
+
10
+ import matplotlib.pyplot as plt
11
+ import cv2
12
+
13
+ def load_model(model_path):
14
+ model = CLIPModel().to(CFG.device)
15
+ model.load_state_dict(torch.load(model_path, map_location=CFG.device))
16
+ model.eval()
17
+ return model
18
+
19
+ def load_df():
20
+ with open("pickles/valid_df.pkl", 'rb') as file:
21
+ valid_df = pickle.load(file)
22
+ return valid_df
23
+
24
+ def load_image_embeddings():
25
+ with open("pickles/image_embeddings.pkl", 'rb') as file:
26
+ image_embeddings = pickle.load(file)
27
+ return image_embeddings
28
+
29
+ def find_matches(model, image_embeddings, query, image_filenames, n=9):
30
+ tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
31
+ encoded_query = tokenizer([query])
32
+ batch = {
33
+ key: torch.tensor(values).to(CFG.device)
34
+ for key, values in encoded_query.items()
35
+ }
36
+ with torch.no_grad():
37
+ text_features = model.text_encoder(
38
+ input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
39
+ )
40
+ text_embeddings = model.text_projection(text_features)
41
+
42
+ image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
43
+ text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
44
+ dot_similarity = text_embeddings_n @ image_embeddings_n.T
45
+
46
+ values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
47
+ matches = [image_filenames[idx] for idx in indices[::5]]
48
+
49
+ _, axes = plt.subplots(3, 3, figsize=(10, 10))
50
+ for match, ax in zip(matches, axes.flatten()):
51
+ image = cv2.imread(f"Images/{match}")
52
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
53
+ ax.imshow(image)
54
+ ax.axis("off")
55
+
56
+ plt.show()
57
+
58
+ def inference():
59
+ valid_df = load_df()
60
+ image_embeddings = load_image_embeddings()
61
+ find_matches(load_model(model_path="model/best.pt"),
62
+ image_embeddings,
63
+ query="dogs on the grass",
64
+ image_filenames=valid_df['image'].values, n=9)
65
+
clip_model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as F
3
+
4
+ from image_encoder import ImageEncoder
5
+ from text_encoder import TextEncoder
6
+ from projection_head import ProjectionHead
7
+ from configuration import CFG
8
+
9
+
10
+ class CLIPModel(nn.Module):
11
+ def __init__(
12
+ self,
13
+ temperature=CFG.temperature,
14
+ image_embedding=CFG.image_embedding,
15
+ text_embedding=CFG.text_embedding,
16
+ ):
17
+ super().__init__()
18
+ self.image_encoder = ImageEncoder()
19
+ self.text_encoder = TextEncoder()
20
+ self.image_projection = ProjectionHead(embedding_dim=image_embedding)
21
+ self.text_projection = ProjectionHead(embedding_dim=text_embedding)
22
+ self.temperature = temperature
23
+
24
+ def forward(self, batch):
25
+ # Getting Image and Text Features
26
+ image_features = self.image_encoder(batch["image"])
27
+ text_features = self.text_encoder(
28
+ input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
29
+ )
30
+ # Getting Image and Text Embeddings (with same dimension)
31
+ image_embeddings = self.image_projection(image_features)
32
+ text_embeddings = self.text_projection(text_features)
33
+
34
+ # Calculating the Loss
35
+ logits = (text_embeddings @ image_embeddings.T) / self.temperature
36
+ images_similarity = image_embeddings @ image_embeddings.T
37
+ texts_similarity = text_embeddings @ text_embeddings.T
38
+ targets = F.softmax(
39
+ (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
40
+ )
41
+ texts_loss = cross_entropy(logits, targets, reduction='none')
42
+ images_loss = cross_entropy(logits.T, targets.T, reduction='none')
43
+ loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
44
+ return loss.mean()
45
+
46
+
47
+ def cross_entropy(preds, targets, reduction='none'):
48
+ log_softmax = nn.LogSoftmax(dim=-1)
49
+ loss = (-targets * log_softmax(preds)).sum(1)
50
+ if reduction == "none":
51
+ return loss
52
+ elif reduction == "mean":
53
+ return loss.mean()
configuration.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class CFG:
4
+ debug = False
5
+ batch_size = 32
6
+ num_workers = 2
7
+ head_lr = 1e-3
8
+ image_encoder_lr = 1e-4
9
+ text_encoder_lr = 1e-5
10
+ weight_decay = 1e-3
11
+ patience = 1
12
+ factor = 0.8
13
+ epochs = 1 #4
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ model_name = 'resnet50'
17
+ image_embedding = 2048
18
+ text_encoder_model = "distilbert-base-uncased"
19
+ text_embedding = 768
20
+ text_tokenizer = "distilbert-base-uncased"
21
+ max_length = 200
22
+
23
+ pretrained = True # for both image encoder and text encoder
24
+ trainable = True # for both image encoder and text encoder
25
+ temperature = 1.0
26
+
27
+ # image size
28
+ size = 224
29
+
30
+ # for projection head; used for both image and text encoders
31
+ num_projection_layers = 1
32
+ projection_dim = 256
33
+ dropout = 0.1