Brendar commited on
Commit
99303a6
1 Parent(s): b8785a1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +24 -1
README.md CHANGED
@@ -7,4 +7,27 @@ library_name: transformers
7
  pipeline_tag: fill-mask
8
  tags:
9
  - code
10
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  pipeline_tag: fill-mask
8
  tags:
9
  - code
10
+ ---
11
+ import torch
12
+ from torch import nn
13
+ # Definir la red, lo que en definitiva la hace siamesa es pasar los dos dic
14
+ class SBERT(nn.Module):
15
+ def __init__(self, base_model, dropout=0.1):
16
+ super().__init__()
17
+ self.base_model = base_model
18
+ self.dropout = nn.Dropout(dropout)
19
+ # Recordamos que la salida de la Bert es 768
20
+ #self.fc = nn.Linear(768, 3) #Cambio 13/6
21
+ self.fc = nn.Linear(768*3, 3)
22
+
23
+ def forward(self, premise, hypothesis):
24
+ out_u = self.base_model(**premise)
25
+ out_v = self.base_model(**hypothesis)
26
+ pooler_u = out_u.pooler_output
27
+ pooler_v = out_v.pooler_output
28
+ pooler_u = self.dropout(pooler_u)
29
+ pooler_v = self.dropout(pooler_v)
30
+ #concatenated = torch.cat([self.fc(pooler_u), self.fc(pooler_v), torch.abs(self.fc(pooler_u) - self.fc(pooler_v))], dim=0)
31
+ concatenated = torch.cat([pooler_u, pooler_v, torch.abs(pooler_u -pooler_v)], dim=1)
32
+ out=self.fc(concatenated)
33
+ return out