bala1802 commited on
Commit
6a3434a
1 Parent(s): 9d4e40b

Upload 3 files

Browse files
Files changed (3) hide show
  1. image_encoder.py +22 -0
  2. projection_head.py +25 -0
  3. text_encoder.py +22 -0
image_encoder.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import timm
3
+
4
+ from configuration import CFG
5
+
6
+ class ImageEncoder(nn.Module):
7
+ """
8
+ Encode images to a fixed size vector
9
+ """
10
+
11
+ def __init__(
12
+ self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
13
+ ):
14
+ super().__init__()
15
+ self.model = timm.create_model(
16
+ model_name, pretrained, num_classes=0, global_pool="avg"
17
+ )
18
+ for p in self.model.parameters():
19
+ p.requires_grad = trainable
20
+
21
+ def forward(self, x):
22
+ return self.model(x)
projection_head.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from configuration import CFG
3
+
4
+ class ProjectionHead(nn.Module):
5
+ def __init__(
6
+ self,
7
+ embedding_dim,
8
+ projection_dim=CFG.projection_dim,
9
+ dropout=CFG.dropout
10
+ ):
11
+ super().__init__()
12
+ self.projection = nn.Linear(embedding_dim, projection_dim)
13
+ self.gelu = nn.GELU()
14
+ self.fc = nn.Linear(projection_dim, projection_dim)
15
+ self.dropout = nn.Dropout(dropout)
16
+ self.layer_norm = nn.LayerNorm(projection_dim)
17
+
18
+ def forward(self, x):
19
+ projected = self.projection(x)
20
+ x = self.gelu(projected)
21
+ x = self.fc(x)
22
+ x = self.dropout(x)
23
+ x = x + projected
24
+ x = self.layer_norm(x)
25
+ return x
text_encoder.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from configuration import CFG
3
+ from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
4
+
5
+ class TextEncoder(nn.Module):
6
+ def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
7
+ super().__init__()
8
+ if pretrained:
9
+ self.model = DistilBertModel.from_pretrained(model_name)
10
+ else:
11
+ self.model = DistilBertModel(config=DistilBertConfig())
12
+
13
+ for p in self.model.parameters():
14
+ p.requires_grad = trainable
15
+
16
+ # we are using the CLS token hidden representation as the sentence's embedding
17
+ self.target_token_idx = 0
18
+
19
+ def forward(self, input_ids, attention_mask):
20
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)
21
+ last_hidden_state = output.last_hidden_state
22
+ return last_hidden_state[:, self.target_token_idx, :]