Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import pdb | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | |
import modules.utils as utils | |
from modules.caption_model import CaptionModel | |
class AttModel(CaptionModel): | |
def __init__(self, args, tokenizer): | |
super(AttModel, self).__init__() | |
self.args = args | |
self.tokenizer = tokenizer | |
self.vocab_size = len(tokenizer.idx2token) | |
self.max_seq_length = 60 | |
def _sample(self, clip_features, gpt_tokens,update_opts={}): | |
opt = self.args.__dict__ | |
opt.update(**update_opts) | |
sample_method = opt.get('sample_method', 'greedy') | |
if sample_method == 'greedy': | |
return self._greedy_sample(clip_features, gpt_tokens) | |
elif sample_method == 'beam_search': | |
return self._beam_search_sample(clip_features, gpt_tokens) | |
else: | |
raise ValueError("Unknown sample_method: " + sample_method) | |
def _greedy_sample(self, clip_features, gpt_tokens, temperature=1.0): | |
#input_ids = torch.full((clip_features.size(0), 1), self.tokenizer.bos_token_id).type_as(clip_features).long() | |
clip_features = self.clip_project(clip_features).reshape(clip_features.size(0), 1, -1) | |
tokens = [None for _ in range(clip_features.size(0))] | |
finished = [False for _ in range(clip_features.size(0))] | |
max_length = 200 | |
for _ in range(max_length): | |
outputs = self.decoder(inputs_embeds= clip_features) | |
logits = outputs.logits[:, -1, :] / (temperature if temperature > 0 else 1.0) | |
next_tokens = torch.argmax(logits, -1).unsqueeze(1) | |
next_token_embeds = self.decoder.transformer.wte(next_tokens) | |
for j in range(clip_features.size(0)): | |
if finished[j]: | |
continue | |
if tokens[j] is None: | |
tokens[j] = next_tokens[j] | |
else: | |
tokens[j] = torch.cat((tokens[j], next_tokens[j]), dim=0) | |
if next_tokens[j].item() == self.tokenizer.eos_token_id: | |
finished[j] = True | |
clip_features = torch.cat((clip_features, next_token_embeds), dim=1) | |
outputs = [] | |
for token in tokens: | |
try: | |
output_list = token.squeeze().cpu().numpy().tolist() | |
# Pad or truncate output_list to max_length | |
output_list = (output_list + [self.tokenizer.pad_token_id] * max_length)[:max_length] | |
except Exception as e: | |
print(f"Error during decoding: {type(e).__name__}: {e}") | |
output_list = [self.tokenizer.pad_token_id] * max_length | |
outputs.append(output_list) | |
# Convert list of lists to tensor | |
outputs = torch.tensor(outputs, device=clip_features.device) | |
return outputs | |
def _beam_search_sample(self, clip_features, gpt_tokens, beam_size=5): | |
batch_size = clip_features.size(0) | |
# Prepare the first input for every beam | |
input_ids = torch.full((batch_size*beam_size, 1), self.tokenizer.bos_token_id).type_as(clip_features).long() | |
beam_scores = torch.zeros((batch_size, beam_size)).type_as(clip_features) | |
done = [False]*batch_size | |
for _ in range(self.max_seq_length): | |
outputs = self._forward(clip_features.repeat_interleave(beam_size, 0), input_ids) | |
next_token_logits = outputs.logits[:, -1, :] | |
next_token_probs = F.softmax(next_token_logits, dim=-1) | |
# Apply a mask for already finished beams | |
next_token_probs[done] = 0 | |
next_token_probs[:, self.tokenizer.eos_token_id] = -float('Inf') | |
# Multiply old scores with new probabilities | |
scores = beam_scores.unsqueeze(2) * next_token_probs | |
scores = scores.view(batch_size, -1) | |
# Get the top beam_size scores and their respective indices | |
top_scores, top_indices = scores.topk(beam_size, dim=1) | |
# Update beam scores | |
beam_scores = top_scores.log() | |
# Reshape input_ids | |
input_ids = input_ids.view(batch_size, beam_size, -1) | |
# Compute next inputs | |
next_token_ids = top_indices % self.vocab_size | |
beam_indices = top_indices // self.vocab_size | |
next_input_ids = torch.cat([input_ids.gather(1, beam_indices.unsqueeze(2).expand(-1, -1, input_ids.size(2))), next_token_ids.unsqueeze(2)], dim=2) | |
# Flatten input_ids | |
input_ids = next_input_ids.view(batch_size*beam_size, -1) | |
# Check which beams are done | |
done = (next_token_ids == self.tokenizer.eos_token_id).all(dim=1).tolist() | |
if all(done): | |
break | |
return input_ids.view(batch_size, beam_size, -1) | |