anthony.galtier
Added light code files
06a851e
raw
history blame
827 Bytes
import torch.nn as nn
from transformers import CamembertModel
class CamembertRegressor(nn.Module):
def __init__(self, drop_rate=0.2, freeze_camembert=True):
super(CamembertRegressor, self).__init__()
D_in, D_out = 768, 1
self.camembert = CamembertModel.from_pretrained('camembert-base')
self.regressor = nn.Sequential(
nn.Dropout(drop_rate),
nn.Linear(D_in, D_out))
if freeze_camembert:
for param in self.camembert.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_masks):
outputs = self.camembert(input_ids, attention_masks)
outputs_cls = outputs[1]
outputs = self.regressor(outputs_cls)
return outputs