Sentence-Entailment / models.py
Clemet's picture
Upload 6 files
f2a699d verified
raw
history blame
No virus
1.34 kB
from torch import nn
from transformers import RobertaModel, RobertaConfig
class RobertaSNLI(nn.Module):
def __init__(self):
super(RobertaSNLI, self).__init__()
config = RobertaConfig.from_pretrained('roberta-base')
config.output_attentions = True # activer sortie des poids d'attention
config.max_position_embeddings = 130 # gérer la longueur des séquences
config.hidden_size = 256 # taille des états cachés du modèle
config.num_hidden_layers = 4 # nombre de couches cachées dans le transformateur
config.intermediate_size = 512 # taille couche intermédiaire dans modèle de transformateur
config.num_attention_heads = 4 # nombre de têtes d'attentions
self.roberta = RobertaModel(config)
self.roberta.requires_grad = True
self.output = nn.Linear(256, 3) # couche de sortie linéaire. Entrée la taille des états cachées et 3 sorties
def forward(self, input_ids, attention_mask=None):
outputs = self.roberta(input_ids, attention_mask=attention_mask)
roberta_out = outputs[0] # séquence des états cachés à la sortie de la dernière couche
attentions = outputs.attentions # poids d'attention du modèle RoBERTa
return self.output(roberta_out[:, 0]), attentions