|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""This file contains the implementation of the Megrez-Omni model.""" |
|
|
|
|
|
import torch |
|
from transformers import AutoProcessor |
|
from transformers import LlamaForCausalLM |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import add_start_docstrings |
|
|
|
from .audio import AudioEncoder |
|
from .configuration_megrezo import MegrezOConfig |
|
from .modeling_navit_siglip import SiglipVisionTransformer |
|
from .resampler import Resampler |
|
|
|
|
|
def insert_audio_embeddings(text_embeddings, inserted_embeddings, inserted_bounds): |
|
|
|
inserted_bounds = inserted_bounds.long() |
|
|
|
for idx in range(len(inserted_embeddings)): |
|
bid = inserted_bounds[idx][0] |
|
start_id = inserted_bounds[idx][1] |
|
end_id = inserted_bounds[idx][2] |
|
embedding = inserted_embeddings[idx] |
|
text_embeddings[bid, start_id + 1 : end_id] = embedding |
|
|
|
return text_embeddings |
|
|
|
|
|
def insert_image_embeddings(text_embeddings, inserted_embeddings, inserted_bounds): |
|
|
|
inserted_bounds = inserted_bounds.long() |
|
for idx in range(len(inserted_embeddings)): |
|
bid = inserted_bounds[idx][0] |
|
start_id = inserted_bounds[idx][1] |
|
end_id = inserted_bounds[idx][2] |
|
embedding = inserted_embeddings[idx] |
|
text_embeddings[bid, start_id:end_id] = embedding |
|
|
|
return text_embeddings |
|
|
|
|
|
MegrezO_START_DOCSTRING = r""" |
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
and behavior. |
|
Parameters: |
|
config ([`MegrezOConfig`]): |
|
Model configuration class with all the parameters of the model. Initializing with a config file does not |
|
load the weights associated with the model, only the configuration. Check out the |
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
""" |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare MegrezO Model outputting raw hidden-states without any specific head on top.", |
|
MegrezO_START_DOCSTRING, |
|
) |
|
class MegrezOPreTrainedModel(PreTrainedModel): |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
config_class = MegrezOConfig |
|
_skip_keys_device_placement = "past_key_values" |
|
_supports_flash_attn_2 = True |
|
|
|
|
|
class AudioModel(torch.nn.Module): |
|
|
|
def __init__(self, config: MegrezOConfig): |
|
super(AudioModel, self).__init__() |
|
self.config = config |
|
self.audio = AudioEncoder(**config.audio_config.to_dict()) |
|
|
|
def forward(self, audio_info): |
|
audios = audio_info["input_audios"] |
|
input_audio_lengths = audio_info["input_audio_lengths"] |
|
audio_span_tokens = audio_info["audio_span_tokens"] |
|
audios_features = self.audio.encode(audios, input_audio_lengths, audio_span_tokens) |
|
return audios_features |
|
|
|
|
|
class VisionModel(torch.nn.Module): |
|
|
|
def __init__(self, config: MegrezOConfig): |
|
super(VisionModel, self).__init__() |
|
self.config = config |
|
self.vpm = self.init_vision_module() |
|
self.resampler = self.init_resampler(self.config.hidden_size, self.vpm.embed_dim) |
|
|
|
def init_vision_module(self): |
|
if self.config._attn_implementation == "flash_attention_2": |
|
self.config.vision_config._attn_implementation = "flash_attention_2" |
|
else: |
|
|
|
self.config.vision_config._attn_implementation = "eager" |
|
model = SiglipVisionTransformer(self.config.vision_config) |
|
if self.config.drop_vision_last_layer: |
|
model.encoder.layers = model.encoder.layers[:-1] |
|
|
|
setattr(model, "embed_dim", model.embeddings.embed_dim) |
|
setattr(model, "patch_size", model.embeddings.patch_size) |
|
|
|
return model |
|
|
|
def init_resampler(self, embed_dim, vision_dim): |
|
return Resampler( |
|
num_queries=self.config.query_num, |
|
embed_dim=embed_dim, |
|
num_heads=embed_dim // 128, |
|
kv_dim=vision_dim, |
|
adaptive=True, |
|
) |
|
|
|
def get_vision_embedding( |
|
self, |
|
all_pixel_values: torch.Tensor, |
|
patch_attention_mask: torch.Tensor, |
|
tgt_sizes: torch.Tensor, |
|
): |
|
B = all_pixel_values.size(0) |
|
vision_batch_size = self.config.vision_batch_size |
|
if B > vision_batch_size: |
|
hs = [] |
|
for i in range(0, B, vision_batch_size): |
|
start_idx = i |
|
end_idx = i + vision_batch_size |
|
tmp_hs = self.vpm( |
|
all_pixel_values[start_idx:end_idx], |
|
patch_attention_mask=patch_attention_mask[start_idx:end_idx], |
|
tgt_sizes=tgt_sizes[start_idx:end_idx], |
|
).last_hidden_state |
|
hs.append(tmp_hs) |
|
vision_embedding = torch.cat(hs, dim=0) |
|
else: |
|
vision_embedding = self.vpm( |
|
all_pixel_values, |
|
patch_attention_mask=patch_attention_mask, |
|
tgt_sizes=tgt_sizes, |
|
).last_hidden_state |
|
|
|
return vision_embedding |
|
|
|
def _prepare_vision_input(self, images, patch_attention_mask, tgt_sizes): |
|
|
|
device = self.vpm.device |
|
dtype = self.vpm.dtype |
|
|
|
pixel_values = torch.stack([(image.to(device) - 127.5) / 127.5 for image in images]).type(dtype) |
|
patch_attention_mask = patch_attention_mask.to(device) |
|
return pixel_values, patch_attention_mask, tgt_sizes |
|
|
|
def forward(self, images, tgt_sizes, patch_attention_mask): |
|
pixel_values, patch_attention_mask, tgt_sizes = self._prepare_vision_input( |
|
images, patch_attention_mask, tgt_sizes |
|
) |
|
embedding = self.get_vision_embedding(pixel_values, patch_attention_mask, tgt_sizes) |
|
embedding = self.resampler(embedding, tgt_sizes) |
|
return embedding |
|
|
|
|
|
class MegrezO(MegrezOPreTrainedModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.llm = LlamaForCausalLM(config) |
|
self.vision = VisionModel(config) |
|
self.audio = AudioModel(config) |
|
self.post_init() |
|
self.processor = None |
|
|
|
|
|
self.tune_vision = False |
|
self.tune_audio = False |
|
|
|
def _get_or_init_processor(self): |
|
|
|
if self.processor is None: |
|
self.processor = AutoProcessor.from_pretrained( |
|
self.config._name_or_path, |
|
trust_remote_code=True, |
|
) |
|
|
|
return self.processor |
|
|
|
def convert_to_device(self, mini_batch): |
|
for key in mini_batch: |
|
if isinstance(mini_batch[key], torch.Tensor): |
|
mini_batch[key] = mini_batch[key].to(self.device) |
|
if isinstance(mini_batch[key], list): |
|
return_value = [] |
|
for value in mini_batch[key]: |
|
if isinstance(value, torch.Tensor): |
|
value = value.to(self.device) |
|
return_value.append(value) |
|
mini_batch[key] = return_value |
|
|
|
return mini_batch |
|
|
|
def compose_embeddings(self, mini_batch): |
|
position_ids = mini_batch["position_ids"] |
|
input_ids = mini_batch["input_ids"] |
|
image_encoding = mini_batch.get("image_encoding") |
|
audio_encoding = mini_batch.get("audio_encoding") |
|
if position_ids.dtype != torch.int64: |
|
position_ids = position_ids.long() |
|
|
|
embeddings_text = self.llm.model.embed_tokens(input_ids) |
|
input_embeds = embeddings_text |
|
if image_encoding: |
|
pixel_values = image_encoding["pixel_values"] |
|
tgt_sizes = image_encoding["tgt_sizes"] |
|
patch_attention_mask = image_encoding["patch_attention_mask"] |
|
bounds_image = image_encoding["image_bounds"] |
|
embeddings_image = self.vision(pixel_values, tgt_sizes, patch_attention_mask=patch_attention_mask) |
|
input_embeds = insert_image_embeddings(embeddings_text, embeddings_image, bounds_image) |
|
elif self.training and self.tune_vision: |
|
pixel_values = torch.zeros((3, 14, 3584), dtype=torch.float32) |
|
tgt_sizes = torch.tensor([[16, 16]], dtype=torch.int64) |
|
patch_attention_mask = torch.ones((3, 14), dtype=torch.float32) |
|
embeddings_image = self.vision(pixel_values, tgt_sizes, patch_attention_mask=patch_attention_mask) |
|
input_embeds += embeddings_image[0].sum() * 0.0 |
|
|
|
if audio_encoding: |
|
embeddings_audio = self.audio(audio_encoding) |
|
bounds_audio = audio_encoding["audio_bounds"] |
|
input_embeds = insert_audio_embeddings(embeddings_text, embeddings_audio, bounds_audio) |
|
elif self.training and self.tune_audio: |
|
dummy_audio = torch.zeros((1, 128, 3000), dtype=torch.float32) |
|
dummy_audio_lengths = torch.tensor([[125, 62]], dtype=torch.int32) |
|
dummy_span_tokens = [64] |
|
dummy_audio_encoding = [ |
|
{ |
|
"input_audios": dummy_audio, |
|
"input_audio_lengths": dummy_audio_lengths, |
|
"audio_span_tokens": dummy_span_tokens, |
|
} |
|
] |
|
embeddings_audio = self.audio(dummy_audio_encoding) |
|
input_embeds += embeddings_audio[0].sum() * 0.0 |
|
|
|
return input_ids, input_embeds, position_ids |
|
|
|
def forward(self, data, **kwargs): |
|
if self.training: |
|
_, input_embeds, position_ids = self.compose_embeddings(data) |
|
return self.llm.forward( |
|
input_ids=None, |
|
position_ids=position_ids, |
|
inputs_embeds=input_embeds, |
|
**kwargs, |
|
) |
|
return self.llm.forward(**kwargs) |
|
|
|
def generate( |
|
self, |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
image_encoding=None, |
|
audio_encoding=None, |
|
**kwargs, |
|
): |
|
tokenizer = self._get_or_init_processor().tokenizer |
|
data = { |
|
"input_ids": input_ids, |
|
"position_ids": position_ids, |
|
"attention_mask": attention_mask, |
|
"image_encoding": image_encoding, |
|
"audio_encoding": audio_encoding, |
|
} |
|
data = self.convert_to_device(data) |
|
input_ids, input_embeds, position_ids = self.compose_embeddings(data) |
|
|
|
output = self.llm.generate( |
|
inputs_embeds=input_embeds, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
**kwargs, |
|
) |
|
return output |
|
|
|
def trim_stop_words(self, response, stop_words): |
|
if stop_words: |
|
for stop in stop_words: |
|
idx = response.find(stop) |
|
if idx != -1: |
|
response = response[:idx] |
|
return response |
|
|
|
@torch.inference_mode() |
|
def chat(self, input_msgs, processor=None, sampling=False, **kwargs): |
|
if processor is None: |
|
processor = self._get_or_init_processor() |
|
|
|
if sampling: |
|
generation_config = { |
|
"top_p": 0.8, |
|
"top_k": 100, |
|
"temperature": 0.7, |
|
"do_sample": True, |
|
"repetition_penalty": 1.05, |
|
} |
|
else: |
|
generation_config = { |
|
"num_beams": 3, |
|
"repetition_penalty": 1.2, |
|
} |
|
|
|
generation_config.update(kwargs) |
|
if generation_config.get("temperature") == 0: |
|
generation_config["do_sample"] = False |
|
|
|
data = processor(input_msgs) |
|
output_ids = self.generate(**data, **generation_config) |
|
tokenizer = processor.tokenizer |
|
answer = tokenizer.decode(output_ids[0]) |
|
return answer |
|
|