Spaces:
Running
Running
Create text_encoder.py
Browse files- text_encoder.py +29 -0
text_encoder.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import AutoModel
|
4 |
+
|
5 |
+
class TextEncoder(nn.Module):
|
6 |
+
def __init__(self, output_dim=64, lang_model="sentence-transformers/all-MiniLM-L6-v2", unfreeze_n_blocks=4):
|
7 |
+
super().__init__()
|
8 |
+
self.lang_model = lang_model
|
9 |
+
self.encoder = AutoModel.from_pretrained(lang_model)
|
10 |
+
|
11 |
+
# freeze all parameters
|
12 |
+
for param in self.encoder.parameters():
|
13 |
+
param.requires_grad = False
|
14 |
+
|
15 |
+
# unfreeze the last few encoder layers
|
16 |
+
for layer in self.encoder.encoder.layer[ - unfreeze_n_blocks :]:
|
17 |
+
for param in layer.parameters():
|
18 |
+
param.requires_grad = True
|
19 |
+
|
20 |
+
# unfreeze the pooler layer
|
21 |
+
for param in self.encoder.pooler.parameters():
|
22 |
+
param.requires_grad = True
|
23 |
+
|
24 |
+
self.fc = nn.Linear(self.encoder.config.hidden_size, output_dim)
|
25 |
+
|
26 |
+
def forward(self, input_ids, attention_mask=None):
|
27 |
+
x = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
|
28 |
+
x = self.fc(x)
|
29 |
+
return x
|