Dusan's picture
Update fudge/model.py
9d196bb
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence, pad_packed_sequence, pack_padded_sequence
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, GPT2LMHeadModel, GPT2Config, GPT2ForSequenceClassification, GPT2LMHeadModel, MarianTokenizer
from fudge.constants import *
from fudge.util import pad_mask
from fudge.clickbait_classifier import BertClickbaitClassifier, ClickbaitConfig
class Model(nn.Module):
def __init__(self, args, gpt_pad_id, vocab_size, rhyme_group_size=None, glove_embeddings=None, verbose=True):
super(Model, self).__init__()
# self.topic = args.task == 'topic'
self.formality = args.task == 'formality'
self.iambic = args.task == 'iambic'
self.rhyme = args.task == 'rhyme'
self.newline = args.task == 'newline'
self.clickbait = args.task == 'clickbait'
# if self.topic:
# self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words
# if glove_embeddings is None:
# if verbose:
# print('initializing word embeddings from scratch')
# self.word_embed = nn.Embedding(vocab_size, GLOVE_DIM, padding_idx=0)
# else:
# if verbose:
# print('initializing word embeddings from glove')
# self.word_embed = nn.Embedding.from_pretrained(glove_embeddings, padding_idx=0)
# self.rnn = nn.LSTM(HIDDEN_DIM, RNN_DIM, num_layers=3, bidirectional=True)
# self.attention_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
# large_hidden_dim = HIDDEN_DIM
# self.embed_key_linear = nn.Linear(large_hidden_dim, HIDDEN_DIM)
# self.attention_value_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
# self.out_embed_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
# self.out_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
# self.out_linear2 = nn.Linear(HIDDEN_DIM + large_hidden_dim, HIDDEN_DIM)
# self.out_linear3 = nn.Linear(HIDDEN_DIM, 1)
# self.nonlinear = nn.ReLU()
# elif self.formality:
if self.formality:
self.marian_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=0) # 0 in marian is ''
self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False, dropout=0.5) # want it to be causal so we can learn all positions
self.out_linear = nn.Linear(HIDDEN_DIM, 1)
elif self.iambic:
self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id)
self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False, dropout=0) # want it to be causal so we can learn all positions
self.out_linear = nn.Linear(HIDDEN_DIM, 1)
elif self.rhyme:
self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words
self.word_embed = nn.Embedding(rhyme_group_size+1, GLOVE_DIM, padding_idx=0) # this embedding for future words will actually embed the rhyme group idx
self.rnn = nn.LSTM(HIDDEN_DIM, RNN_DIM, num_layers=3, bidirectional=True)
self.attention_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
large_hidden_dim = HIDDEN_DIM + COUNT_SYLLABLE_DIM
self.embed_key_linear = nn.Linear(large_hidden_dim, HIDDEN_DIM)
self.attention_value_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
self.out_embed_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
self.out_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
self.out_linear2 = nn.Linear(HIDDEN_DIM + large_hidden_dim, HIDDEN_DIM)
self.out_linear3 = nn.Linear(HIDDEN_DIM, 1)
self.count_syllable_embed = nn.Embedding(MAX_COUNT_SYLLABLE_DIST+1, COUNT_SYLLABLE_DIM)
self.nonlinear = nn.ReLU()
elif self.newline:
self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words
self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False)
self.count_syllable_embed = nn.Embedding(MAX_COUNT_SYLLABLE_DIST+1, COUNT_SYLLABLE_DIM)
self.out_linear = nn.Linear(HIDDEN_DIM + COUNT_SYLLABLE_DIM, HIDDEN_DIM)
self.out_linear2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
self.out_linear3 = nn.Linear(HIDDEN_DIM, 1)
self.nonlinear = nn.ReLU()
elif self.clickbait:
# mpnet_config = ClickbaitConfig(
# model_type="mpnet",
# pretrained_model="sentence-transformers/all-mpnet-base-v2",
# num_labels=1,
# dropout=0.2,
# inner_dim1=256,
# inner_dim2=32,
# max_length=25,
# load_pretrained=True,
# freeze_bert=False,
# )
#TODO add a checkpoint to Classifier
# print('add a checkpoint to Classifier')
checkpoint = args.checkpoint #'ckpt/clickbait_classifier/checkpoint-1464'
# self.classifier = BertClickbaitClassifier(config=mpnet_config).to(torch.device(args.device))
self.classifier = BertClickbaitClassifier.from_pretrained(checkpoint).to(torch.device(args.device))
else:
raise NotImplementedError # TODO honestly this can/should be refactored into different models
def forward(self, inputs, lengths=None, future_words=None, log_probs=None, syllables_to_go=None, future_word_num_syllables=None, rhyme_group_index=None, run_classifier=False, attention_mask=None):
"""
inputs: token ids, batch x seq, right-padded with 0s
lengths: lengths of inputs; batch
future_words: batch x N words to check if not predict next token, else batch
log_probs: N
syllables_to_go: batch
"""
# if self.topic:
# inputs = self.gpt_embed(inputs) # batch x seq x 300
# inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
# rnn_output, _ = self.rnn(inputs)
# rnn_output, _ = pad_packed_sequence(rnn_output)
# rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
# hidden = rnn_output
# attention_mask = pad_mask(lengths).permute(1, 0) # batch x seq
# embed = self.word_embed(future_words) # batch x N x 300
# embed_query = self.embed_key_linear(embed)
# attention_tensor = self.attention_linear(hidden).unsqueeze(2) * embed_query.unsqueeze(1) # batch x seq x N x 300
# attention_weights = F.softmax(attention_tensor.sum(dim=3), dim=1) # batch x seq x N
# attention_weights = attention_weights * attention_mask.unsqueeze(2)
# hidden = self.attention_value_linear(hidden)
# weighted_hidden = (hidden.unsqueeze(2) * attention_weights.unsqueeze(3)).sum(dim=1) # batch x seq x N x 768 -> batch x N x 768
# unnormalized_scores = (self.out_linear(weighted_hidden) * self.out_embed_linear(embed)) # batch x N x 300
# unnormalized_scores = torch.cat([unnormalized_scores, embed], dim=2)
# unnormalized_scores = self.nonlinear(self.out_linear2(self.nonlinear(unnormalized_scores)))
# unnormalized_scores = self.out_linear3(unnormalized_scores)
# scores = unnormalized_scores.squeeze(2) - log_probs.unsqueeze(0)
# return scores # batch x N of normalized scores or batch x
# elif self.formality:
if self.formality:
inputs = self.marian_embed(inputs)
inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
rnn_output, _ = self.rnn(inputs)
rnn_output, _ = pad_packed_sequence(rnn_output)
rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
return self.out_linear(rnn_output).squeeze(2)
elif self.iambic:
inputs = self.gpt_embed(inputs)
inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
rnn_output, _ = self.rnn(inputs)
rnn_output, _ = pad_packed_sequence(rnn_output)
rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
return self.out_linear(rnn_output).squeeze(2)
elif self.rhyme:
inputs = self.gpt_embed(inputs) # batch x seq x 300
inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
rnn_output, _ = self.rnn(inputs)
rnn_output, _ = pad_packed_sequence(rnn_output)
rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
hidden = rnn_output
attention_mask = pad_mask(lengths).permute(1, 0) # batch x seq
embed = self.word_embed(future_words) # batch x N x 300
embedded_syllables_to_go = self.count_syllable_embed(syllables_to_go).unsqueeze(1).expand(-1, embed.shape[1], -1) # batch x N x 100
auxiliary_embed = embedded_syllables_to_go
embed_query = self.embed_key_linear(torch.cat([embed, auxiliary_embed], dim=2))
attention_tensor = self.attention_linear(hidden).unsqueeze(2) * embed_query.unsqueeze(1) # batch x seq x N x 300
attention_weights = F.softmax(attention_tensor.sum(dim=3), dim=1) # batch x seq x N
attention_weights = attention_weights * attention_mask.unsqueeze(2)
hidden = self.attention_value_linear(hidden)
weighted_hidden = (hidden.unsqueeze(2) * attention_weights.unsqueeze(3)).sum(dim=1) # batch x seq x N x 768 -> batch x N x 768
unnormalized_scores = (self.out_linear(weighted_hidden) * self.out_embed_linear(embed)) # batch x N x 300
unnormalized_scores = torch.cat([unnormalized_scores, embed, auxiliary_embed], dim=2)
unnormalized_scores = self.nonlinear(self.out_linear2(self.nonlinear(unnormalized_scores)))
unnormalized_scores = self.out_linear3(unnormalized_scores)
scores = unnormalized_scores.squeeze(2) - log_probs.unsqueeze(0)
return scores # batch x N of normalized scores or batch x
elif self.newline:
inputs = self.gpt_embed(inputs) # batch x seq x 300
inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
rnn_output, _ = self.rnn(inputs)
rnn_output, _ = pad_packed_sequence(rnn_output)
rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
hidden = torch.cat([rnn_output, self.count_syllable_embed(syllables_to_go).unsqueeze(1).expand(-1, rnn_output.shape[1], -1)], dim=2)
return self.out_linear3(self.nonlinear(self.out_linear2(self.nonlinear(self.out_linear(hidden))))).squeeze(2)
elif self.clickbait:
input_ids = torch.tensor(inputs)
classifer_output = self.classifier(input_ids = input_ids, attention_mask = attention_mask).logits
classifer_output = classifer_output[None,:,:] # batch x seq x 300
# return self.out_linear(rnn_output).squeeze(2)
return classifer_output.squeeze(2)
else:
raise NotImplementedError