amc-opt-msmd / amc.py
brownfortress's picture
add model
d34f202
raw
history blame
2.32 kB
import copy
from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoTokenizer
import torch.nn as nn
import torch
from typing import List
class AMC_OPT_conf(PretrainedConfig):
model_type = "bert"
def __init__(self,
out_labels=3,
emb_size=768,
drop_out=0.1,
pretrained_model="m-polignano-uniba/bert_uncased_L-12_H-768_A-12_italian_alb3rt0",
**kwargs):
self.out_labels = out_labels
self.drop_out = drop_out
self.emb_size = emb_size
self.pretrained_model = pretrained_model
self.fine_tuned_model = "data/models/amc_opt_msmd_nocoadapt.pt"
super().__init__(**kwargs)
class AMC_OPT_sub(nn.Module):
def __init__(self, config):
super(AMC_OPT_sub, self).__init__()
self.model = copy.deepcopy(AutoModel.from_pretrained(config.pretrained_model))
self.dropout1 = nn.Dropout(config.drop_out)
self.linear1 = nn.Linear(config.emb_size, config.out_labels)
self.loss_fct = nn.CrossEntropyLoss()
self.hyper_params = {'learning_rate': 6.599917952321265e-05,
'weight_decay': 0.02157165894420757,
'warmup_steps': 0.8999999999999999,
'num_epochs': 11}
self.params_name = "alberto_multiclass_opt_msmd.pt"
def forward(self, labels, input_ids, attention_mask, **args):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **args)
x = self.dropout1(outputs[1])
logits = self.linear1(x)
if labels != None:
loss = self.loss_fct(logits, labels)
return {'logits':logits, 'loss':loss}
else:
return {'logits':logits}
class AMC_OPT(PreTrainedModel):
config_class = AMC_OPT_conf
def __init__(self, config):
super().__init__(config)
self.model = AMC_OPT_sub(config)
self.model.load_state_dict(torch.load(config.fine_tuned_model))
def forward(self, labels, input_ids, attention_mask, **args):
return self.model(labels, input_ids, attention_mask, **args)
tokenizer = AutoTokenizer.from_pretrained("m-polignano-uniba/bert_uncased_L-12_H-768_A-12_italian_alb3rt0")
tokenizer.model_max_length = 128