| | """PyTorch MarkupDM model.""" |
| |
|
| | import contextlib |
| | import math |
| | import os |
| | from typing import Any |
| |
|
| | import rff.layers |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import ( |
| | AutoModel, |
| | AutoModelForCausalLM, |
| | GenerationMixin, |
| | PreTrainedModel, |
| | ) |
| | from transformers.loss.loss_utils import LOSS_MAPPING |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from transformers.utils import logging |
| |
|
| | from .configuration_markupdm import MarkupDMConfig |
| | from .loss_utils import WeightedCausalLMLoss |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | LOSS_MAPPING["WeightedCausalLMLoss"] = WeightedCausalLMLoss |
| |
|
| |
|
| | class MarkupDMForCausalLM(PreTrainedModel, GenerationMixin): |
| | config: MarkupDMConfig |
| | config_class = MarkupDMConfig |
| |
|
| | supports_gradient_checkpointing = True |
| | _supports_flash_attn_2 = True |
| |
|
| | def __init__( |
| | self, |
| | config: MarkupDMConfig, |
| | text_model: PreTrainedModel, |
| | vision_model: PreTrainedModel, |
| | ) -> None: |
| | if not isinstance(config, self.config_class): |
| | raise ValueError(f"Config: {config} has to be of type {self.config_class}") |
| |
|
| | |
| | logger.info(f"MarkupDM config: {config}") |
| | super().__init__(config) |
| |
|
| | self.text_model = text_model.train() |
| | self.vision_model = vision_model.eval().requires_grad_(False) |
| |
|
| | if self.text_model.config.to_dict() != self.config.text_model.to_dict(): |
| | logger.warning( |
| | f"Config of the text model: {self.text_model.__class__} is" |
| | f"overwritten by shared text config: {self.config.text_model}" |
| | ) |
| | if self.vision_model.config.to_dict() != self.config.vision_model.to_dict(): |
| | logger.warning( |
| | f"Config of the vision model: {self.vision_model.__class__} is" |
| | f"overwritten by shared vision config: {self.config.vision_model}" |
| | ) |
| |
|
| | |
| | |
| | self.text_model.config = self.config.text_model |
| | self.vision_model.config = self.config.vision_model |
| |
|
| | |
| | base_size = self.text_model.config.vocab_size |
| | if base_size < self.config.vocab_size: |
| | self.text_model.resize_token_embeddings(self.config.vocab_size) |
| | new_size = self.text_model.get_input_embeddings().num_embeddings |
| | logger.info(f"Resize embedding layer from {base_size} to {new_size} tokens") |
| |
|
| | d_text = self.text_model.config.hidden_size |
| | assert self.vision_model.config.model_type == "vqmodel" |
| | d_vision = self.vision_model.model.embed_dim |
| | image_pos_size = self.config.image_pos_size |
| | sigma = self.config.image_pos_sigma |
| | m = math.ceil(image_pos_size / 2) |
| | self.image_vocab_size = self.vision_model.model.n_embed |
| |
|
| | |
| | self.proj_vpos = rff.layers.PositionalEncoding(sigma, m) |
| | self.proj_vt = nn.Linear(d_vision + image_pos_size, d_text) |
| | self.vis_head = nn.Linear(d_text, self.image_vocab_size) |
| |
|
| | |
| | scale_factor = 2 ** (vision_model.model.encoder.num_resolutions - 1) |
| | latent_size = self.config.image_size // scale_factor |
| | self.num_image_tokens = latent_size**2 |
| |
|
| | |
| | self.post_init() |
| |
|
| | |
| | if config.freeze_text_embeddings: |
| | self.text_model.get_input_embeddings().requires_grad_(False) |
| |
|
| | def tie_weights(self) -> None: |
| | self.text_model.tie_weights() |
| |
|
| | @classmethod |
| | def from_pretrained(cls, *args: Any, **kwargs: Any) -> "MarkupDMForCausalLM": |
| | assert "config" in kwargs, "Config must be provided" |
| | config = kwargs["config"] |
| | dtype = kwargs.get("dtype", kwargs.get("torch_dtype", None)) |
| |
|
| | |
| | text_model = AutoModelForCausalLM.from_config( |
| | config.text_model, |
| | dtype=dtype, |
| | attn_implementation=config._attn_implementation, |
| | ) |
| |
|
| | |
| | with contextlib.redirect_stdout(open(os.devnull, "w")): |
| | vision_model = AutoModel.from_config( |
| | config.vision_model, |
| | trust_remote_code=True, |
| | dtype=dtype, |
| | ) |
| |
|
| | return super().from_pretrained( |
| | *args, |
| | **kwargs, |
| | text_model=text_model, |
| | vision_model=vision_model, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | inputs_embeds: torch.Tensor | None = None, |
| | image_mask: torch.Tensor | None = None, |
| | image_pos_ids: torch.Tensor | None = None, |
| | labels: torch.Tensor | None = None, |
| | attention_mask: torch.Tensor | None = None, |
| | position_ids: torch.Tensor | None = None, |
| | past_key_values: tuple[tuple[torch.Tensor]] | None = None, |
| | use_cache: bool | None = None, |
| | output_attentions: bool | None = None, |
| | output_hidden_states: bool | None = None, |
| | return_dict: bool | None = None, |
| | cache_position: torch.Tensor | None = None, |
| | num_items_in_batch: int | None = None, |
| | **kwargs: Any, |
| | ) -> CausalLMOutputWithPast: |
| | for key in kwargs.keys(): |
| | if kwargs[key] is not None: |
| | raise ValueError(f"Unknown argument: {key}={kwargs[key]}") |
| |
|
| | output_attentions = ( |
| | output_attentions |
| | if output_attentions is not None |
| | else self.config.output_attentions |
| | ) |
| | output_hidden_states = ( |
| | output_hidden_states |
| | if output_hidden_states is not None |
| | else self.config.output_hidden_states |
| | ) |
| | return_dict = ( |
| | return_dict if return_dict is not None else self.config.use_return_dict |
| | ) |
| |
|
| | if image_mask is None: |
| | image_mask = input_ids >= self.config.vocab_size |
| |
|
| | |
| | if inputs_embeds is None: |
| | inputs_embeds = self.embed_tokens( |
| | input_ids, |
| | image_mask=image_mask, |
| | image_pos_ids=image_pos_ids, |
| | ) |
| |
|
| | |
| | fwd_kwargs = { |
| | "inputs_embeds": inputs_embeds, |
| | "attention_mask": attention_mask, |
| | "position_ids": position_ids, |
| | "past_key_values": past_key_values, |
| | "use_cache": use_cache, |
| | "output_hidden_states": True, |
| | "output_attentions": output_attentions, |
| | } |
| | if self.config.text_model.model_type == "starcoder2": |
| | fwd_kwargs["cache_position"] = cache_position |
| | outputs = self.text_model(**fwd_kwargs) |
| |
|
| | |
| | text_logits = outputs.logits[:, :, : self.config.vocab_size] |
| |
|
| | |
| | last_hidden_states = outputs.hidden_states[-1] |
| | vision_logits = self.vis_head(last_hidden_states) |
| |
|
| | if labels is not None: |
| | |
| | shift_mask = F.pad(image_mask[:, 1:], (0, 1), value=False) |
| | text_logits[shift_mask] = -float("inf") |
| | vision_logits[~shift_mask] = -float("inf") |
| |
|
| | |
| | logits = torch.cat([text_logits, vision_logits], dim=-1) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss = self.loss_function( |
| | logits=logits, |
| | labels=labels, |
| | image_vocab_size=self.image_vocab_size, |
| | image_loss_weight=self.config.image_loss_weight, |
| | num_items_in_batch=num_items_in_batch, |
| | **kwargs, |
| | ) |
| |
|
| | if not return_dict: |
| | output = (logits,) + outputs[1:] |
| | return (loss,) + output if loss is not None else output |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=outputs.past_key_values, |
| | hidden_states=outputs.hidden_states if output_hidden_states else None, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| | def embed_tokens( |
| | self, |
| | input_ids: torch.Tensor, |
| | image_mask: torch.Tensor | None = None, |
| | image_pos_ids: torch.Tensor | None = None, |
| | ) -> torch.Tensor: |
| | if image_mask is None: |
| | return self.text_embed(input_ids) |
| |
|
| | |
| | size = input_ids.size() + (self.text_model.config.hidden_size,) |
| | inputs_embeds = torch.zeros(size, device=self.device, dtype=self.dtype) |
| |
|
| | |
| | text_embeds = self.text_embed(input_ids[~image_mask]) |
| | inputs_embeds[~image_mask] = text_embeds |
| |
|
| | |
| | image_embeds = self.vis_embed(input_ids[image_mask] - self.config.vocab_size) |
| |
|
| | |
| | assert image_pos_ids is not None |
| | image_pos = image_pos_ids / self.num_image_tokens |
| | image_pos = self.proj_vpos(image_pos.unsqueeze(-1)).to(image_embeds) |
| | image_pos = image_pos[image_mask][:, : self.config.image_pos_size] |
| | image_embeds = torch.cat([image_embeds, image_pos], dim=-1) |
| |
|
| | |
| | image_embeds = self.proj_vt(image_embeds) |
| | inputs_embeds[image_mask] = image_embeds |
| |
|
| | return inputs_embeds |
| |
|
| | def text_embed(self, input_ids: torch.Tensor) -> torch.Tensor: |
| | return self.text_model.get_input_embeddings()(input_ids) |
| |
|
| | def vis_embed(self, input_ids: torch.Tensor) -> torch.Tensor: |
| | return self.vision_model.model.quantize.embedding(input_ids) |
| |
|
| | def prepare_inputs_for_generation( |
| | self, input_ids: torch.Tensor, **model_kwargs: Any |
| | ) -> dict: |
| | |
| | default_prepare_inputs = self.text_model.prepare_inputs_for_generation |
| | inputs = default_prepare_inputs(input_ids, **model_kwargs) |
| |
|
| | |
| | base_ids = torch.arange(self.num_image_tokens, device=self.device) |
| | image_pos_ids = torch.zeros_like(input_ids) |
| | image_mask_all = input_ids >= self.config.vocab_size |
| | for i_batch, image_mask in enumerate(image_mask_all): |
| | N = sum(image_mask) |
| | pos_ids = base_ids.repeat(N // self.num_image_tokens + 1) |
| | image_pos_ids[i_batch, image_mask] = pos_ids[:N] |
| | length = inputs["input_ids"].size(1) |
| | inputs["image_pos_ids"] = image_pos_ids[:, -length:] |
| |
|
| | inputs["image_mask"] = inputs["input_ids"] >= self.config.vocab_size |
| |
|
| | return inputs |
| |
|