videollm-online / models /live_llama /modeling_live_llama.py
chenjoya's picture
Upload 9 files
7d1b5a5 verified
raw
history blame
6.95 kB
import torch
from torch import nn
from transformers import LlamaForCausalLM, Cache
from transformers.activations import GELUActivation
from transformers.utils import logging
from .configuration_live_llama import LiveLlamaConfig
from ..modeling_live import build_live, LiveMixin
logger = logging.get_logger(__name__)
class LiveLlamaForCausalLM(LlamaForCausalLM, LiveMixin):
config_class = LiveLlamaConfig
_keys_to_ignore_on_load_missing = ['vision_encoder', 'connector']
def __init__(self, config: LiveLlamaConfig):
super().__init__(config)
self.connector = torch.nn.Sequential(
torch.nn.Linear(config.vision_hidden_size, config.hidden_size, bias=True),
GELUActivation(config.hidden_size),
torch.nn.Linear(config.hidden_size, config.hidden_size, bias=True),
)
def forward(
self,
input_ids: torch.LongTensor = None,
frames: torch.FloatTensor = None,
attention_mask: torch.Tensor = None,
position_ids: torch.LongTensor = None,
past_key_values: list[torch.FloatTensor] = None,
inputs_embeds: torch.FloatTensor = None,
labels: torch.LongTensor = None,
use_cache: bool = None,
output_attentions: bool = None,
output_hidden_states: bool = None,
return_dict: bool = None,
cache_position: torch.LongTensor = None,
**kwargs,
):
if inputs_embeds is None:
inputs_embeds = self.joint_embed(input_ids, frames)
outputs = super().forward(
attention_mask = attention_mask,
position_ids = position_ids,
past_key_values = past_key_values,
inputs_embeds = inputs_embeds,
# labels
use_cache = use_cache,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
return_dict = return_dict,
cache_position=cache_position,
)
loss = None
if labels is not None:
logits = outputs[0]
v_mask = input_ids.flatten(0, 1) == self.config.v_placeholder_id
weight = v_mask * self.config.stream_loss_weight + ~v_mask
loss = nn.functional.cross_entropy(logits.flatten(0, 1), labels.flatten(), reduction='none') * weight
loss = loss.sum() / (labels >= 0).sum()
if not return_dict:
return (loss,) + outputs[1:] if loss is not None else outputs
outputs.loss = loss
return outputs
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
use_cache=True,
**kwargs,
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, past_length :] # NOTE
# NOTE
if inputs_embeds is not None and past_length < inputs_embeds.size(1):
model_inputs = {"inputs_embeds": inputs_embeds[:, past_length:]}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
if cache_position is None:
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
elif use_cache:
cache_position = cache_position[-input_length:]
model_inputs.update(
{
"position_ids": position_ids, # 长度为新的inputs,从past开始
"cache_position": cache_position, # 没有被cache的区域
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask, # cache + input的长度
}
)
return model_inputs
def build_live_llama(**kwargs):
return build_live(config_class=LiveLlamaConfig, model_class=LiveLlamaForCausalLM, **kwargs)
if __name__ == '__main__':
from ..arguments_live import LiveOnePlusTrainingArguments
print(LiveOnePlusTrainingArguments().to_dict())
model, tokenizer = build_live_llama(is_training=True, **LiveOnePlusTrainingArguments().to_dict())
print(model.config, tokenizer)