Update README.md
Browse filesimport torch
from torch import nn
# Definir la red, lo que en definitiva la hace siamesa es pasar los dos dic
class SBERT(nn.Module):
def __init__(self, base_model, dropout=0.1):
super().__init__()
self.base_model = base_model
self.dropout = nn.Dropout(dropout)
# Recordamos que la salida de la Bert es 768
#self.fc = nn.Linear(768, 3) #Cambio 13/6
self.fc = nn.Linear(768*3, 3)
def forward(self, premise, hypothesis):
out_u = self.base_model(**premise)
out_v = self.base_model(**hypothesis)
pooler_u = out_u.pooler_output
pooler_v = out_v.pooler_output
pooler_u = self.dropout(pooler_u)
pooler_v = self.dropout(pooler_v)
#concatenated = torch.cat([self.fc(pooler_u), self.fc(pooler_v), torch.abs(self.fc(pooler_u) - self.fc(pooler_v))], dim=0)
concatenated = torch.cat([pooler_u, pooler_v, torch.abs(pooler_u -pooler_v)], dim=1)
out=self.fc(concatenated)
return out
@@ -8,26 +8,3 @@ pipeline_tag: fill-mask
|
|
8 |
tags:
|
9 |
- code
|
10 |
---
|
11 |
-
import torch
|
12 |
-
from torch import nn
|
13 |
-
# Definir la red, lo que en definitiva la hace siamesa es pasar los dos dic
|
14 |
-
class SBERT(nn.Module):
|
15 |
-
def __init__(self, base_model, dropout=0.1):
|
16 |
-
super().__init__()
|
17 |
-
self.base_model = base_model
|
18 |
-
self.dropout = nn.Dropout(dropout)
|
19 |
-
# Recordamos que la salida de la Bert es 768
|
20 |
-
#self.fc = nn.Linear(768, 3) #Cambio 13/6
|
21 |
-
self.fc = nn.Linear(768*3, 3)
|
22 |
-
|
23 |
-
def forward(self, premise, hypothesis):
|
24 |
-
out_u = self.base_model(**premise)
|
25 |
-
out_v = self.base_model(**hypothesis)
|
26 |
-
pooler_u = out_u.pooler_output
|
27 |
-
pooler_v = out_v.pooler_output
|
28 |
-
pooler_u = self.dropout(pooler_u)
|
29 |
-
pooler_v = self.dropout(pooler_v)
|
30 |
-
#concatenated = torch.cat([self.fc(pooler_u), self.fc(pooler_v), torch.abs(self.fc(pooler_u) - self.fc(pooler_v))], dim=0)
|
31 |
-
concatenated = torch.cat([pooler_u, pooler_v, torch.abs(pooler_u -pooler_v)], dim=1)
|
32 |
-
out=self.fc(concatenated)
|
33 |
-
return out
|
|
|
8 |
tags:
|
9 |
- code
|
10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|