|
|
from typing import List, Optional, Tuple, Union |
|
|
from torch import nn |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from transformers.utils import logging |
|
|
from typing import Optional, Union |
|
|
import torch |
|
|
import torchvision |
|
|
from torch import nn |
|
|
from einops import rearrange |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers import GenerationConfig |
|
|
|
|
|
from .configuration_gar import GARConfig |
|
|
from .modeling_perception_lm import PerceptionLMForConditionalGeneration |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class GARModel(PreTrainedModel): |
|
|
config_class = GARConfig |
|
|
main_input_name = 'pixel_values' |
|
|
base_model_prefix = 'language_model' |
|
|
_no_split_modules = ['LlamaDecoderLayer'] |
|
|
_supports_flash_attn_2 = True |
|
|
supports_gradient_checkpointing = True |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: GARConfig, |
|
|
mllm=None, |
|
|
mask_patch_embedding=None, |
|
|
use_flash_attn=True, |
|
|
): |
|
|
super().__init__(config) |
|
|
use_flash_attn = use_flash_attn |
|
|
config.mllm_config.use_flash_attn = True if use_flash_attn else False |
|
|
config.mllm_config.text_config.use_flash_attn = True if use_flash_attn else False |
|
|
config.mllm_config.vision_config.use_flash_attn = False |
|
|
|
|
|
config.mllm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager' |
|
|
config.mllm_config.vision_config._attn_implementation = 'eager' |
|
|
|
|
|
self.prompt_numbers = config.prompt_numbers |
|
|
|
|
|
if mllm is not None: |
|
|
self.mllm = mllm |
|
|
else: |
|
|
self.mllm = PerceptionLMForConditionalGeneration(config.mllm_config) |
|
|
if mask_patch_embedding is not None: |
|
|
self.mask_patch_embedding = mask_patch_embedding |
|
|
else: |
|
|
self.mask_patch_embedding = nn.Conv2d( |
|
|
in_channels=3, |
|
|
out_channels=config.mask_path_embedding_out_channels, |
|
|
kernel_size=config.kernel_size, |
|
|
stride=config.kernel_size, |
|
|
bias=False, |
|
|
) |
|
|
|
|
|
self.crop_tokens_ids = config.crop_tokens_ids |
|
|
|
|
|
@property |
|
|
def lm_head(self): |
|
|
return self.mllm.model.language_model.get_output_embeddings() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.mllm.model.language_model.get_input_embeddings() |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.mllm.model.language_model.get_output_embeddings() |
|
|
|
|
|
def forward(self, data, data_samples=None, mode='loss'): |
|
|
crop_tokens = self.crop_tokens_ids |
|
|
|
|
|
pixel_values = data['pixel_values'].to(self.mllm.device).to(self.mllm.dtype) |
|
|
mask_values = torch.round((data['global_mask_values'] + 1.) / 2. * 255.).long().to(self.mllm.device) |
|
|
mask_values = torch.clamp(mask_values, min=0, max=self.prompt_numbers) |
|
|
assert mask_values.max() < self.prompt_numbers + 1 and mask_values.min() >= 0 |
|
|
|
|
|
mask_embeds = self.mask_patch_embedding((mask_values != self.prompt_numbers).to(self.mllm.dtype)) |
|
|
input_ids = data['input_ids'] |
|
|
aspect_ratios = data['aspect_ratios'] |
|
|
bboxes = data['bboxes'] |
|
|
assert input_ids.shape[0] == 1, "Currently only support batch_size=1" |
|
|
|
|
|
inputs_embeds = self.mllm.get_input_embeddings()(input_ids) |
|
|
labels = data['labels'] |
|
|
|
|
|
image_features = None |
|
|
if pixel_values is not None: |
|
|
image_features = self.mllm.get_image_features( |
|
|
pixel_values=pixel_values.unsqueeze(0), |
|
|
mask_embeds=mask_embeds, |
|
|
) |
|
|
image_features = image_features.to(inputs_embeds.device, dtype=inputs_embeds.dtype) |
|
|
special_image_mask, _ = self.mllm.get_placeholder_mask( |
|
|
input_ids, inputs_embeds=inputs_embeds, image_features=image_features |
|
|
) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) |
|
|
|
|
|
|
|
|
new_inputs_embeds = [] |
|
|
new_labels = [] |
|
|
image_features_tiles = rearrange(image_features[1:].unsqueeze(0), 'b n (h w) c -> b n c h w', h=16, w=16) |
|
|
for batch_idx in range(inputs_embeds.shape[0]): |
|
|
curr_inputs_embeds = inputs_embeds[batch_idx] |
|
|
curr_labels = labels[batch_idx] |
|
|
for crop_token in crop_tokens: |
|
|
if crop_token in input_ids[batch_idx]: |
|
|
target_mask = input_ids[batch_idx].eq(crop_token) |
|
|
target_indices = target_mask.nonzero().squeeze() |
|
|
head_idx = target_indices.min().item() |
|
|
tail_idx = target_indices.max().item() |
|
|
image_features_recover = self._merge(image_features_tiles, aspect_ratios[batch_idx][0], aspect_ratios[batch_idx][1]) |
|
|
feat_h, feat_w = image_features_recover.shape[2:] |
|
|
|
|
|
x1, y1, x2, y2 = bboxes[batch_idx][str(crop_token)] |
|
|
orig_h, orig_w = feat_h * 28, feat_w * 28 |
|
|
|
|
|
|
|
|
roi_orig_x1 = x1 * orig_w |
|
|
roi_orig_y1 = y1 * orig_h |
|
|
roi_orig_x2 = x2 * orig_w |
|
|
roi_orig_y2 = y2 * orig_h |
|
|
|
|
|
|
|
|
spatial_scale = feat_w / orig_w |
|
|
roi_feat_x1 = roi_orig_x1 * spatial_scale |
|
|
roi_feat_y1 = roi_orig_y1 * spatial_scale |
|
|
roi_feat_x2 = roi_orig_x2 * spatial_scale |
|
|
roi_feat_y2 = roi_orig_y2 * spatial_scale |
|
|
|
|
|
roi = torch.tensor( |
|
|
[0, roi_feat_x1, roi_feat_y1, roi_feat_x2, roi_feat_y2], |
|
|
dtype=torch.float32, device=image_features_recover.device, |
|
|
) |
|
|
|
|
|
roi_features = torchvision.ops.roi_align( |
|
|
input=image_features_recover.float(), |
|
|
boxes=roi.unsqueeze(0), |
|
|
output_size=(16, 16), |
|
|
spatial_scale=spatial_scale, |
|
|
sampling_ratio=2, |
|
|
aligned=True, |
|
|
) |
|
|
|
|
|
image_features_replay = roi_features.permute(0, 2, 3, 1).flatten(1, 2).to(image_features_recover.dtype).squeeze() |
|
|
|
|
|
curr_inputs_embeds = torch.cat([ |
|
|
curr_inputs_embeds[:head_idx], |
|
|
image_features_replay, |
|
|
curr_inputs_embeds[tail_idx+1:], |
|
|
]) |
|
|
curr_labels = torch.cat([ |
|
|
curr_labels[:head_idx], |
|
|
-100 * torch.ones(image_features_replay.shape[0], dtype=torch.long, device=labels.device), |
|
|
curr_labels[tail_idx+1:], |
|
|
]) |
|
|
|
|
|
assert curr_inputs_embeds.shape[0] == curr_labels.shape[0], f"shape mismatch, got {curr_inputs_embeds.shape[0]} != {curr_labels.shape[0]}" |
|
|
|
|
|
new_inputs_embeds.append(curr_inputs_embeds.unsqueeze(0)) |
|
|
new_labels.append(curr_labels) |
|
|
|
|
|
inputs_embeds = torch.cat(new_inputs_embeds, dim=0) |
|
|
labels = torch.cat(new_labels, dim=0) |
|
|
|
|
|
skip_this_batch = False |
|
|
|
|
|
if mode == "loss": |
|
|
position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=inputs_embeds.device).unsqueeze(0).repeat(inputs_embeds.shape[0], 1) |
|
|
attention_mask = torch.ones(inputs_embeds.shape[0], inputs_embeds.shape[1], dtype=torch.long, device=inputs_embeds.device) |
|
|
use_cache = False |
|
|
|
|
|
outputs, _skip_this_case = self._llm_forward( |
|
|
inputs_embeds=inputs_embeds, |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels, |
|
|
use_cache=use_cache |
|
|
) |
|
|
|
|
|
if skip_this_batch or _skip_this_case: |
|
|
print("skip this batch!") |
|
|
loss_dict = {'loss': outputs.loss * 0.0} |
|
|
else: |
|
|
loss_dict = {'loss': outputs.loss} |
|
|
return loss_dict |
|
|
|
|
|
elif mode == "predict": |
|
|
pass |
|
|
elif mode == "tensor": |
|
|
pass |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
return outputs |
|
|
|
|
|
def _merge(self, tiles: torch.Tensor, ncw: int, nch: int) -> torch.Tensor: |
|
|
batch_size, num_tiles, num_channels, tile_height, tile_width = tiles.size() |
|
|
assert num_tiles == ncw * nch, f"{ncw * nch} != {num_tiles}" |
|
|
|
|
|
tiles = tiles.view(batch_size, nch, ncw, num_channels, tile_height, tile_width) |
|
|
tiles = tiles.permute(0, 3, 1, 4, 2, 5).contiguous() |
|
|
|
|
|
original_height = nch * tile_height |
|
|
original_width = ncw * tile_width |
|
|
|
|
|
image = tiles.view(batch_size, num_channels, original_height, original_width) |
|
|
|
|
|
return image |
|
|
|
|
|
def _llm_forward( |
|
|
self, |
|
|
inputs_embeds: torch.FloatTensor, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
image_flags: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
return_dict = return_dict if return_dict is not None \ |
|
|
else self.mllm.config.use_return_dict |
|
|
skip_this_case = False |
|
|
|
|
|
outputs = self.mllm( |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
labels=labels, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
return outputs, skip_this_case |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
global_mask_values: Optional[torch.LongTensor] = None, |
|
|
aspect_ratios: Optional[torch.FloatTensor] = None, |
|
|
bboxes: Optional[torch.FloatTensor] = None, |
|
|
input_ids: Optional[torch.FloatTensor] = None, |
|
|
attention_mask: Optional[torch.LongTensor] = None, |
|
|
generation_config: Optional[GenerationConfig] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**generate_kwargs, |
|
|
) -> torch.LongTensor: |
|
|
device = self.device |
|
|
|
|
|
if pixel_values is not None: |
|
|
pixel_values = pixel_values.to(device).to(self.mllm.dtype) |
|
|
if global_mask_values is not None: |
|
|
|
|
|
mask_values = torch.round((global_mask_values + 1.) / 2. * 255.).long().to(device) |
|
|
mask_values = torch.clamp(mask_values, min=0, max=self.prompt_numbers) |
|
|
|
|
|
assert mask_values.max() < self.prompt_numbers + 1 and mask_values.min() >= 0, f"max: {mask_values.max()}, min: {mask_values.min()}" |
|
|
mask_embeds = self.mask_patch_embedding((mask_values != self.prompt_numbers).to(self.mllm.dtype)) |
|
|
else: |
|
|
mask_embeds = None |
|
|
|
|
|
inputs_embeds = self.mllm.get_input_embeddings()(input_ids) |
|
|
|
|
|
image_features = self.mllm.get_image_features( |
|
|
pixel_values=pixel_values.unsqueeze(0), |
|
|
mask_embeds=mask_embeds, |
|
|
) |
|
|
image_features = image_features.to(inputs_embeds.device, dtype=inputs_embeds.dtype) |
|
|
special_image_mask, _ = self.mllm.get_placeholder_mask( |
|
|
input_ids, inputs_embeds=inputs_embeds, image_features=image_features |
|
|
) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) |
|
|
|
|
|
|
|
|
new_inputs_embeds = [] |
|
|
image_features_tiles = rearrange(image_features[1:].unsqueeze(0), 'b n (h w) c -> b n c h w', h=16, w=16) |
|
|
for batch_idx in range(inputs_embeds.shape[0]): |
|
|
curr_inputs_embeds = inputs_embeds[batch_idx] |
|
|
for crop_token in self.crop_tokens_ids: |
|
|
if crop_token in input_ids[batch_idx]: |
|
|
target_mask = input_ids[batch_idx].eq(crop_token) |
|
|
target_indices = target_mask.nonzero().squeeze() |
|
|
head_idx = target_indices.min().item() |
|
|
tail_idx = target_indices.max().item() |
|
|
image_features_recover = self._merge(image_features_tiles, aspect_ratios[batch_idx][0], aspect_ratios[batch_idx][1]) |
|
|
feat_h, feat_w = image_features_recover.shape[2:] |
|
|
x1, y1, x2, y2 = bboxes[batch_idx][str(crop_token)] |
|
|
orig_h, orig_w = feat_h * 28, feat_w * 28 |
|
|
|
|
|
|
|
|
roi_orig_x1 = x1 * orig_w |
|
|
roi_orig_y1 = y1 * orig_h |
|
|
roi_orig_x2 = x2 * orig_w |
|
|
roi_orig_y2 = y2 * orig_h |
|
|
|
|
|
|
|
|
spatial_scale = feat_w / orig_w |
|
|
roi_feat_x1 = roi_orig_x1 * spatial_scale |
|
|
roi_feat_y1 = roi_orig_y1 * spatial_scale |
|
|
roi_feat_x2 = roi_orig_x2 * spatial_scale |
|
|
roi_feat_y2 = roi_orig_y2 * spatial_scale |
|
|
|
|
|
roi = torch.tensor( |
|
|
[0, roi_feat_x1, roi_feat_y1, roi_feat_x2, roi_feat_y2], |
|
|
dtype=torch.float32, device=image_features_recover.device, |
|
|
) |
|
|
|
|
|
roi_features = torchvision.ops.roi_align( |
|
|
input=image_features_recover.float(), |
|
|
boxes=roi.unsqueeze(0), |
|
|
output_size=(16, 16), |
|
|
spatial_scale=spatial_scale, |
|
|
sampling_ratio=2, |
|
|
aligned=True, |
|
|
) |
|
|
|
|
|
image_features_replay = roi_features.permute(0, 2, 3, 1).flatten(1, 2).to(image_features_recover.dtype).squeeze() |
|
|
|
|
|
curr_inputs_embeds = torch.cat([ |
|
|
curr_inputs_embeds[:head_idx], |
|
|
image_features_replay, |
|
|
curr_inputs_embeds[tail_idx+1:], |
|
|
]) |
|
|
|
|
|
new_inputs_embeds.append(curr_inputs_embeds.unsqueeze(0)) |
|
|
inputs_embeds = torch.cat(new_inputs_embeds, dim=0) |
|
|
else: |
|
|
inputs_embeds = self.mllm.get_input_embeddings()(input_ids) |
|
|
|
|
|
outputs = self.mllm.generate( |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
generation_config=generation_config, |
|
|
output_hidden_states=output_hidden_states, |
|
|
|
|
|
use_cache=True, |
|
|
return_dict_in_generate=True, |
|
|
) |
|
|
|
|
|
return outputs |