|
import copy |
|
from collections import OrderedDict |
|
import torch |
|
import torch.nn as nn |
|
from mmengine.config import Config, ConfigDict |
|
from mmengine.model import BaseModel |
|
from peft import get_peft_model, prepare_model_for_kbit_training |
|
|
|
from xtuner.registry import BUILDER |
|
from .modules import ProjectorConfig_OMG_LLaVA, ProjectorModel_OMG_LLaVA |
|
from xtuner.model.modules import ProjectorModel, ProjectorConfig |
|
from xtuner.model.modules import dispatch_modules |
|
from .utils import (LoadWoInit, find_all_linear_names, |
|
get_peft_model_state_dict, guess_load_checkpoint, |
|
make_inputs_require_grad, |
|
traverse_dict, |
|
prepare_inputs_labels_for_multimodal_with_visual_prompts) |
|
from .convnext_clip import OpenCLIPBackbone |
|
from .omg_seg import OMGSegVisualEncoder |
|
|
|
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, |
|
PROMPT_TEMPLATE) |
|
from xtuner.tools.utils import get_stop_criteria, is_cn_string |
|
from transformers import GenerationConfig |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from pycocotools import mask as _mask |
|
|
|
class OMG_LLaVA(BaseModel): |
|
def __init__(self, |
|
llm, |
|
visual_encoder, |
|
visual_select_layer=-2, |
|
freeze_llm=False, |
|
freeze_visual_encoder=False, |
|
require_omg_decoder=False, |
|
pretrained_pth=None, |
|
llm_lora=None, |
|
visual_encoder_lora=None, |
|
use_activation_checkpointing=True, |
|
projector_depth=2, |
|
text2vision_projector=False, |
|
tokenizer=None, |
|
keep_omg_decoder_frozen=False, |
|
add_seg_pretrain=False, |
|
additional_cross_attn_layers=False, |
|
pixel_shuffle_ratio=None, |
|
train_vocabulary=False, |
|
freeze_llm_with_lora=False, |
|
freeze_visual_projector=False, |
|
rm_prior_embedding=False, |
|
rm_query=False, |
|
clip_feat_channel=1536, |
|
): |
|
super().__init__() |
|
|
|
self.freeze_llm_with_lora = freeze_llm_with_lora |
|
self.freeze_visual_projector = freeze_visual_projector |
|
|
|
self.freeze_llm = freeze_llm |
|
self.freeze_visual_encoder = freeze_visual_encoder |
|
with LoadWoInit(): |
|
self.llm = self._build_from_cfg_or_module(llm) |
|
if visual_encoder.type == OpenCLIPBackbone or visual_encoder.type == OMGSegVisualEncoder: |
|
self.visual_encoder = visual_encoder.type(**visual_encoder) |
|
else: |
|
self.visual_encoder = self._build_from_cfg_or_module( |
|
visual_encoder) |
|
self.llm.config.use_cache = False |
|
dispatch_modules(self.llm) |
|
|
|
projector_config = ProjectorConfig_OMG_LLaVA( |
|
query_channels=256, |
|
feat_channels=clip_feat_channel, |
|
llm_hidden_size=self.llm.config.hidden_size, |
|
depth=projector_depth, |
|
pixel_shuffle_ratio=pixel_shuffle_ratio, |
|
) |
|
self.projector = ProjectorModel_OMG_LLaVA(projector_config).to( |
|
self.visual_encoder.dtype) |
|
|
|
self.text2vision_projector = text2vision_projector |
|
if text2vision_projector: |
|
projector_config = ProjectorConfig( |
|
visual_hidden_size=self.llm.config.hidden_size, |
|
llm_hidden_size=256 * 2, |
|
depth=projector_depth) |
|
self.projector_text2vision = ProjectorModel(projector_config).to( |
|
self.visual_encoder.dtype) |
|
|
|
if rm_query: |
|
self.projector.model.rm_query = rm_query |
|
if rm_prior_embedding: |
|
self.projector.model.rm_prior_embedding = rm_prior_embedding |
|
|
|
if self.freeze_llm: |
|
self.llm.requires_grad_(False) |
|
if self.freeze_visual_encoder: |
|
self.visual_encoder.requires_grad_(False) |
|
|
|
self.use_activation_checkpointing = use_activation_checkpointing |
|
if use_activation_checkpointing: |
|
|
|
if hasattr(self.llm, 'enable_input_require_grads'): |
|
self.llm.enable_input_require_grads() |
|
else: |
|
self.llm.get_input_embeddings().register_forward_hook( |
|
make_inputs_require_grad) |
|
if hasattr(self.visual_encoder, 'enable_input_require_grads'): |
|
self.visual_encoder.enable_input_require_grads() |
|
else: |
|
self.visual_encoder.get_input_embeddings( |
|
).register_forward_hook(make_inputs_require_grad) |
|
self.projector.enable_input_require_grads() |
|
if text2vision_projector: |
|
self.projector_text2vision.enable_input_require_grads() |
|
|
|
|
|
self.gradient_checkpointing_enable() |
|
|
|
|
|
self.added_special_token = False |
|
if tokenizer is not None: |
|
self.tokenizer = tokenizer |
|
tokenizer_type = self.tokenizer['type'] |
|
del self.tokenizer['type'] |
|
self.tokenizer = tokenizer_type(**self.tokenizer) |
|
self._add_special_tokens() |
|
|
|
self.use_llm_lora = llm_lora is not None |
|
self.use_visual_encoder_lora = visual_encoder_lora is not None |
|
|
|
if self.use_llm_lora: |
|
self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing) |
|
if self.freeze_llm_with_lora: |
|
for name, param in self.llm.named_parameters(): |
|
param.requires_grad_(False) |
|
else: |
|
if train_vocabulary: |
|
|
|
for name, param in self.named_parameters(): |
|
if ('tok_' in name or 'embed_tokens' in name) or 'lm_head' in name: |
|
print("Unfrozen {} !!!".format(name)) |
|
param.requires_grad_(True) |
|
if ('output.' in name or 'lm_head' in name) and 'llm' in name and 'lora' not in name: |
|
print("Unfrozen {} !!!".format(name)) |
|
param.requires_grad_(True) |
|
|
|
if self.use_visual_encoder_lora: |
|
self._prepare_visual_encoder_for_lora( |
|
visual_encoder_lora, use_activation_checkpointing) |
|
|
|
if pretrained_pth is not None: |
|
pretrained_state_dict = guess_load_checkpoint(pretrained_pth) |
|
self.load_state_dict(pretrained_state_dict, strict=False) |
|
print(f'Load pretrained weight from {pretrained_pth}') |
|
|
|
self.visual_select_layer = visual_select_layer |
|
|
|
self._is_init = True |
|
|
|
self.require_omg_decoder = require_omg_decoder |
|
if require_omg_decoder: |
|
self.visual_encoder.init_new_decoder() |
|
if keep_omg_decoder_frozen: |
|
for name, param in self.visual_encoder.panoptic_head.transformer_decoder_llm.named_parameters(): |
|
param.requires_grad_(False) |
|
print("Frozen all the omg seg decoder !!!") |
|
|
|
self.additional_cross_attn_layers = additional_cross_attn_layers |
|
if self.additional_cross_attn_layers: |
|
self.visual_encoder.init_cross_attn_layer() |
|
|
|
if self.freeze_visual_projector: |
|
for name, param in self.projector.named_parameters(): |
|
param.requires_grad_(False) |
|
|
|
self.add_seg_pretrain = add_seg_pretrain |
|
self.init_prediction_config = False |
|
|
|
|
|
def _add_special_tokens(self): |
|
assert hasattr(self, "tokenizer") |
|
|
|
segmentation_tokens = ['[SEG]'] |
|
|
|
phrase_tokens = ['<p>', '</p>'] |
|
|
|
region_tokens = ['<region>'] |
|
point_tokens = ['<mark>'] |
|
special_tokens = segmentation_tokens + phrase_tokens + region_tokens |
|
self.tokenizer.add_tokens(special_tokens, special_tokens=True) |
|
|
|
self.seg_token_idx = self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
|
self.bop_token_idx = self.tokenizer("<p>", add_special_tokens=False).input_ids[0] |
|
self.eop_token_idx = self.tokenizer("</p>", add_special_tokens=False).input_ids[0] |
|
self.region_token_idx = self.tokenizer("<region>", add_special_tokens=False).input_ids[0] |
|
|
|
self.llm.resize_token_embeddings(len(self.tokenizer)) |
|
|
|
self.tokenizer.add_tokens(point_tokens, special_tokens=True) |
|
self.mark_token_idx = self.tokenizer("<mark>", add_special_tokens=False).input_ids[0] |
|
if self.use_activation_checkpointing or self.use_llm_lora or not self.freeze_llm: |
|
self.llm.enable_input_require_grads() |
|
self.added_special_token = True |
|
print("[SEG]: {}, <p>: {}, </p>: {}, <region>: {}, <mark>: {}" \ |
|
.format(self.seg_token_idx, self.bop_token_idx, |
|
self.eop_token_idx, self.region_token_idx, self.mark_token_idx)) |
|
print('****************************Add special tokens ********************************************') |
|
return |
|
|
|
def _parse_lora_config(self, lora_config): |
|
if isinstance(lora_config, dict) or isinstance( |
|
lora_config, Config) or isinstance(lora_config, ConfigDict): |
|
lora_config = BUILDER.build(lora_config) |
|
return lora_config |
|
|
|
def _prepare_llm_for_lora(self, |
|
lora_config, |
|
use_activation_checkpointing=True): |
|
lora_config = self._parse_lora_config(lora_config) |
|
self.llm = prepare_model_for_kbit_training( |
|
self.llm, use_activation_checkpointing) |
|
if lora_config.target_modules is None: |
|
modules = find_all_linear_names(self.llm) |
|
lora_config.target_modules = modules |
|
self.llm = get_peft_model(self.llm, lora_config) |
|
for name, param in self.named_parameters(): |
|
if 'tok_' in name or 'lm_head' in name: |
|
print("Unfrozen {} !!!".format(name)) |
|
param.requires_grad_(True) |
|
if 'output.' in name and 'llm' in name and 'lora' not in name: |
|
print("Unfrozen {} !!!".format(name)) |
|
param.requires_grad_(True) |
|
|
|
def _prepare_visual_encoder_for_lora(self, |
|
lora_config, |
|
use_activation_checkpointing=True): |
|
lora_config = self._parse_lora_config(lora_config) |
|
if lora_config.target_modules is None: |
|
modules = find_all_linear_names(self.visual_encoder) |
|
lora_config.target_modules = modules |
|
self.visual_encoder = get_peft_model(self.visual_encoder, lora_config) |
|
|
|
def gradient_checkpointing_enable(self): |
|
self.activation_checkpointing_enable() |
|
|
|
def activation_checkpointing_enable(self): |
|
self.llm.gradient_checkpointing_enable() |
|
if hasattr(self.visual_encoder, 'gradient_checkpointing_enable'): |
|
self.visual_encoder.gradient_checkpointing_enable() |
|
elif hasattr(self.visual_encoder, 'clip_model'): |
|
if self.visual_encoder.clip_model is not None: |
|
self.visual_encoder.clip_model.gradient_checkpointing_enable() |
|
if hasattr(self.projector, 'gradient_checkpointing_enable'): |
|
self.projector.gradient_checkpointing_enable() |
|
if self.text2vision_projector and hasattr(self.projector_text2vision, 'gradient_checkpointing_enable'): |
|
self.projector_text2vision.gradient_checkpointing_enable() |
|
|
|
def gradient_checkpointing_disable(self): |
|
self.activation_checkpointing_disable() |
|
|
|
def activation_checkpointing_disable(self): |
|
self.llm.gradient_checkpointing_disable() |
|
if hasattr(self.visual_encoder, 'gradient_checkpointing_disable'): |
|
self.visual_encoder.gradient_checkpointing_disable() |
|
if hasattr(self.projector, 'gradient_checkpointing_disable'): |
|
self.projector.gradient_checkpointing_disable() |
|
if self.text2vision_projector and hasattr(self.projector_text2vision, 'gradient_checkpointing_disable'): |
|
self.projector_text2vision.gradient_checkpointing_disable() |
|
|
|
def init_weights(self): |
|
pass |
|
|
|
def state_dict(self, *args, **kwargs): |
|
state_dict = super().state_dict(*args, **kwargs) |
|
|
|
to_return = OrderedDict() |
|
|
|
|
|
to_return.update( |
|
{k: v for k, v in state_dict.items() if 'tok_' in k or 'embed_tokens' in k} |
|
) |
|
|
|
to_return.update( |
|
{k: v for k, v in state_dict.items() if ('output.' in k or 'lm_head' in k) and 'llm' in k and 'lora' not in k} |
|
) |
|
|
|
|
|
if self.use_visual_encoder_lora: |
|
to_return.update( |
|
get_peft_model_state_dict( |
|
self.visual_encoder, state_dict=state_dict)) |
|
elif not self.freeze_visual_encoder: |
|
to_return.update({ |
|
k: v |
|
for k, v in state_dict.items() if 'visual_encoder.' in k |
|
}) |
|
|
|
if self.use_llm_lora: |
|
to_return.update( |
|
get_peft_model_state_dict(self.llm, state_dict=state_dict)) |
|
elif not self.freeze_llm: |
|
to_return.update( |
|
{k: v |
|
for k, v in state_dict.items() if 'llm.' in k}) |
|
|
|
to_return.update( |
|
{k: v |
|
for k, v in state_dict.items() if 'projector.' in k}) |
|
|
|
to_return.update( |
|
{k: v |
|
for k, v in state_dict.items() if 'projector_text2vision' in k}) |
|
|
|
|
|
if self.freeze_visual_encoder: |
|
to_return.update( |
|
{k: v |
|
for k, v in state_dict.items() if 'visual_encoder.adapter_proj' in k}) |
|
|
|
|
|
if hasattr(self.visual_encoder, 'clip_model'): |
|
if self.visual_encoder.clip_lora is not None: |
|
to_return.update( |
|
get_peft_model_state_dict(self.visual_encoder.clip_model, |
|
state_dict=state_dict)) |
|
|
|
if self.require_omg_decoder: |
|
to_return.update( |
|
{k: v |
|
for k, v in state_dict.items() |
|
if 'visual_encoder.panoptic_head.transformer_decoder_llm' in k or |
|
'visual_encoder.panoptic_head.mask_embed_llm' in k or |
|
'visual_encoder.panoptic_head.pixel_decoder_llm' in k or |
|
'visual_encoder.panoptic_head.additional_cross_attn_layers' in k or |
|
'visual_encoder.panoptic_head.additional_ffn' in k or |
|
'visual_encoder.downsample_layer' in k |
|
}) |
|
|
|
return to_return |
|
|
|
def _build_from_cfg_or_module(self, cfg_or_mod): |
|
if isinstance(cfg_or_mod, nn.Module): |
|
return cfg_or_mod |
|
elif isinstance(cfg_or_mod, dict): |
|
traverse_dict(cfg_or_mod) |
|
return BUILDER.build(cfg_or_mod) |
|
else: |
|
raise NotImplementedError |
|
|
|
def forward(self, data, data_samples=None, mode='loss'): |
|
if 'pixel_values' in data: |
|
if 'masks' in data: |
|
masks = data['masks'] |
|
del data['masks'] |
|
else: |
|
masks = None |
|
if 'regions' in data: |
|
regions = data['regions'] |
|
del data['regions'] |
|
else: |
|
regions = None |
|
if 'points' in data: |
|
points = data['points'] |
|
del data['points'] |
|
else: |
|
points = None |
|
|
|
visual_outputs = self.visual_encoder( |
|
data['pixel_values'].to(self.visual_encoder.dtype), |
|
output_hidden_states=True) |
|
|
|
if self.add_seg_pretrain: |
|
pred_obj_query, gt_obj_query = prepare_seg_pretrain_data( |
|
visual_outputs, |
|
[self.projector.model.query_proj, self.projector.model.model], |
|
self.projector_text2vision.model |
|
) |
|
|
|
if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\ |
|
or isinstance(visual_outputs, torch.Tensor): |
|
pixel_values = self.projector(visual_outputs) |
|
else: |
|
pixel_values = self.projector( |
|
visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) |
|
|
|
if regions is not None: |
|
region_embeddings, region_success = self.get_region_embeddings( |
|
regions, data['input_ids'], |
|
) |
|
none_region_embeddings = region_embeddings |
|
del regions |
|
else: |
|
region_success = True |
|
region_embeddings = [] |
|
none_region_embeddings = self.get_none_region_embeddings( |
|
input_ids=data['input_ids'], |
|
) |
|
|
|
if points is not None: |
|
points_mark_embedding, mark_success = self.get_points_embeddings( |
|
points, data['input_ids'], |
|
width=data['pixel_values'].shape[-1], |
|
height=data['pixel_values'].shape[-2], |
|
) |
|
none_points_mark_embedding = points_mark_embedding |
|
else: |
|
none_points_mark_embedding = self.get_none_points_embeddings( |
|
data['input_ids'], |
|
width=data['pixel_values'].shape[-1], |
|
height=data['pixel_values'].shape[-2], |
|
) |
|
points_mark_embedding = [] |
|
mark_success = True |
|
|
|
data['pixel_values'] = pixel_values |
|
data = prepare_inputs_labels_for_multimodal_with_visual_prompts( |
|
llm=self.llm, region_id=self.region_token_idx, |
|
regions_feats=region_embeddings, |
|
mark_id=self.mark_token_idx, |
|
mark_feats=points_mark_embedding, |
|
**data) |
|
else: |
|
masks = None |
|
|
|
_zero = none_points_mark_embedding.sum() * 0.0 + none_region_embeddings.sum() * 0.0 |
|
|
|
if mode == 'loss': |
|
if self.add_seg_pretrain: |
|
return self.compute_loss(data, data_samples, masks=masks, region_success=region_success, |
|
pred_gt_obj_query=(pred_obj_query, gt_obj_query), |
|
mark_success=mark_success, _zero=_zero) |
|
else: |
|
return self.compute_loss(data, data_samples, masks=masks, |
|
pred_gt_obj_query=None, |
|
region_success=region_success, |
|
mark_success=mark_success, |
|
_zero=_zero) |
|
elif mode == 'predict': |
|
return self.predict(data, data_samples) |
|
elif mode == 'tensor': |
|
return self._forward(data, data_samples) |
|
else: |
|
raise NotImplementedError |
|
|
|
def _forward(self, data, data_samples=None): |
|
|
|
outputs = self.llm(**data) |
|
|
|
return outputs |
|
|
|
def predict(self, data, data_samples=None): |
|
outputs = self.llm(**data) |
|
logits_dict = [{'logits': logits} for logits in outputs.logits] |
|
return logits_dict |
|
|
|
def compute_loss(self, data, data_samples=None, masks=None, pred_gt_obj_query=None, |
|
region_success=True, mark_success=True, _zero=0): |
|
if 'original_labels' in data.keys(): |
|
input_ids = data['original_labels'] |
|
del data['original_labels'] |
|
else: |
|
input_ids = data['labels'] |
|
outputs = self.llm(**data, output_hidden_states=True) |
|
|
|
loss_dice, loss_mask = self.compute_seg_loss( |
|
input_ids, outputs.hidden_states[-1], masks) |
|
|
|
if pred_gt_obj_query is not None: |
|
pred_obj_query, gt_obj_query = pred_gt_obj_query |
|
proj_loss = torch.mean((pred_obj_query - gt_obj_query) ** 2) * 10 |
|
else: |
|
proj_loss = 0 |
|
|
|
if not region_success: |
|
loss = outputs.loss * 0 |
|
else: |
|
loss = outputs.loss |
|
|
|
if not mark_success: |
|
loss = outputs.loss * 0 |
|
|
|
|
|
loss = loss + _zero |
|
|
|
loss_dict = {'loss': loss, 'loss_dice': outputs.loss* 0 + loss_dice * 0.1, |
|
'loss_mask': outputs.loss * 0 + loss_mask * 0.4, |
|
'loss_proj': outputs.loss * 0 + proj_loss} |
|
return loss_dict |
|
|
|
def __getattr__(self, name: str): |
|
try: |
|
return super().__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.llm, name) |
|
|
|
def get_region_embeddings(self, regions, input_ids): |
|
success = True |
|
if regions is None or len(regions) == 0: |
|
return [], success |
|
else: |
|
region_token_mask = input_ids == self.region_token_idx |
|
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( |
|
input_ids.device) |
|
batch_idxs = batch_idxs[region_token_mask] |
|
if len(regions) != len(batch_idxs): |
|
|
|
success = False |
|
if len(regions) > len(batch_idxs): |
|
regions = regions[:len(batch_idxs)] |
|
else: |
|
n_pad = len(batch_idxs) - len(regions) |
|
pad_region = regions[:1].repeat(n_pad, 1, 1) |
|
regions = torch.cat([pad_region, regions]) |
|
|
|
regions_embeddings = self.visual_encoder.forward_region_sam( |
|
regions, batch_idxs |
|
)[:, 0] |
|
|
|
regions_embeddings = self.projector.model.forward_visual_prompts_embeddings( |
|
regions_embeddings, batch_idxs) |
|
return regions_embeddings, success |
|
|
|
def get_none_region_embeddings(self, input_ids): |
|
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( |
|
input_ids.device) |
|
batch_idxs = batch_idxs[0, :1] |
|
|
|
regions = torch.ones((1, 50, 50)).to(torch.float32).to(input_ids.device) |
|
|
|
regions_embeddings = self.visual_encoder.forward_region_sam( |
|
regions, batch_idxs |
|
)[:, 0] |
|
|
|
regions_embeddings = self.projector.model.forward_visual_prompts_embeddings( |
|
regions_embeddings, batch_idxs) |
|
return regions_embeddings |
|
|
|
def get_points_embeddings(self, points, input_ids, width, height): |
|
success = True |
|
if points is None or len(points) == 0: |
|
return [] |
|
|
|
mark_token_mask = input_ids == self.mark_token_idx |
|
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( |
|
input_ids.device) |
|
batch_idxs = batch_idxs[mark_token_mask] |
|
|
|
if len(points) != len(batch_idxs): |
|
|
|
success = False |
|
if len(points) > len(batch_idxs): |
|
points = points[:len(batch_idxs)] |
|
else: |
|
n_pad = len(batch_idxs) - len(points) |
|
pad_region = points[:1].repeat(n_pad, 1, 1) |
|
points = torch.cat([pad_region, points]) |
|
|
|
marks_embeddings = self.visual_encoder.forward_point_sam( |
|
points, batch_idxs, width=width, height=height |
|
)[:, 0] |
|
|
|
marks_embeddings = self.projector.model.forward_visual_prompts_embeddings( |
|
marks_embeddings, batch_idxs) |
|
return marks_embeddings, success |
|
|
|
def get_none_points_embeddings(self, input_ids, width, height): |
|
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( |
|
input_ids.device) |
|
batch_idxs = batch_idxs[0, :1] |
|
|
|
marks_embeddings = self.visual_encoder.forward_point_sam( |
|
torch.zeros((1, 2)).to(input_ids), batch_idxs, width=width, height=height |
|
)[:, 0] |
|
|
|
marks_embeddings = self.projector.model.forward_visual_prompts_embeddings( |
|
marks_embeddings, batch_idxs) |
|
return marks_embeddings |
|
|
|
def get_visual_prompts_projector_zero(self): |
|
return self.projector.model.visual_prompt_zero |
|
|
|
def compute_seg_loss(self, input_ids, hidden_states, gt_masks): |
|
if not self.text2vision_projector or self.add_seg_pretrain: |
|
return 0.0, 0.0 |
|
success = True |
|
if gt_masks is None or len(gt_masks) == 0: |
|
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( |
|
input_ids.device) |
|
batch_idxs = batch_idxs[0, :1] |
|
gt_masks = [None] |
|
hidden_states = hidden_states[0, :1] |
|
hidden_states = self.projector_text2vision(hidden_states) |
|
|
|
pred_masks_list = self.visual_encoder.forward_llm_seg(hidden_states, batch_idxs) |
|
dice_loss, mask_loss = self.visual_encoder.loss_llm_seg(pred_masks_list, gt_masks) |
|
|
|
return dice_loss * 0.0, mask_loss * 0.0 |
|
|
|
|
|
seg_tokens_mask = input_ids == self.seg_token_idx |
|
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(seg_tokens_mask.device) |
|
|
|
ori_hidden_states = hidden_states |
|
hidden_states = hidden_states[seg_tokens_mask] |
|
batch_idxs = batch_idxs[seg_tokens_mask] |
|
|
|
if len(hidden_states) != len(gt_masks) or len(hidden_states) == 0: |
|
|
|
print("Drop the batch because the number of [SEG] and masks not equal !!!") |
|
hidden_states = ori_hidden_states |
|
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to( |
|
input_ids.device) |
|
batch_idxs = batch_idxs[0, :1] |
|
gt_masks = [None] |
|
hidden_states = hidden_states[0, :1] |
|
hidden_states = self.projector_text2vision(hidden_states) |
|
|
|
pred_masks_list = self.visual_encoder.forward_llm_seg(hidden_states, batch_idxs) |
|
dice_loss, mask_loss = self.visual_encoder.loss_llm_seg(pred_masks_list, gt_masks) |
|
|
|
return dice_loss * 0.0, mask_loss * 0.0 |
|
|
|
assert len(hidden_states) == len(gt_masks), "expect [seg] number equal to mask number, but get {} [seg] and {} masks".format(len(hidden_states), len(gt_masks)) |
|
hidden_states = self.projector_text2vision(hidden_states) |
|
|
|
pred_masks_list = self.visual_encoder.forward_llm_seg(hidden_states, batch_idxs) |
|
dice_loss, mask_loss = self.visual_encoder.loss_llm_seg(pred_masks_list, gt_masks) |
|
|
|
if not success: |
|
return dice_loss * 0.0, mask_loss * 0.0 |
|
|
|
return dice_loss, mask_loss |
|
|
|
def preparing_for_generation(self, metainfo, **kwargs): |
|
|
|
assert hasattr(self, 'tokenizer'), "The Model does not have the tokenizer!!!" |
|
self.bot_name = 'BOT' |
|
if 'template' in metainfo.keys(): |
|
template = metainfo['template'] |
|
else: |
|
template = PROMPT_TEMPLATE['internlm2_chat'] |
|
self.template = template |
|
stop_words = [] |
|
stop_words += template.get('STOP_WORDS', []) |
|
stop_criteria = get_stop_criteria( |
|
tokenizer=self.tokenizer, stop_words=stop_words) |
|
self.stop_criteria = stop_criteria |
|
|
|
default_generation_kwargs = dict( |
|
max_new_tokens=2048, |
|
do_sample=False, |
|
eos_token_id=self.tokenizer.eos_token_id, |
|
pad_token_id=( |
|
self.tokenizer.pad_token_id |
|
if self.tokenizer.pad_token_id is not None |
|
else self.tokenizer.eos_token_id |
|
), |
|
) |
|
default_generation_kwargs.update(metainfo.get('generation_kwargs', {})) |
|
self.gen_config = GenerationConfig(**default_generation_kwargs) |
|
self.init_prediction_config = True |
|
|
|
self.llm.to(self.visual_encoder.dtype) |
|
self.visual_encoder.to(self.visual_encoder.dtype) |
|
self.projector.to(self.visual_encoder.dtype) |
|
self.projector_text2vision.to(self.visual_encoder.dtype) |
|
return |
|
|
|
def predict_forward( |
|
self, pixel_values, text_prompts, |
|
ori_image_size=None, |
|
box_prompts=None, points_prompts=None, mask_prompts=None, **kwargs): |
|
|
|
|
|
assert self.init_prediction_config, "Please set prediction configs using self.preparing_for_generation()" |
|
|
|
ret_predictions = [] |
|
ret_masks = [] |
|
|
|
image = pixel_values.cuda().unsqueeze(0).to(self.visual_encoder.dtype) |
|
visual_outputs = self.visual_encoder(image, output_hidden_states=True) |
|
if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple) \ |
|
or isinstance(visual_outputs, torch.Tensor): |
|
pixel_values = self.projector(visual_outputs) |
|
else: |
|
pixel_values = self.projector( |
|
visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) |
|
|
|
if isinstance(text_prompts, str): |
|
text_prompts = [text_prompts] |
|
for text_prompt in text_prompts: |
|
|
|
input_text = '' |
|
input_text += self.template['INSTRUCTION'].format( |
|
input=text_prompt, round=1, bot_name=self.bot_name) |
|
|
|
chunk_encode = [] |
|
for idx, chunk in enumerate(input_text.split(DEFAULT_IMAGE_TOKEN)): |
|
if idx == 0: |
|
cur_encode = self.tokenizer.encode(chunk) |
|
else: |
|
cur_encode = self.tokenizer.encode(chunk, add_special_tokens=False) |
|
chunk_encode.append(cur_encode) |
|
assert len(chunk_encode) == 2 |
|
ids = [] |
|
for idx, cur_chunk_encode in enumerate(chunk_encode): |
|
ids.extend(cur_chunk_encode) |
|
if idx != len(chunk_encode) - 1: |
|
ids.append(IMAGE_TOKEN_INDEX) |
|
ids = torch.tensor(ids).cuda().unsqueeze(0) |
|
|
|
mm_inputs = prepare_inputs_labels_for_multimodal_with_visual_prompts( |
|
llm=self.llm, input_ids=ids, pixel_values=pixel_values, |
|
region_id=self.region_token_idx, |
|
regions_feats=[], |
|
mark_id=self.mark_token_idx, |
|
mark_feats=[], |
|
) |
|
|
|
generate_output = self.llm.generate( |
|
**mm_inputs, |
|
generation_config=self.gen_config, |
|
streamer=None, |
|
bos_token_id=self.tokenizer.bos_token_id, |
|
stopping_criteria=self.stop_criteria, |
|
output_hidden_states=True, |
|
return_dict_in_generate=True |
|
) |
|
predict = self.tokenizer.decode( |
|
generate_output.sequences[0], skip_special_tokens=True).strip() |
|
ret_predictions.append(predict) |
|
|
|
if ori_image_size is not None and 'masks' in kwargs.keys(): |
|
hidden_states = generate_output.hidden_states |
|
last_hidden_states = [item[-1][0] for item in hidden_states] |
|
last_hidden_states = torch.cat(last_hidden_states, dim=0) |
|
seg_hidden_states = get_seg_hidden_states( |
|
last_hidden_states, generate_output.sequences[0][:-1], |
|
seg_id=self.seg_token_idx |
|
) |
|
|
|
if len(seg_hidden_states) == 0: |
|
print("Warning, no [SEG] tokens !!!") |
|
ret_masks.append(None) |
|
continue |
|
elif len(seg_hidden_states) > 1: |
|
print("Warning, {} [SEG] tokens !!!".format(len(seg_hidden_states))) |
|
seg_hidden_states = seg_hidden_states[:1] |
|
seg_hidden_states = self.projector_text2vision(seg_hidden_states) |
|
batch_idxs = torch.zeros((seg_hidden_states.shape[0],), |
|
dtype=torch.int64).to(seg_hidden_states.device) |
|
pred_masks_list = self.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs) |
|
pred_masks = pred_masks_list[-1] |
|
w, h = copy.deepcopy(ori_image_size) |
|
masks = F.interpolate(pred_masks, size=(max(w, h), max(w, h)), |
|
mode='bilinear', align_corners=False) |
|
masks = masks[:, 0] |
|
|
|
if w == h: |
|
pass |
|
elif w > h: |
|
n_pad = w - h |
|
n_pad_1 = n_pad // 2 |
|
n_pad_2 = n_pad - n_pad_1 |
|
masks = masks[:, n_pad_1: w - n_pad_2] |
|
else: |
|
n_pad = h - w |
|
n_pad_1 = n_pad // 2 |
|
n_pad_2 = n_pad - n_pad_1 |
|
masks = masks[:, :, n_pad_1: h - n_pad_2] |
|
|
|
masks = masks.sigmoid() > 0.5 |
|
masks = masks.int() |
|
ret_masks.append(masks) |
|
|
|
if len(ret_predictions) == 1: |
|
ret_predictions = ret_predictions[0] |
|
if len(ret_masks) == 0: |
|
return {'prediction': ret_predictions} |
|
|
|
_ret_masks = [] |
|
for i, ret_mask in enumerate(ret_masks): |
|
if ret_mask is None: |
|
_ret_masks.append(None) |
|
else: |
|
ret_mask = ret_mask.cpu().numpy() |
|
_ret_masks.append(mask_to_rle(ret_mask)) |
|
|
|
if 'masks' not in kwargs.keys(): |
|
gt_masks = None |
|
else: |
|
gt_masks = mask_to_rle(kwargs['masks'].cpu().numpy()) |
|
|
|
return { |
|
'prediction': ret_predictions, 'prediction_masks': _ret_masks, |
|
'gt_masks': gt_masks, |
|
} |
|
|
|
def prepare_seg_pretrain_data(visual_outputs, |
|
query_in_proj, query_out_proj): |
|
clip_feature, query_feat, attention_mask = visual_outputs |
|
|
|
|
|
|
|
bs, q, _ = query_feat.shape |
|
pred_query_embed = [] |
|
gt_query_embed = [] |
|
for i in range(bs): |
|
valid = attention_mask[i].sum(-1) > 0 |
|
valid_query_feat = query_feat[i][valid] |
|
gt_query_embed.append(valid_query_feat) |
|
|
|
if isinstance(query_in_proj, list): |
|
llm_query = valid_query_feat |
|
for proj in query_in_proj: |
|
llm_query = proj(llm_query) |
|
else: |
|
llm_query = query_in_proj(valid_query_feat) |
|
|
|
pred_query_embed.append(query_out_proj(llm_query)) |
|
|
|
pred_query_embed = torch.cat(pred_query_embed, dim=0) |
|
gt_query_embed = torch.cat(gt_query_embed, dim=0) |
|
return pred_query_embed, gt_query_embed |
|
|
|
def get_seg_hidden_states(hidden_states, output_ids, seg_id): |
|
seg_mask = output_ids == seg_id |
|
n_out = len(seg_mask) |
|
return hidden_states[-n_out:][seg_mask] |
|
|
|
|
|
def mask_to_rle(mask): |
|
rle = [] |
|
for m in mask: |
|
rle.append(_mask.encode(np.asfortranarray(m.astype(np.uint8)))) |
|
return rle |
|
|