amaralibey commited on
Commit
2a7a19d
·
verified ·
1 Parent(s): df4e74c

Create text_encoder.py

Browse files
Files changed (1) hide show
  1. text_encoder.py +29 -0
text_encoder.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoModel
4
+
5
+ class TextEncoder(nn.Module):
6
+ def __init__(self, output_dim=64, lang_model="sentence-transformers/all-MiniLM-L6-v2", unfreeze_n_blocks=4):
7
+ super().__init__()
8
+ self.lang_model = lang_model
9
+ self.encoder = AutoModel.from_pretrained(lang_model)
10
+
11
+ # freeze all parameters
12
+ for param in self.encoder.parameters():
13
+ param.requires_grad = False
14
+
15
+ # unfreeze the last few encoder layers
16
+ for layer in self.encoder.encoder.layer[ - unfreeze_n_blocks :]:
17
+ for param in layer.parameters():
18
+ param.requires_grad = True
19
+
20
+ # unfreeze the pooler layer
21
+ for param in self.encoder.pooler.parameters():
22
+ param.requires_grad = True
23
+
24
+ self.fc = nn.Linear(self.encoder.config.hidden_size, output_dim)
25
+
26
+ def forward(self, input_ids, attention_mask=None):
27
+ x = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
28
+ x = self.fc(x)
29
+ return x