from typing import List import torch import torch.nn as nn import torch.nn.functional as F from .llava.model.language_model.llava_llama import (LlavaLlamaForCausalLM, LlavaLlamaModel) from .segment_anything import build_sam_vit_h embedding_dict = {} def dice_loss( inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, scale=1000, # 100000.0, eps=1e-6, ) -> torch.Tensor: """ Compute the DICE loss, similar to generalized IOU for masks. Arguments 'num_masks', 'scale', 'eps' and return value 'loss' are undocumented in original project https://github.com/dvlab-research/LISA About 'num_masks': it's similar to 'avg_factor' in weight_reduce_loss() from https://github.com/open-mmlab/mmdetection/blob/e9cae2d0787cd5c2fc6165a6061f92fa09e48fb1/mmdet/models/losses/utils.py#L30 Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). num_masks: Average factor when computing the mean of losses (?) scale: weight factor applied before computing mean of losses (?) eps: Avoid dividing by zero (?) return: Processed loss values. """ inputs = inputs.sigmoid() inputs = inputs.flatten(1, 2) targets = targets.flatten(1, 2) numerator = 2 * (inputs / scale * targets).sum(-1) denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1) loss = 1 - (numerator + eps) / (denominator + eps) loss = loss.sum() / (num_masks + 1e-8) return loss def sigmoid_ce_loss( inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, ): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). num_masks: Average factor when computing the mean of losses (?) Returns: Loss tensor """ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8) return loss class LisaMetaModel: def __init__( self, config, **kwargs, ): super(LisaMetaModel, self).__init__(config) self.config = config if not hasattr(self.config, "train_mask_decoder"): self.config.train_mask_decoder = kwargs["train_mask_decoder"] self.config.out_dim = kwargs["out_dim"] self.vision_pretrained = kwargs.get("vision_pretrained", None) else: self.vision_pretrained = kwargs.get("vision_pretrained", None) self.initialize_lisa_modules(self.config) def initialize_lisa_modules(self, config): # SAM self.visual_model = build_sam_vit_h(self.vision_pretrained) for param in self.visual_model.parameters(): param.requires_grad = False if config.train_mask_decoder: self.visual_model.mask_decoder.train() for param in self.visual_model.mask_decoder.parameters(): param.requires_grad = True # Projection layer in_dim = config.hidden_size out_dim = config.out_dim text_fc = [ nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True), nn.Linear(in_dim, out_dim), nn.Dropout(0.0), ] self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)]) self.text_hidden_fcs.train() for param in self.text_hidden_fcs.parameters(): param.requires_grad = True class LisaModel(LisaMetaModel, LlavaLlamaModel): def __init__( self, config, **kwargs, ): super(LisaModel, self).__init__(config, **kwargs) self.config.use_cache = False self.config.vision_tower = self.config.mm_vision_tower self.config.mm_vision_select_feature = "patch" self.config.image_aspect_ratio = "square" self.config.image_grid_pinpoints = None self.config.tune_mm_mlp_adapter = False self.config.freeze_mm_mlp_adapter = True self.config.pretrain_mm_mlp_adapter = None self.config.mm_use_im_patch_token = False class LISAForCausalLM(LlavaLlamaForCausalLM): def __init__( self, config, **kwargs, ): if not hasattr(config, "train_mask_decoder"): config.mm_use_im_start_end = kwargs.pop("use_mm_start_end", True) config.mm_vision_tower = kwargs.get( "vision_tower", "openai/clip-vit-large-patch14" ) self.ce_loss_weight = kwargs.pop("ce_loss_weight", None) self.dice_loss_weight = kwargs.pop("dice_loss_weight", None) self.bce_loss_weight = kwargs.pop("bce_loss_weight", None) else: config.mm_vision_tower = config.vision_tower self.seg_token_idx = kwargs.pop("seg_token_idx") super().__init__(config) self.model = LisaModel(config, **kwargs) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_visual_embs(self, pixel_values: torch.FloatTensor): with torch.no_grad(): image_embeddings_list = [] for i in range(pixel_values.shape[0]): torch.cuda.empty_cache() image_embeddings = self.model.visual_model.image_encoder( pixel_values[i].unsqueeze(0) ) image_embeddings_list.append(image_embeddings) torch.cuda.empty_cache() image_embeddings = torch.cat(image_embeddings_list, 0) return image_embeddings def forward(self, **kwargs): if "past_key_values" in kwargs: return super().forward(**kwargs) return self.model_forward(**kwargs) def model_forward( self, images: torch.FloatTensor, images_clip: torch.FloatTensor, input_ids: torch.LongTensor, labels: torch.LongTensor, attention_masks: torch.LongTensor, offset: torch.LongTensor, masks_list: List[torch.FloatTensor], label_list: List[torch.Tensor], resize_list: List[tuple], inference: bool = False, **kwargs, ): image_embeddings = self.get_visual_embs(images) batch_size = image_embeddings.shape[0] assert batch_size == len(offset) - 1 seg_token_mask = input_ids[:, 1:] == self.seg_token_idx seg_token_mask = torch.cat( [ seg_token_mask, torch.zeros((seg_token_mask.shape[0], 1)).bool().cuda(), ], dim=1, ) # hack for IMAGE_TOKEN_INDEX (we suppose that there is only one image, and it is in the front) seg_token_mask = torch.cat( [torch.zeros((seg_token_mask.shape[0], 255)).bool().cuda(), seg_token_mask], dim=1, ) if inference: n_batch = 1 length = input_ids.shape[0] assert images_clip.shape[0] == 1 images_clip_extend = images_clip.expand(length, -1, -1, -1).contiguous() output_hidden_states = [] for i in range(n_batch): start_i, end_i = i * length, min((i + 1) * length, input_ids.shape[0]) output_i = super().forward( images=images_clip_extend[: end_i - start_i], attention_mask=attention_masks[start_i:end_i], input_ids=input_ids[start_i:end_i], output_hidden_states=True, ) output_hidden_states.append(output_i.hidden_states) torch.cuda.empty_cache() output_hidden_states_list = [] output_hidden_states_level = torch.cat(output_hidden_states, dim=0) output_hidden_states_list.append(output_hidden_states_level) output_hidden_states = output_hidden_states_list output = None else: images_clip_list = [] for i in range(len(offset) - 1): start_i, end_i = offset[i], offset[i + 1] images_clip_i = ( images_clip[i] .unsqueeze(0) .expand(end_i - start_i, -1, -1, -1) .contiguous() ) images_clip_list.append(images_clip_i) images_clip = torch.cat(images_clip_list, dim=0) output = super().forward( images=images_clip, attention_mask=attention_masks, input_ids=input_ids, labels=labels, output_hidden_states=True, ) output_hidden_states = output.hidden_states hidden_states = [] assert len(self.model.text_hidden_fcs) == 1 hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states[-1])) last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) pred_embeddings = last_hidden_state[seg_token_mask] seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ] seg_token_offset = seg_token_counts.cumsum(-1) seg_token_offset = torch.cat( [torch.zeros(1).long().cuda(), seg_token_offset], dim=0 ) seg_token_offset = seg_token_offset[offset] pred_embeddings_ = [] for i in range(len(seg_token_offset) - 1): start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1] pred_embeddings_.append(pred_embeddings[start_i:end_i]) pred_embeddings = pred_embeddings_ multimask_output = False pred_masks = [] for i in range(len(pred_embeddings)): ( sparse_embeddings, dense_embeddings, ) = self.model.visual_model.prompt_encoder( points=None, boxes=None, masks=None, text_embeds=pred_embeddings[i].unsqueeze(1), ) sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) low_res_masks, iou_predictions = self.model.visual_model.mask_decoder( image_embeddings=image_embeddings[i].unsqueeze(0), image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) pred_mask = self.model.visual_model.postprocess_masks( low_res_masks, input_size=resize_list[i], original_size=label_list[i].shape, ) pred_masks.append(pred_mask[:, 0]) model_output = output gt_masks = masks_list if inference: return { "pred_masks": pred_masks, "gt_masks": gt_masks, } output = model_output.logits ce_loss = model_output.loss ce_loss = ce_loss * self.ce_loss_weight mask_bce_loss = 0 mask_dice_loss = 0 num_masks = 0 for batch_idx in range(len(pred_masks)): gt_mask = gt_masks[batch_idx] pred_mask = pred_masks[batch_idx] assert ( gt_mask.shape[0] == pred_mask.shape[0] ), "gt_mask.shape: {}, pred_mask.shape: {}".format( gt_mask.shape, pred_mask.shape ) mask_bce_loss += ( sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0]) * gt_mask.shape[0] ) mask_dice_loss += ( dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0]) * gt_mask.shape[0] ) num_masks += gt_mask.shape[0] mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8) mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8) mask_loss = mask_bce_loss + mask_dice_loss loss = ce_loss + mask_loss return { "loss": loss, "ce_loss": ce_loss, "mask_bce_loss": mask_bce_loss, "mask_dice_loss": mask_dice_loss, "mask_loss": mask_loss, } def evaluate( self, images_clip, images, input_ids, resize_list, original_size_list, max_new_tokens=32, tokenizer=None, model_logger=None, embedding_key=None ): with torch.no_grad(): if model_logger is None: import logging model_logger = logging model_logger.debug("start output generation...") outputs = self.generate( images=images_clip, input_ids=input_ids, max_new_tokens=max_new_tokens, num_beams=1, output_hidden_states=True, return_dict_in_generate=True, ) model_logger.debug("done output generation...") output_hidden_states = outputs.hidden_states[-1] output_ids = outputs.sequences seg_token_mask = output_ids[:, 1:] == self.seg_token_idx # hack for IMAGE_TOKEN_INDEX (we suppose that there is only one image, and it is in the front) model_logger.debug(f"start torch.cat to seg_token_mask...") seg_token_mask = torch.cat( [ torch.zeros((seg_token_mask.shape[0], 255)).bool().cuda(), seg_token_mask, ], dim=1, ) model_logger.debug("done torch.cat to seg_token_mask...") hidden_states = [] assert len(self.model.text_hidden_fcs) == 1 hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states)) model_logger.debug("start torch.stack to last_hidden_state...") last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) model_logger.debug("done torch.stack to last_hidden_state...") pred_embeddings = last_hidden_state[seg_token_mask] seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ] seg_token_offset = seg_token_counts.cumsum(-1) model_logger.debug(f"start torch.cat to seg_token_offset...") seg_token_offset = torch.cat( [torch.zeros(1).long().cuda(), seg_token_offset], dim=0 ) model_logger.debug("done torch.cat to seg_token_offset...") pred_embeddings_ = [] for i in range(len(seg_token_offset) - 1): start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1] pred_embeddings_.append(pred_embeddings[start_i:end_i]) pred_embeddings = pred_embeddings_ model_logger.debug(f"start get_visual_embs to image_embeddings with embedding_key {embedding_key}.") if embedding_key is None: image_embeddings = self.get_visual_embs(images) else: try: image_embeddings = embedding_dict[embedding_key] except KeyError: model_logger.debug(f"embedding_key {embedding_key} not in embedding_dict, creating embedding now!") image_embeddings = self.get_visual_embs(images) embedding_dict[embedding_key] = image_embeddings model_logger.debug(f"image embedding added in embedding_dict with embedding_key {embedding_key}!") model_logger.debug("done get_visual_embs to image_embeddings...") multimask_output = False pred_masks = [] for i in range(len(pred_embeddings)): model_logger.debug(f"start ({i}nth time) visual_model.prompt_encoder to sparse/dense") ( sparse_embeddings, dense_embeddings, ) = self.model.visual_model.prompt_encoder( points=None, boxes=None, masks=None, text_embeds=pred_embeddings[i].unsqueeze(1), ) model_logger.debug(f"done ({i}nth) visual_model.prompt_encoder to sparse/dense, start sparse2sparse") sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) model_logger.debug(f"done ({i}nth) sparse2sparse, start visual_model.mask_decoder") low_res_masks, iou_predictions = self.model.visual_model.mask_decoder( image_embeddings=image_embeddings[i].unsqueeze(0), image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) model_logger.debug(f"done ({i}nth) visual_model.mask_decoder, start postprocess_masks") pred_mask = self.model.visual_model.postprocess_masks( low_res_masks, input_size=resize_list[i], original_size=original_size_list[i], ) model_logger.debug(f"done ({i}nth) postprocess_masks") pred_masks.append(pred_mask[:, 0]) model_logger.debug(f"env evaluate! ") return output_ids, pred_masks