Brendar commited on
Commit
878acd5
1 Parent(s): 99303a6

Update README.md

Browse files

import torch
from torch import nn
# Definir la red, lo que en definitiva la hace siamesa es pasar los dos dic
class SBERT(nn.Module):
def __init__(self, base_model, dropout=0.1):
super().__init__()
self.base_model = base_model
self.dropout = nn.Dropout(dropout)
# Recordamos que la salida de la Bert es 768
#self.fc = nn.Linear(768, 3) #Cambio 13/6
self.fc = nn.Linear(768*3, 3)

def forward(self, premise, hypothesis):
out_u = self.base_model(**premise)
out_v = self.base_model(**hypothesis)
pooler_u = out_u.pooler_output
pooler_v = out_v.pooler_output
pooler_u = self.dropout(pooler_u)
pooler_v = self.dropout(pooler_v)
#concatenated = torch.cat([self.fc(pooler_u), self.fc(pooler_v), torch.abs(self.fc(pooler_u) - self.fc(pooler_v))], dim=0)
concatenated = torch.cat([pooler_u, pooler_v, torch.abs(pooler_u -pooler_v)], dim=1)
out=self.fc(concatenated)
return out

Files changed (1) hide show
  1. README.md +0 -23
README.md CHANGED
@@ -8,26 +8,3 @@ pipeline_tag: fill-mask
8
  tags:
9
  - code
10
  ---
11
- import torch
12
- from torch import nn
13
- # Definir la red, lo que en definitiva la hace siamesa es pasar los dos dic
14
- class SBERT(nn.Module):
15
- def __init__(self, base_model, dropout=0.1):
16
- super().__init__()
17
- self.base_model = base_model
18
- self.dropout = nn.Dropout(dropout)
19
- # Recordamos que la salida de la Bert es 768
20
- #self.fc = nn.Linear(768, 3) #Cambio 13/6
21
- self.fc = nn.Linear(768*3, 3)
22
-
23
- def forward(self, premise, hypothesis):
24
- out_u = self.base_model(**premise)
25
- out_v = self.base_model(**hypothesis)
26
- pooler_u = out_u.pooler_output
27
- pooler_v = out_v.pooler_output
28
- pooler_u = self.dropout(pooler_u)
29
- pooler_v = self.dropout(pooler_v)
30
- #concatenated = torch.cat([self.fc(pooler_u), self.fc(pooler_v), torch.abs(self.fc(pooler_u) - self.fc(pooler_v))], dim=0)
31
- concatenated = torch.cat([pooler_u, pooler_v, torch.abs(pooler_u -pooler_v)], dim=1)
32
- out=self.fc(concatenated)
33
- return out
 
8
  tags:
9
  - code
10
  ---