|
import torch |
|
import transformers |
|
import jellyfish |
|
from tqdm import tqdm |
|
from transformers import AutoModelForMaskedLM |
|
from .poet_utils import RHYME_SCHEMES, METER_TYPES |
|
|
|
from torch.utils.data import DataLoader, Dataset |
|
from pytorch_optimizer import SAM,GSAM, ProportionScheduler, AdamP |
|
|
|
class ValidatorInterface(torch.nn.Module): |
|
"""Pytorch Model Interface. Abstract class for all validators |
|
|
|
Args: |
|
torch (_type_): Is child of torch.nn.Module for integration with torch and huggingface |
|
""" |
|
def __init__(self, *args, **kwargs) -> None: |
|
""" Constructor. As child Class needs to construct Parent |
|
""" |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, input_ids=None, attention_mask=None, *args, **kwargs): |
|
"""Compute model output and model loss |
|
|
|
Args: |
|
input_ids (_type_, optional): Model inputs. Defaults to None. |
|
attention_mask (_type_, optional): Attention mask where padding starts. Defaults to None. |
|
|
|
Raises: |
|
NotImplementedError: Abstract class |
|
""" |
|
raise NotImplementedError() |
|
|
|
def predict(self, input_ids=None, *args, **kwargs): |
|
"""Compute model outputs |
|
|
|
Args: |
|
input_ids (_type_, optional): Model inputs. Defaults to None. |
|
|
|
Raises: |
|
NotImplementedError: Abstract class |
|
""" |
|
raise NotImplementedError() |
|
|
|
def validate(self, input_ids=None, *args, **kwargs): |
|
"""Validate model given some labels, Doesn't use loss |
|
|
|
Args: |
|
input_ids (_type_, optional): Model inputs. Defaults to None. |
|
|
|
Raises: |
|
NotImplementedError: Abstract class |
|
""" |
|
raise NotImplementedError() |
|
|
|
|
|
class RhymeValidator(ValidatorInterface): |
|
def __init__(self, pretrained_model, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
|
|
self.model = AutoModelForMaskedLM.from_pretrained(pretrained_model, output_hidden_states=True) |
|
|
|
self.config = self.model.config |
|
|
|
self.model_size = self.config.hidden_size |
|
|
|
self.rhyme_regressor = torch.nn.Linear(self.model_size, len(RHYME_SCHEMES)) |
|
|
|
self.loss_fnc = torch.nn.CrossEntropyLoss(label_smoothing=0.05) |
|
|
|
def forward(self, input_ids=None, attention_mask=None, rhyme=None, *args, **kwargs): |
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids.type(torch.LongTensor)) |
|
|
|
last_hidden = outputs['hidden_states'][-1] |
|
|
|
rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size))) |
|
|
|
softmaxed = torch.softmax(rhyme_regression, dim=1) |
|
rhyme_loss = self.loss_fnc(softmaxed, rhyme) |
|
|
|
return {"model_output" : softmaxed, |
|
"loss": rhyme_loss + outputs.loss} |
|
|
|
def predict(self, input_ids=None, *args, **kwargs): |
|
|
|
outputs = self.model(input_ids=input_ids) |
|
|
|
last_hidden = outputs['hidden_states'][-1] |
|
|
|
rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size))) |
|
|
|
softmaxed = torch.softmax(rhyme_regression, dim=1) |
|
|
|
return softmaxed |
|
|
|
def validate(self, input_ids=None, rhyme=None, k:int = 2,*args, **kwargs): |
|
outputs = self.model(input_ids=input_ids) |
|
|
|
last_hidden = outputs['hidden_states'][-1] |
|
|
|
rhyme_regression = self.rhyme_regressor((last_hidden[:,0,:].view(-1, self.model_size))) |
|
|
|
softmaxed = torch.softmax(rhyme_regression, dim=1) |
|
|
|
softmaxed = softmaxed.flatten() |
|
|
|
predicted_val = torch.argmax(softmaxed) |
|
|
|
predicted_top_k = torch.topk(softmaxed, k).indices |
|
|
|
label_val = torch.argmax(rhyme.flatten()) |
|
|
|
validation_true_val = (label_val == predicted_val).float().sum().numpy() |
|
top_k_presence = 0 |
|
if label_val in predicted_top_k: |
|
top_k_presence = 1 |
|
|
|
levenshtein = jellyfish.levenshtein_distance(RHYME_SCHEMES[predicted_val] if RHYME_SCHEMES[predicted_val] != None else "", RHYME_SCHEMES[label_val] if RHYME_SCHEMES[label_val] != None else "") |
|
|
|
hit_pred = softmaxed[label_val].detach().numpy() |
|
|
|
return {"acc" : validation_true_val, |
|
"top_k" : top_k_presence, |
|
"lev_distance": levenshtein, |
|
"predicted_label" : hit_pred |
|
} |
|
|
|
|
|
|
|
class MeterValidator(ValidatorInterface): |
|
def __init__(self, pretrained_model, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
self.model = AutoModelForMaskedLM.from_pretrained(pretrained_model, output_hidden_states=True) |
|
|
|
self.config = self.model.config |
|
|
|
self.model_size = self.config.hidden_size |
|
|
|
self.meter_regressor = torch.nn.Linear(self.model_size, len(METER_TYPES)) |
|
|
|
self.loss_fnc = torch.nn.CrossEntropyLoss(label_smoothing=0.05) |
|
|
|
def forward(self, input_ids=None, attention_mask=None, metre=None, *args, **kwargs): |
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids.type(torch.LongTensor)) |
|
|
|
last_hidden = outputs['hidden_states'][-1] |
|
|
|
meter_regression = self.meter_regressor((last_hidden[:,0,:].view(-1, self.model_size))) |
|
|
|
softmaxed = torch.softmax(meter_regression, dim=1) |
|
meter_loss = self.loss_fnc(softmaxed, metre) |
|
|
|
return {"model_output" : softmaxed, |
|
"loss": meter_loss + outputs.loss} |
|
|
|
def predict(self, input_ids=None, *args, **kwargs): |
|
outputs = self.model(input_ids=input_ids) |
|
|
|
last_hidden = outputs['hidden_states'][-1] |
|
|
|
meter_regression = self.meter_regressor((last_hidden[:,0,:].view(-1, self.model_size))) |
|
|
|
softmaxed = torch.softmax(meter_regression, dim=1) |
|
|
|
return softmaxed |
|
|
|
def validate(self, input_ids=None, metre=None, k: int=2,*args, **kwargs): |
|
outputs = self.model(input_ids=input_ids) |
|
|
|
last_hidden = outputs['hidden_states'][-1] |
|
|
|
meter_regression = self.meter_regressor((last_hidden[:,0,:].view(-1, self.model_size))) |
|
|
|
softmaxed = torch.softmax(meter_regression, dim=1) |
|
|
|
softmaxed = softmaxed.flatten() |
|
|
|
predicted_val = torch.argmax(softmaxed) |
|
|
|
predicted_top_k = torch.topk(softmaxed, k).indices |
|
|
|
label_val = torch.argmax(metre.flatten()) |
|
|
|
validation_true_val = (label_val == predicted_val).float().sum().numpy() |
|
top_k_presence = 0 |
|
if label_val in predicted_top_k: |
|
top_k_presence = 1 |
|
|
|
hit_pred = softmaxed[label_val].detach().numpy() |
|
|
|
return {"acc" : validation_true_val, |
|
"top_k" : top_k_presence, |
|
"predicted_label" : hit_pred |
|
} |
|
|
|
|
|
class ValidatorTrainer: |
|
def __init__(self, model: ValidatorInterface, args: dict, train_dataset: Dataset, data_collator, device): |
|
self.model = model |
|
self.args = args |
|
self.epochs = 1 if "epochs" not in args.keys() else args["epochs"] |
|
self.batch_size = 1 if "batch_size" not in args.keys() else args["batch_size"] |
|
self.lr = 3e-4 if "lr" not in args.keys() else args["lr"] |
|
self.weight_decay = 0.0 if "weight_decay" not in args.keys() else args['weight_decay'] |
|
|
|
self.train_loader = DataLoader(train_dataset, self.batch_size, True, collate_fn=data_collator) |
|
|
|
|
|
self.device = device |
|
self.optimizer = SAM(self.model.parameters(), torch.optim.AdamW, lr=self.lr, weight_decay=self.weight_decay) |
|
self.scheduler = transformers.get_constant_schedule_with_warmup(self.optimizer, len(train_dataset)//self.batch_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(self): |
|
for epoch in tqdm(range(self.epochs)): |
|
self.model.train() |
|
|
|
|
|
|
|
for step, batch in enumerate(self.train_loader): |
|
|
|
loss = self.model(input_ids=batch["input_ids"].to(self.device), attention_mask=batch["attention_mask"].to(self.device), |
|
rhyme = None if batch["rhyme"] == None else batch["rhyme"].to(self.device), |
|
metre = None if batch["metre"] == None else batch["metre"].to(self.device))['loss'] |
|
loss.backward() |
|
self.optimizer.first_step(zero_grad=True) |
|
|
|
loss = self.model(input_ids=batch["input_ids"].to(self.device), attention_mask=batch["attention_mask"].to(self.device), |
|
rhyme = None if batch["rhyme"] == None else batch["rhyme"].to(self.device), |
|
metre = None if batch["metre"] == None else batch["metre"].to(self.device))['loss'] |
|
loss.backward() |
|
self.optimizer.second_step(zero_grad=True) |
|
self.scheduler.step() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if step % 100 == 0: |
|
print(f'Step {step}, loss : {loss.item()}', flush=True) |
|
|