|
import copy |
|
|
|
from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoTokenizer |
|
import torch.nn as nn |
|
import torch |
|
from typing import List |
|
|
|
|
|
class AMC_OPT_conf(PretrainedConfig): |
|
model_type = "bert" |
|
def __init__(self, |
|
out_labels=3, |
|
emb_size=768, |
|
drop_out=0.1, |
|
pretrained_model="m-polignano-uniba/bert_uncased_L-12_H-768_A-12_italian_alb3rt0", |
|
**kwargs): |
|
self.out_labels = out_labels |
|
self.drop_out = drop_out |
|
self.emb_size = emb_size |
|
self.pretrained_model = pretrained_model |
|
self.fine_tuned_model = "data/models/amc_opt_msmd_nocoadapt.pt" |
|
super().__init__(**kwargs) |
|
|
|
|
|
class AMC_OPT_sub(nn.Module): |
|
def __init__(self, config): |
|
super(AMC_OPT_sub, self).__init__() |
|
|
|
self.model = copy.deepcopy(AutoModel.from_pretrained(config.pretrained_model)) |
|
self.dropout1 = nn.Dropout(config.drop_out) |
|
self.linear1 = nn.Linear(config.emb_size, config.out_labels) |
|
|
|
self.loss_fct = nn.CrossEntropyLoss() |
|
self.hyper_params = {'learning_rate': 6.599917952321265e-05, |
|
'weight_decay': 0.02157165894420757, |
|
'warmup_steps': 0.8999999999999999, |
|
'num_epochs': 11} |
|
self.params_name = "alberto_multiclass_opt_msmd.pt" |
|
|
|
def forward(self, labels, input_ids, attention_mask, **args): |
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **args) |
|
x = self.dropout1(outputs[1]) |
|
logits = self.linear1(x) |
|
if labels != None: |
|
loss = self.loss_fct(logits, labels) |
|
return {'logits':logits, 'loss':loss} |
|
else: |
|
return {'logits':logits} |
|
|
|
class AMC_OPT(PreTrainedModel): |
|
config_class = AMC_OPT_conf |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = AMC_OPT_sub(config) |
|
self.model.load_state_dict(torch.load(config.fine_tuned_model)) |
|
|
|
def forward(self, labels, input_ids, attention_mask, **args): |
|
return self.model(labels, input_ids, attention_mask, **args) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("m-polignano-uniba/bert_uncased_L-12_H-768_A-12_italian_alb3rt0") |
|
tokenizer.model_max_length = 128 |