Update README.md
Browse files
README.md
CHANGED
@@ -7,4 +7,27 @@ library_name: transformers
|
|
7 |
pipeline_tag: fill-mask
|
8 |
tags:
|
9 |
- code
|
10 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
pipeline_tag: fill-mask
|
8 |
tags:
|
9 |
- code
|
10 |
+
---
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
# Definir la red, lo que en definitiva la hace siamesa es pasar los dos dic
|
14 |
+
class SBERT(nn.Module):
|
15 |
+
def __init__(self, base_model, dropout=0.1):
|
16 |
+
super().__init__()
|
17 |
+
self.base_model = base_model
|
18 |
+
self.dropout = nn.Dropout(dropout)
|
19 |
+
# Recordamos que la salida de la Bert es 768
|
20 |
+
#self.fc = nn.Linear(768, 3) #Cambio 13/6
|
21 |
+
self.fc = nn.Linear(768*3, 3)
|
22 |
+
|
23 |
+
def forward(self, premise, hypothesis):
|
24 |
+
out_u = self.base_model(**premise)
|
25 |
+
out_v = self.base_model(**hypothesis)
|
26 |
+
pooler_u = out_u.pooler_output
|
27 |
+
pooler_v = out_v.pooler_output
|
28 |
+
pooler_u = self.dropout(pooler_u)
|
29 |
+
pooler_v = self.dropout(pooler_v)
|
30 |
+
#concatenated = torch.cat([self.fc(pooler_u), self.fc(pooler_v), torch.abs(self.fc(pooler_u) - self.fc(pooler_v))], dim=0)
|
31 |
+
concatenated = torch.cat([pooler_u, pooler_v, torch.abs(pooler_u -pooler_v)], dim=1)
|
32 |
+
out=self.fc(concatenated)
|
33 |
+
return out
|