File size: 827 Bytes
06a851e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch.nn as nn
from transformers import CamembertModel


class CamembertRegressor(nn.Module):
    
    def __init__(self, drop_rate=0.2, freeze_camembert=True):
        
        super(CamembertRegressor, self).__init__()
        D_in, D_out = 768, 1
        
        self.camembert = CamembertModel.from_pretrained('camembert-base')
        self.regressor = nn.Sequential(
            nn.Dropout(drop_rate),
            nn.Linear(D_in, D_out))
        
        if freeze_camembert:
            for param in self.camembert.parameters():
                param.requires_grad = False
                
    def forward(self, input_ids, attention_masks):
        
        outputs = self.camembert(input_ids, attention_masks)
        outputs_cls = outputs[1]
        outputs = self.regressor(outputs_cls)
        return outputs