File size: 6,520 Bytes
e48ca55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .modules import AudioEncoder
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
class BartCaptionModel(nn.Module):
def __init__(self, n_mels=128, num_of_conv=6, sr=16000, duration=10, max_length=128, label_smoothing=0.1, bart_type="facebook/bart-base", audio_dim=768):
super(BartCaptionModel, self).__init__()
# non-finetunning case
bart_config = BartConfig.from_pretrained(bart_type)
self.tokenizer = BartTokenizer.from_pretrained(bart_type)
self.bart = BartForConditionalGeneration(bart_config)
self.n_sample = sr * duration
self.hop_length = int(0.01 * sr) # hard coding hop_size
self.n_frames = int(self.n_sample // self.hop_length)
self.num_of_stride_conv = num_of_conv - 1
self.n_ctx = int(self.n_frames // 2**self.num_of_stride_conv) + 1
self.audio_encoder = AudioEncoder(
n_mels = n_mels, # hard coding n_mel
n_ctx = self.n_ctx,
audio_dim = audio_dim,
text_dim = self.bart.config.hidden_size,
num_of_stride_conv = self.num_of_stride_conv
)
self.max_length = max_length
self.loss_fct = nn.CrossEntropyLoss(label_smoothing= label_smoothing, ignore_index=-100)
@property
def device(self):
return list(self.parameters())[0].device
def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.ls
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
def forward_encoder(self, audio):
audio_embs = self.audio_encoder(audio)
encoder_outputs = self.bart.model.encoder(
input_ids=None,
inputs_embeds=audio_embs,
return_dict=True
)["last_hidden_state"]
return encoder_outputs, audio_embs
def forward_decoder(self, text, encoder_outputs):
text = self.tokenizer(text,
padding='longest',
truncation=True,
max_length=self.max_length,
return_tensors="pt")
input_ids = text["input_ids"].to(self.device)
attention_mask = text["attention_mask"].to(self.device)
decoder_targets = input_ids.masked_fill(
input_ids == self.tokenizer.pad_token_id, -100
)
decoder_input_ids = self.shift_tokens_right(
decoder_targets, self.bart.config.pad_token_id, self.bart.config.decoder_start_token_id
)
decoder_outputs = self.bart(
input_ids=None,
attention_mask=None,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=attention_mask,
inputs_embeds=None,
labels=None,
encoder_outputs=(encoder_outputs,),
return_dict=True
)
lm_logits = decoder_outputs["logits"]
loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1))
return loss
def forward(self, audio, text):
encoder_outputs, _ = self.forward_encoder(audio)
loss = self.forward_decoder(text, encoder_outputs)
return loss
def generate(self,
samples,
use_nucleus_sampling=False,
num_beams=5,
max_length=128,
min_length=2,
top_p=0.9,
repetition_penalty=1.0,
):
# self.bart.force_bos_token_to_be_generated = True
audio_embs = self.audio_encoder(samples)
encoder_outputs = self.bart.model.encoder(
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=audio_embs,
output_attentions=None,
output_hidden_states=None,
return_dict=True)
input_ids = torch.zeros((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
input_ids[:, 0] = self.bart.config.decoder_start_token_id
decoder_attention_mask = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
if use_nucleus_sampling:
outputs = self.bart.generate(
input_ids=None,
attention_mask=None,
decoder_input_ids=input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
max_length=max_length,
min_length=min_length,
do_sample=True,
top_p=top_p,
num_return_sequences=1,
repetition_penalty=1.1)
else:
outputs = self.bart.generate(input_ids=None,
attention_mask=None,
decoder_input_ids=input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
head_mask=None,
decoder_head_mask=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
repetition_penalty=repetition_penalty)
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return captions
|