HemanM commited on
Commit
2c48bcc
·
verified ·
1 Parent(s): db0ea86

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +9 -5
model.py CHANGED
@@ -1,12 +1,16 @@
1
  import torch.nn as nn
2
 
3
- class EvoTransformer(nn.Module):
4
- def __init__(self, d_model=768, n_classes=2):
5
- super(EvoTransformer, self).__init__()
6
  self.classifier = nn.Sequential(
7
- nn.Linear(d_model, 384),
8
  nn.ReLU(),
9
- nn.Linear(384, n_classes)
 
 
 
 
10
  )
11
 
12
  def forward(self, x):
 
1
  import torch.nn as nn
2
 
3
+ class EvoTransformerArabic(nn.Module):
4
+ def __init__(self, d_model=768, hidden_dim=1024, n_classes=2, dropout=0.1):
5
+ super(EvoTransformerArabic, self).__init__()
6
  self.classifier = nn.Sequential(
7
+ nn.Linear(d_model, hidden_dim),
8
  nn.ReLU(),
9
+ nn.Dropout(dropout),
10
+ nn.Linear(hidden_dim, hidden_dim // 2),
11
+ nn.ReLU(),
12
+ nn.Dropout(dropout),
13
+ nn.Linear(hidden_dim // 2, n_classes)
14
  )
15
 
16
  def forward(self, x):