|
import torch |
|
import torch.nn as nn |
|
import math |
|
from transformers import ( |
|
LlamaPreTrainedModel, |
|
LlamaConfig, |
|
LlamaTokenizer, |
|
ViTModel, |
|
ViTImageProcessor, |
|
AutoTokenizer, |
|
LlamaForCausalLM, |
|
) |
|
from transformers.models.llama.modeling_llama import ( |
|
LlamaDecoderLayer, |
|
LlamaAttention, |
|
LlamaRMSNorm, |
|
) |
|
from typing import Optional, Tuple |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
from tqdm import tqdm |
|
from IPython.display import clear_output, display |
|
import torch.optim as optim |
|
from torch.amp import autocast, GradScaler |
|
from transformers import get_linear_schedule_with_warmup |
|
|
|
|
|
class CustomCrossAttention(LlamaAttention): |
|
def __init__(self, config, layer_idx, num_key_value_heads=None): |
|
super().__init__(config, layer_idx=layer_idx) |
|
self.layer_idx = layer_idx |
|
self.num_key_value_heads = num_key_value_heads or config.num_attention_heads |
|
self.k_proj = nn.Linear( |
|
config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False |
|
) |
|
self.v_proj = nn.Linear( |
|
config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False |
|
) |
|
self.q_proj = nn.Linear( |
|
config.hidden_size, self.num_heads * self.head_dim, bias=False |
|
) |
|
self.o_proj = nn.Linear( |
|
self.num_heads * self.head_dim, config.hidden_size, bias=False |
|
) |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
position_ids=None, |
|
past_key_value=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
output_attentions=False, |
|
use_cache=False, |
|
): |
|
if encoder_hidden_states is None: |
|
raise ValueError( |
|
f"Cross-attention layer {self.layer_idx} requires encoder_hidden_states" |
|
) |
|
|
|
bsz, tgt_len, _ = hidden_states.size() |
|
src_len = encoder_hidden_states.size(1) |
|
|
|
|
|
query_states = self.q_proj(hidden_states).view( |
|
bsz, tgt_len, self.num_heads, self.head_dim |
|
).transpose(1, 2) |
|
|
|
|
|
key_states = self.k_proj(encoder_hidden_states).view( |
|
bsz, src_len, self.num_key_value_heads, self.head_dim |
|
).transpose(1, 2) |
|
|
|
value_states = self.v_proj(encoder_hidden_states).view( |
|
bsz, src_len, self.num_key_value_heads, self.head_dim |
|
).transpose(1, 2) |
|
|
|
|
|
if self.num_key_value_heads != self.num_heads: |
|
key_states = key_states.repeat_interleave( |
|
self.num_heads // self.num_key_value_heads, dim=1 |
|
) |
|
value_states = value_states.repeat_interleave( |
|
self.num_heads // self.num_key_value_heads, dim=1 |
|
) |
|
|
|
if past_key_value is not None: |
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
|
|
if use_cache: |
|
next_past_key_value = (key_states, value_states) |
|
else: |
|
next_past_key_value = None |
|
|
|
|
|
attn_weights = torch.matmul( |
|
query_states, key_states.transpose(-1, -2) |
|
) / math.sqrt(self.head_dim) |
|
|
|
if encoder_attention_mask is not None: |
|
attn_weights = attn_weights + encoder_attention_mask |
|
|
|
attn_weights = nn.functional.softmax( |
|
attn_weights, dim=-1, dtype=torch.float32 |
|
).to(query_states.dtype) |
|
|
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
attn_output = ( |
|
attn_output.transpose(1, 2).contiguous().view(bsz, tgt_len, -1) |
|
) |
|
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
outputs = (attn_output,) |
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
if use_cache: |
|
outputs += (next_past_key_value,) |
|
else: |
|
outputs += (None,) |
|
|
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
class CustomLlamaDecoderLayer(nn.Module): |
|
def __init__(self, config, layer_idx): |
|
super().__init__() |
|
self.layer_idx = layer_idx |
|
self.self_attn = LlamaAttention(config, layer_idx=layer_idx) |
|
self.mlp = LlamaDecoderLayer(config, layer_idx=layer_idx).mlp |
|
self.input_layernorm = LlamaRMSNorm( |
|
config.hidden_size, eps=config.rms_norm_eps |
|
) |
|
self.post_attention_layernorm = LlamaRMSNorm( |
|
config.hidden_size, eps=config.rms_norm_eps |
|
) |
|
self.cross_attn = CustomCrossAttention(config, layer_idx=layer_idx) |
|
self.cross_attn_layer_norm = LlamaRMSNorm( |
|
config.hidden_size, eps=config.rms_norm_eps |
|
) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
position_ids=None, |
|
past_key_value=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
output_attentions=False, |
|
use_cache=False, |
|
|
|
): |
|
|
|
if encoder_hidden_states is None: |
|
raise ValueError(f"Cross-attention layer {self.layer_idx} requires encoder_hidden_states") |
|
|
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
self_attn_past_key_value = ( |
|
past_key_value[:2] if past_key_value is not None else None |
|
) |
|
self_attn_outputs = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=self_attn_past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
) |
|
hidden_states = residual + self_attn_outputs[0] |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.cross_attn_layer_norm(hidden_states) |
|
cross_attn_past_key_value = ( |
|
past_key_value[2:] if past_key_value is not None else None |
|
) |
|
cross_attn_outputs = self.cross_attn( |
|
hidden_states=hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=cross_attn_past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
) |
|
hidden_states = residual + cross_attn_outputs[0] |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
present_key_value = None |
|
if use_cache: |
|
|
|
self_pkv = self_attn_outputs[2] if len(self_attn_outputs) > 2 else None |
|
cross_pkv = cross_attn_outputs[2] if len(cross_attn_outputs) > 2 else None |
|
|
|
if self_pkv is not None and cross_pkv is not None: |
|
present_key_value = self_pkv + cross_pkv |
|
elif self_pkv is not None: |
|
present_key_value = self_pkv + (None, None) |
|
elif cross_pkv is not None: |
|
present_key_value = (None, None) + cross_pkv |
|
else: |
|
present_key_value = None |
|
|
|
outputs = (hidden_states,) |
|
if use_cache: |
|
outputs += (present_key_value,) |
|
else: |
|
outputs += (None,) |
|
|
|
if output_attentions: |
|
attn_weights = {} |
|
attn_weights["self_attn"] = self_attn_outputs[1] |
|
attn_weights["cross_attn"] = cross_attn_outputs[1] |
|
outputs += (attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
class CustomLlamaModel(LlamaPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.embed_tokens = nn.Embedding( |
|
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id |
|
) |
|
self.layers = nn.ModuleList( |
|
[ |
|
CustomLlamaDecoderLayer(config, layer_idx=i) |
|
for i in range(config.num_hidden_layers) |
|
] |
|
) |
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.gradient_checkpointing = False |
|
self.post_init() |
|
|
|
def _prepare_decoder_attention_mask( |
|
self, attention_mask, input_shape, hidden_states, past_key_values_length |
|
): |
|
|
|
|
|
combined_attention_mask = None |
|
|
|
device = hidden_states.device |
|
tgt_len = input_shape[-1] |
|
src_len = tgt_len + past_key_values_length |
|
|
|
|
|
causal_mask = self._make_causal_mask( |
|
input_shape, device, past_key_values_length |
|
) |
|
|
|
if attention_mask is not None: |
|
expanded_attn_mask = self._expand_mask( |
|
attention_mask, hidden_states.dtype, tgt_len=tgt_len |
|
) |
|
combined_attention_mask = expanded_attn_mask + causal_mask |
|
else: |
|
combined_attention_mask = causal_mask |
|
|
|
return combined_attention_mask |
|
|
|
def _expand_mask(self, mask, dtype, tgt_len=None): |
|
""" |
|
Expands attention_mask from `[batch_size, seq_len]` to |
|
`[batch_size, 1, tgt_seq_len, src_seq_len]` |
|
""" |
|
batch_size, src_len = mask.size() |
|
tgt_len = tgt_len if tgt_len is not None else src_len |
|
|
|
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len) |
|
inverted_mask = 1.0 - expanded_mask |
|
|
|
return inverted_mask.to(dtype=dtype) |
|
|
|
def _make_causal_mask(self, input_shape, device, past_key_values_length): |
|
batch_size, tgt_len = input_shape |
|
total_len = tgt_len + past_key_values_length |
|
mask = torch.tril(torch.ones((total_len, total_len), device=device)) |
|
if past_key_values_length > 0: |
|
mask = mask[past_key_values_length:, :] |
|
return mask[None, None, :, :] |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
position_ids=None, |
|
past_key_values=None, |
|
use_cache=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=True, |
|
): |
|
if input_ids is not None: |
|
batch_size, seq_length = input_ids.size() |
|
else: |
|
raise ValueError("You have to specify input_ids") |
|
|
|
if past_key_values is None: |
|
past_key_values = [None] * len(self.layers) |
|
past_key_values_length = 0 |
|
else: |
|
|
|
if past_key_values[0] is not None and past_key_values[0][0] is not None: |
|
past_key_values_length = past_key_values[0][0].shape[2] |
|
else: |
|
past_key_values_length = 0 |
|
|
|
|
|
if position_ids is None: |
|
position_ids = torch.arange( |
|
past_key_values_length, seq_length + past_key_values_length, |
|
dtype=torch.long, device=input_ids.device |
|
) |
|
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) |
|
|
|
hidden_states = self.embed_tokens(input_ids) |
|
|
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones( |
|
(batch_size, seq_length), device=input_ids.device |
|
) |
|
|
|
combined_attention_mask = self._prepare_decoder_attention_mask( |
|
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length |
|
) |
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
|
|
for idx, decoder_layer in enumerate(self.layers): |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
past_key_value = past_key_values[idx] if past_key_values is not None else None |
|
|
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=combined_attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
use_cache=use_cache, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if use_cache: |
|
past_key_values[idx] = layer_outputs[1] |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
if not return_dict: |
|
outputs = (hidden_states,) |
|
if use_cache: |
|
outputs += (past_key_values,) |
|
if output_hidden_states: |
|
outputs += (all_hidden_states,) |
|
return outputs |
|
|
|
return { |
|
"last_hidden_state": hidden_states, |
|
"past_key_values": past_key_values if use_cache else None, |
|
"hidden_states": all_hidden_states, |
|
} |
|
|
|
|
|
|
|
class CustomLlamaForConditionalGeneration(LlamaPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = CustomLlamaModel(config) |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
labels=None, |
|
past_key_values=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=True, |
|
position_ids=None, |
|
): |
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
position_ids=position_ids, |
|
) |
|
hidden_states = outputs["last_hidden_state"] |
|
|
|
logits = self.lm_head(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
|
logits_flat = logits.view(-1, logits.size(-1)) |
|
|
|
labels_flat = labels.view(-1) |
|
loss = loss_fct(logits_flat, labels_flat) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return { |
|
"loss": loss, |
|
"logits": logits, |
|
"past_key_values": outputs["past_key_values"], |
|
"hidden_states": outputs["hidden_states"], |
|
} |
|
|
|
def generate_caption( |
|
self, |
|
pixel_values, |
|
max_length=16, |
|
temperature=0.7, |
|
top_k=50, |
|
top_p=0.9, |
|
): |
|
self.eval() |
|
with torch.no_grad(): |
|
|
|
encoder_outputs = self.model.vit_model(pixel_values=pixel_values) |
|
encoder_hidden_states = self.model.image_proj(encoder_outputs.last_hidden_state) |
|
encoder_hidden_states = self.model.image_layer_norm(encoder_hidden_states) |
|
encoder_hidden_states = encoder_hidden_states * self.model.scale |
|
encoder_hidden_states = self.model.image_activation(encoder_hidden_states) |
|
encoder_hidden_states = encoder_hidden_states + self.model.image_residual(encoder_hidden_states) |
|
|
|
|
|
generated_ids = torch.full( |
|
(pixel_values.size(0), 1), |
|
self.tokenizer.bos_token_id, |
|
dtype=torch.long, |
|
device=pixel_values.device, |
|
) |
|
past_key_values = None |
|
|
|
for _ in range(max_length): |
|
|
|
position_ids = torch.arange( |
|
generated_ids.size(1), |
|
dtype=torch.long, |
|
device=pixel_values.device |
|
).unsqueeze(0).expand(generated_ids.size(0), -1) |
|
|
|
outputs = self.model( |
|
input_ids=generated_ids, |
|
attention_mask=torch.ones_like(generated_ids).to(generated_ids.device), |
|
encoder_hidden_states=encoder_hidden_states, |
|
use_cache=True, |
|
past_key_values=past_key_values, |
|
return_dict=True, |
|
position_ids=position_ids, |
|
) |
|
logits = outputs["logits"][:, -1, :] / temperature |
|
past_key_values = outputs["past_key_values"] |
|
|
|
|
|
filtered_logits = self.top_k_top_p_filtering( |
|
logits, top_k=top_k, top_p=top_p |
|
) |
|
probabilities = nn.functional.softmax(filtered_logits, dim=-1) |
|
next_token_id = torch.multinomial(probabilities, num_samples=1) |
|
generated_ids = torch.cat((generated_ids, next_token_id), dim=-1) |
|
|
|
if next_token_id.item() == self.tokenizer.eos_token_id: |
|
break |
|
|
|
captions = [ |
|
self.tokenizer.decode(generated_id, skip_special_tokens=True) |
|
for generated_id in generated_ids |
|
] |
|
return captions |
|
|
|
def top_k_top_p_filtering(self, logits, top_k=0, top_p=1.0, filter_value=-float("Inf")): |
|
""" |
|
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. |
|
""" |
|
|
|
logits = logits.clone() |
|
|
|
|
|
batch_size, vocab_size = logits.size() |
|
|
|
|
|
if top_k > 0: |
|
top_k = min(max(top_k, 1), vocab_size) |
|
|
|
top_k_values, _ = torch.topk(logits, top_k, dim=-1) |
|
min_top_k = top_k_values[:, -1, None] |
|
|
|
indices_to_remove = logits < min_top_k |
|
logits[indices_to_remove] = filter_value |
|
|
|
|
|
if top_p < 1.0: |
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) |
|
cumulative_probs = torch.cumsum( |
|
nn.functional.softmax(sorted_logits, dim=-1), dim=-1 |
|
) |
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() |
|
sorted_indices_to_remove[:, 0] = False |
|
|
|
|
|
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) |
|
indices_to_remove.scatter_( |
|
dim=-1, index=sorted_indices, src=sorted_indices_to_remove |
|
) |
|
|
|
logits[indices_to_remove] = filter_value |
|
|
|
return logits |
|
|
|
|
|
|
|
|
|
class ImageProjection(nn.Module): |
|
def __init__(self, input_dim, hidden_dim, output_dim): |
|
super().__init__() |
|
self.proj = nn.Sequential( |
|
nn.Linear(input_dim, hidden_dim), |
|
nn.GELU(), |
|
nn.Linear(hidden_dim, output_dim), |
|
nn.GELU() |
|
) |
|
|
|
def forward(self, x): |
|
return self.proj(x) |
|
|
|
|
|
class ImageCaptioningModel(nn.Module): |
|
def __init__( |
|
self, vit_model, llama_config, llama_model, tokenizer, |
|
): |
|
super().__init__() |
|
self.vit_model = vit_model |
|
self.tokenizer = tokenizer |
|
self.llama_model = llama_model |
|
|
|
|
|
self.image_proj = ImageProjection( |
|
input_dim=vit_model.config.hidden_size, |
|
hidden_dim=llama_config.hidden_size * 2, |
|
output_dim=llama_config.hidden_size |
|
) |
|
nn.init.xavier_uniform_(self.image_proj.proj[0].weight, gain=math.sqrt(2)) |
|
nn.init.zeros_(self.image_proj.proj[0].bias) |
|
nn.init.xavier_uniform_(self.image_proj.proj[2].weight, gain=math.sqrt(2)) |
|
nn.init.zeros_(self.image_proj.proj[2].bias) |
|
|
|
|
|
self.image_layer_norm = nn.LayerNorm(llama_config.hidden_size) |
|
|
|
|
|
self.scale = nn.Parameter(torch.ones(1, 1, llama_config.hidden_size) * 5.0) |
|
|
|
|
|
self.image_activation = nn.GELU() |
|
|
|
|
|
self.image_residual = nn.Linear(llama_config.hidden_size, llama_config.hidden_size) |
|
nn.init.xavier_uniform_(self.image_residual.weight, gain=1.0) |
|
nn.init.zeros_(self.image_residual.bias) |
|
|
|
def forward(self, pixel_values, input_ids=None, attention_mask=None, labels=None): |
|
|
|
encoder_outputs = self.vit_model(pixel_values=pixel_values) |
|
encoder_hidden_states = self.image_proj(encoder_outputs.last_hidden_state) |
|
encoder_hidden_states = self.image_layer_norm(encoder_hidden_states) |
|
encoder_hidden_states = encoder_hidden_states * self.scale |
|
encoder_hidden_states = self.image_activation(encoder_hidden_states) |
|
encoder_hidden_states = encoder_hidden_states + self.image_residual(encoder_hidden_states) |
|
|
|
|
|
if input_ids is not None: |
|
batch_size, seq_length = input_ids.size() |
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) |
|
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) |
|
else: |
|
position_ids = None |
|
|
|
outputs = self.llama_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=None, |
|
labels=labels, |
|
use_cache=False, |
|
return_dict=True, |
|
position_ids=position_ids, |
|
) |
|
|
|
return outputs |
|
|
|
def top_k_top_p_filtering(self, logits, top_k=0, top_p=1.0, filter_value=-float("Inf")): |
|
""" |
|
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. |
|
""" |
|
|
|
logits = logits.clone() |
|
|
|
|
|
batch_size, vocab_size = logits.size() |
|
|
|
|
|
if top_k > 0: |
|
top_k = min(max(top_k, 1), vocab_size) |
|
|
|
top_k_values, _ = torch.topk(logits, top_k, dim=-1) |
|
min_top_k = top_k_values[:, -1, None] |
|
|
|
indices_to_remove = logits < min_top_k |
|
logits[indices_to_remove] = filter_value |
|
|
|
|
|
if top_p < 1.0: |
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) |
|
cumulative_probs = torch.cumsum( |
|
nn.functional.softmax(sorted_logits, dim=-1), dim=-1 |
|
) |
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() |
|
sorted_indices_to_remove[:, 0] = False |
|
|
|
|
|
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) |
|
indices_to_remove.scatter_( |
|
dim=-1, index=sorted_indices, src=sorted_indices_to_remove |
|
) |
|
|
|
logits[indices_to_remove] = filter_value |
|
|
|
return logits |
|
|
|
def generate_caption( |
|
self, |
|
pixel_values, |
|
max_length=16, |
|
temperature=0.7, |
|
top_k=50, |
|
top_p=0.9, |
|
): |
|
self.eval() |
|
with torch.no_grad(): |
|
|
|
encoder_outputs = self.vit_model(pixel_values=pixel_values) |
|
encoder_hidden_states = self.image_proj(encoder_outputs.last_hidden_state) |
|
encoder_hidden_states = self.image_layer_norm(encoder_hidden_states) |
|
encoder_hidden_states = encoder_hidden_states * self.scale |
|
encoder_hidden_states = self.image_activation(encoder_hidden_states) |
|
encoder_hidden_states = encoder_hidden_states + self.image_residual(encoder_hidden_states) |
|
|
|
|
|
generated_ids = torch.full( |
|
(pixel_values.size(0), 1), |
|
self.tokenizer.bos_token_id, |
|
dtype=torch.long, |
|
device=pixel_values.device, |
|
) |
|
past_key_values = None |
|
|
|
for _ in range(max_length): |
|
|
|
position_ids = torch.arange( |
|
generated_ids.size(1), |
|
dtype=torch.long, |
|
device=pixel_values.device |
|
).unsqueeze(0).expand(generated_ids.size(0), -1) |
|
|
|
outputs = self.llama_model( |
|
input_ids=generated_ids, |
|
attention_mask=torch.ones_like(generated_ids).to(generated_ids.device), |
|
encoder_hidden_states=encoder_hidden_states, |
|
use_cache=True, |
|
past_key_values=past_key_values, |
|
return_dict=True, |
|
position_ids=position_ids, |
|
) |
|
logits = outputs["logits"][:, -1, :] / temperature |
|
past_key_values = outputs["past_key_values"] |
|
|
|
|
|
filtered_logits = self.top_k_top_p_filtering( |
|
logits, top_k=top_k, top_p=top_p |
|
) |
|
probabilities = nn.functional.softmax(filtered_logits, dim=-1) |
|
next_token_id = torch.multinomial(probabilities, num_samples=1) |
|
generated_ids = torch.cat((generated_ids, next_token_id), dim=-1) |
|
|
|
if next_token_id.item() == self.tokenizer.eos_token_id: |
|
break |
|
|
|
captions = [ |
|
self.tokenizer.decode(generated_id, skip_special_tokens=True) |
|
for generated_id in generated_ids |
|
] |
|
return captions |
|
|