File size: 35,876 Bytes
032e687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 |
import copy
from collections import OrderedDict
import torch
import torch.nn as nn
from mmengine.config import Config, ConfigDict
from mmengine.model import BaseModel
from peft import get_peft_model, prepare_model_for_kbit_training
from xtuner.registry import BUILDER
from .modules import ProjectorConfig_OMG_LLaVA, ProjectorModel_OMG_LLaVA
from xtuner.model.modules import ProjectorModel, ProjectorConfig
from xtuner.model.modules import dispatch_modules
from .utils import (LoadWoInit, find_all_linear_names,
get_peft_model_state_dict, guess_load_checkpoint,
make_inputs_require_grad,
traverse_dict,
prepare_inputs_labels_for_multimodal_with_visual_prompts)
from .convnext_clip import OpenCLIPBackbone
from .omg_seg import OMGSegVisualEncoder
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
PROMPT_TEMPLATE)
from xtuner.tools.utils import get_stop_criteria, is_cn_string
from transformers import GenerationConfig
import torch.nn.functional as F
import numpy as np
from pycocotools import mask as _mask
class OMG_LLaVA(BaseModel):
def __init__(self,
llm,
visual_encoder,
visual_select_layer=-2,
freeze_llm=False,
freeze_visual_encoder=False,
require_omg_decoder=False,
pretrained_pth=None,
llm_lora=None,
visual_encoder_lora=None,
use_activation_checkpointing=True,
projector_depth=2,
text2vision_projector=False,
tokenizer=None,
keep_omg_decoder_frozen=False,
add_seg_pretrain=False,
additional_cross_attn_layers=False,
pixel_shuffle_ratio=None,
train_vocabulary=False,
freeze_llm_with_lora=False,
freeze_visual_projector=False,
rm_prior_embedding=False,
rm_query=False,
clip_feat_channel=1536,
):
super().__init__()
self.freeze_llm_with_lora = freeze_llm_with_lora
self.freeze_visual_projector = freeze_visual_projector
self.freeze_llm = freeze_llm
self.freeze_visual_encoder = freeze_visual_encoder
with LoadWoInit():
self.llm = self._build_from_cfg_or_module(llm)
if visual_encoder.type == OpenCLIPBackbone or visual_encoder.type == OMGSegVisualEncoder:
self.visual_encoder = visual_encoder.type(**visual_encoder)
else:
self.visual_encoder = self._build_from_cfg_or_module(
visual_encoder)
self.llm.config.use_cache = False
dispatch_modules(self.llm)
projector_config = ProjectorConfig_OMG_LLaVA(
query_channels=256,
feat_channels=clip_feat_channel,
llm_hidden_size=self.llm.config.hidden_size,
depth=projector_depth,
pixel_shuffle_ratio=pixel_shuffle_ratio,
)
self.projector = ProjectorModel_OMG_LLaVA(projector_config).to(
self.visual_encoder.dtype)
self.text2vision_projector = text2vision_projector
if text2vision_projector:
projector_config = ProjectorConfig(
visual_hidden_size=self.llm.config.hidden_size,
llm_hidden_size=256 * 2,
depth=projector_depth)
self.projector_text2vision = ProjectorModel(projector_config).to(
self.visual_encoder.dtype)
if rm_query:
self.projector.model.rm_query = rm_query
if rm_prior_embedding:
self.projector.model.rm_prior_embedding = rm_prior_embedding
if self.freeze_llm:
self.llm.requires_grad_(False)
if self.freeze_visual_encoder:
self.visual_encoder.requires_grad_(False)
self.use_activation_checkpointing = use_activation_checkpointing
if use_activation_checkpointing:
# For backward compatibility
if hasattr(self.llm, 'enable_input_require_grads'):
self.llm.enable_input_require_grads()
else:
self.llm.get_input_embeddings().register_forward_hook(
make_inputs_require_grad)
if hasattr(self.visual_encoder, 'enable_input_require_grads'):
self.visual_encoder.enable_input_require_grads()
else:
self.visual_encoder.get_input_embeddings(
).register_forward_hook(make_inputs_require_grad)
self.projector.enable_input_require_grads()
if text2vision_projector:
self.projector_text2vision.enable_input_require_grads()
# enable gradient (activation) checkpointing for memory efficiency
self.gradient_checkpointing_enable()
# resize input embed before add llm lora
self.added_special_token = False
if tokenizer is not None:
self.tokenizer = tokenizer
tokenizer_type = self.tokenizer['type']
del self.tokenizer['type']
self.tokenizer = tokenizer_type(**self.tokenizer)
self._add_special_tokens()
self.use_llm_lora = llm_lora is not None
self.use_visual_encoder_lora = visual_encoder_lora is not None
if self.use_llm_lora:
self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
if self.freeze_llm_with_lora:
for name, param in self.llm.named_parameters():
param.requires_grad_(False)
else:
if train_vocabulary:
# train vocabulary embedding and logit head when pretrain
for name, param in self.named_parameters():
if ('tok_' in name or 'embed_tokens' in name) or 'lm_head' in name:
print("Unfrozen {} !!!".format(name))
param.requires_grad_(True)
if ('output.' in name or 'lm_head' in name) and 'llm' in name and 'lora' not in name:
print("Unfrozen {} !!!".format(name))
param.requires_grad_(True)
if self.use_visual_encoder_lora:
self._prepare_visual_encoder_for_lora(
visual_encoder_lora, use_activation_checkpointing)
if pretrained_pth is not None:
pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
self.load_state_dict(pretrained_state_dict, strict=False)
print(f'Load pretrained weight from {pretrained_pth}')
self.visual_select_layer = visual_select_layer
self._is_init = True
self.require_omg_decoder = require_omg_decoder
if require_omg_decoder:
self.visual_encoder.init_new_decoder()
if keep_omg_decoder_frozen:
for name, param in self.visual_encoder.panoptic_head.transformer_decoder_llm.named_parameters():
param.requires_grad_(False)
print("Frozen all the omg seg decoder !!!")
self.additional_cross_attn_layers = additional_cross_attn_layers
if self.additional_cross_attn_layers:
self.visual_encoder.init_cross_attn_layer()
if self.freeze_visual_projector:
for name, param in self.projector.named_parameters():
param.requires_grad_(False)
self.add_seg_pretrain = add_seg_pretrain
self.init_prediction_config = False
def _add_special_tokens(self):
assert hasattr(self, "tokenizer")
segmentation_tokens = ['[SEG]']
# Adding tokens for GCG
phrase_tokens = ['<p>', '</p>']
# add for visual prompt
region_tokens = ['<region>']
point_tokens = ['<mark>']
special_tokens = segmentation_tokens + phrase_tokens + region_tokens
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.seg_token_idx = self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
self.bop_token_idx = self.tokenizer("<p>", add_special_tokens=False).input_ids[0]
self.eop_token_idx = self.tokenizer("</p>", add_special_tokens=False).input_ids[0]
self.region_token_idx = self.tokenizer("<region>", add_special_tokens=False).input_ids[0]
self.llm.resize_token_embeddings(len(self.tokenizer))
self.tokenizer.add_tokens(point_tokens, special_tokens=True)
self.mark_token_idx = self.tokenizer("<mark>", add_special_tokens=False).input_ids[0]
if self.use_activation_checkpointing or self.use_llm_lora or not self.freeze_llm:
self.llm.enable_input_require_grads()
self.added_special_token = True
print("[SEG]: {}, <p>: {}, </p>: {}, <region>: {}, <mark>: {}" \
.format(self.seg_token_idx, self.bop_token_idx,
self.eop_token_idx, self.region_token_idx, self.mark_token_idx))
print('****************************Add special tokens ********************************************')
return
def _parse_lora_config(self, lora_config):
if isinstance(lora_config, dict) or isinstance(
lora_config, Config) or isinstance(lora_config, ConfigDict):
lora_config = BUILDER.build(lora_config)
return lora_config
def _prepare_llm_for_lora(self,
lora_config,
use_activation_checkpointing=True):
lora_config = self._parse_lora_config(lora_config)
self.llm = prepare_model_for_kbit_training(
self.llm, use_activation_checkpointing)
if lora_config.target_modules is None:
modules = find_all_linear_names(self.llm)
lora_config.target_modules = modules
self.llm = get_peft_model(self.llm, lora_config)
for name, param in self.named_parameters():
if 'tok_' in name or 'lm_head' in name:
print("Unfrozen {} !!!".format(name))
param.requires_grad_(True)
if 'output.' in name and 'llm' in name and 'lora' not in name:
print("Unfrozen {} !!!".format(name))
param.requires_grad_(True)
def _prepare_visual_encoder_for_lora(self,
lora_config,
use_activation_checkpointing=True):
lora_config = self._parse_lora_config(lora_config)
if lora_config.target_modules is None:
modules = find_all_linear_names(self.visual_encoder)
lora_config.target_modules = modules
self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
def gradient_checkpointing_enable(self):
self.activation_checkpointing_enable()
def activation_checkpointing_enable(self):
self.llm.gradient_checkpointing_enable()
if hasattr(self.visual_encoder, 'gradient_checkpointing_enable'):
self.visual_encoder.gradient_checkpointing_enable()
elif hasattr(self.visual_encoder, 'clip_model'):
if self.visual_encoder.clip_model is not None:
self.visual_encoder.clip_model.gradient_checkpointing_enable()
if hasattr(self.projector, 'gradient_checkpointing_enable'):
self.projector.gradient_checkpointing_enable()
if self.text2vision_projector and hasattr(self.projector_text2vision, 'gradient_checkpointing_enable'):
self.projector_text2vision.gradient_checkpointing_enable()
def gradient_checkpointing_disable(self):
self.activation_checkpointing_disable()
def activation_checkpointing_disable(self):
self.llm.gradient_checkpointing_disable()
if hasattr(self.visual_encoder, 'gradient_checkpointing_disable'):
self.visual_encoder.gradient_checkpointing_disable()
if hasattr(self.projector, 'gradient_checkpointing_disable'):
self.projector.gradient_checkpointing_disable()
if self.text2vision_projector and hasattr(self.projector_text2vision, 'gradient_checkpointing_disable'):
self.projector_text2vision.gradient_checkpointing_disable()
def init_weights(self):
pass
def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
to_return = OrderedDict()
# vocabulary embedding
to_return.update(
{k: v for k, v in state_dict.items() if 'tok_' in k or 'embed_tokens' in k}
)
# logit head
to_return.update(
{k: v for k, v in state_dict.items() if ('output.' in k or 'lm_head' in k) and 'llm' in k and 'lora' not in k}
)
# Step 1. visual_encoder
if self.use_visual_encoder_lora:
to_return.update(
get_peft_model_state_dict(
self.visual_encoder, state_dict=state_dict))
elif not self.freeze_visual_encoder:
to_return.update({
k: v
for k, v in state_dict.items() if 'visual_encoder.' in k
})
# Step 2. LLM
if self.use_llm_lora:
to_return.update(
get_peft_model_state_dict(self.llm, state_dict=state_dict))
elif not self.freeze_llm:
to_return.update(
{k: v
for k, v in state_dict.items() if 'llm.' in k})
# Step 3. Projector
to_return.update(
{k: v
for k, v in state_dict.items() if 'projector.' in k})
# projector text2vision
to_return.update(
{k: v
for k, v in state_dict.items() if 'projector_text2vision' in k})
# visual_encoder.adapter_proj
if self.freeze_visual_encoder:
to_return.update(
{k: v
for k, v in state_dict.items() if 'visual_encoder.adapter_proj' in k})
# git_clip lora
if hasattr(self.visual_encoder, 'clip_model'):
if self.visual_encoder.clip_lora is not None:
to_return.update(
get_peft_model_state_dict(self.visual_encoder.clip_model,
state_dict=state_dict))
# omg decoder for llm
if self.require_omg_decoder:
to_return.update(
{k: v
for k, v in state_dict.items()
if 'visual_encoder.panoptic_head.transformer_decoder_llm' in k or
'visual_encoder.panoptic_head.mask_embed_llm' in k or
'visual_encoder.panoptic_head.pixel_decoder_llm' in k or
'visual_encoder.panoptic_head.additional_cross_attn_layers' in k or
'visual_encoder.panoptic_head.additional_ffn' in k or
'visual_encoder.downsample_layer' in k
})
return to_return
def _build_from_cfg_or_module(self, cfg_or_mod):
if isinstance(cfg_or_mod, nn.Module):
return cfg_or_mod
elif isinstance(cfg_or_mod, dict):
traverse_dict(cfg_or_mod)
return BUILDER.build(cfg_or_mod)
else:
raise NotImplementedError
def forward(self, data, data_samples=None, mode='loss'):
if 'pixel_values' in data:
if 'masks' in data:
masks = data['masks']
del data['masks']
else:
masks = None
if 'regions' in data:
regions = data['regions']
del data['regions']
else:
regions = None
if 'points' in data:
points = data['points']
del data['points']
else:
points = None
visual_outputs = self.visual_encoder(
data['pixel_values'].to(self.visual_encoder.dtype),
output_hidden_states=True)
if self.add_seg_pretrain:
pred_obj_query, gt_obj_query = prepare_seg_pretrain_data(
visual_outputs,
[self.projector.model.query_proj, self.projector.model.model],
self.projector_text2vision.model
)
if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\
or isinstance(visual_outputs, torch.Tensor):
pixel_values = self.projector(visual_outputs)
else:
pixel_values = self.projector(
visual_outputs.hidden_states[self.visual_select_layer][:, 1:])
if regions is not None:
region_embeddings, region_success = self.get_region_embeddings(
regions, data['input_ids'],
)
none_region_embeddings = region_embeddings
del regions
else:
region_success = True
region_embeddings = []
none_region_embeddings = self.get_none_region_embeddings(
input_ids=data['input_ids'],
)
if points is not None:
points_mark_embedding, mark_success = self.get_points_embeddings(
points, data['input_ids'],
width=data['pixel_values'].shape[-1],
height=data['pixel_values'].shape[-2],
)
none_points_mark_embedding = points_mark_embedding
else:
none_points_mark_embedding = self.get_none_points_embeddings(
data['input_ids'],
width=data['pixel_values'].shape[-1],
height=data['pixel_values'].shape[-2],
)
points_mark_embedding = []
mark_success = True
data['pixel_values'] = pixel_values
data = prepare_inputs_labels_for_multimodal_with_visual_prompts(
llm=self.llm, region_id=self.region_token_idx,
regions_feats=region_embeddings,
mark_id=self.mark_token_idx,
mark_feats=points_mark_embedding,
**data)
else:
masks = None
_zero = none_points_mark_embedding.sum() * 0.0 + none_region_embeddings.sum() * 0.0
if mode == 'loss':
if self.add_seg_pretrain:
return self.compute_loss(data, data_samples, masks=masks, region_success=region_success,
pred_gt_obj_query=(pred_obj_query, gt_obj_query),
mark_success=mark_success, _zero=_zero)
else:
return self.compute_loss(data, data_samples, masks=masks,
pred_gt_obj_query=None,
region_success=region_success,
mark_success=mark_success,
_zero=_zero)
elif mode == 'predict':
return self.predict(data, data_samples)
elif mode == 'tensor':
return self._forward(data, data_samples)
else:
raise NotImplementedError
def _forward(self, data, data_samples=None):
outputs = self.llm(**data)
return outputs
def predict(self, data, data_samples=None):
outputs = self.llm(**data)
logits_dict = [{'logits': logits} for logits in outputs.logits]
return logits_dict
def compute_loss(self, data, data_samples=None, masks=None, pred_gt_obj_query=None,
region_success=True, mark_success=True, _zero=0):
if 'original_labels' in data.keys():
input_ids = data['original_labels']
del data['original_labels']
else:
input_ids = data['labels']
outputs = self.llm(**data, output_hidden_states=True)
loss_dice, loss_mask = self.compute_seg_loss(
input_ids, outputs.hidden_states[-1], masks)
if pred_gt_obj_query is not None:
pred_obj_query, gt_obj_query = pred_gt_obj_query
proj_loss = torch.mean((pred_obj_query - gt_obj_query) ** 2) * 10
else:
proj_loss = 0
if not region_success:
loss = outputs.loss * 0
else:
loss = outputs.loss
if not mark_success:
loss = outputs.loss * 0
# loss = loss + self.get_visual_prompts_projector_zero() + _zero
loss = loss + _zero
loss_dict = {'loss': loss, 'loss_dice': outputs.loss* 0 + loss_dice * 0.1,
'loss_mask': outputs.loss * 0 + loss_mask * 0.4,
'loss_proj': outputs.loss * 0 + proj_loss}
return loss_dict
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.llm, name)
def get_region_embeddings(self, regions, input_ids):
success = True
if regions is None or len(regions) == 0:
return [], success
else:
region_token_mask = input_ids == self.region_token_idx
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(
input_ids.device)
batch_idxs = batch_idxs[region_token_mask] # (N, ) batch_size number
if len(regions) != len(batch_idxs):
# There is a bug !!! skip it.
success = False
if len(regions) > len(batch_idxs):
regions = regions[:len(batch_idxs)]
else:
n_pad = len(batch_idxs) - len(regions)
pad_region = regions[:1].repeat(n_pad, 1, 1)
regions = torch.cat([pad_region, regions])
regions_embeddings = self.visual_encoder.forward_region_sam(
regions, batch_idxs
)[:, 0] # (N, C)
regions_embeddings = self.projector.model.forward_visual_prompts_embeddings(
regions_embeddings, batch_idxs)
return regions_embeddings, success # (N, C)
def get_none_region_embeddings(self, input_ids):
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(
input_ids.device)
batch_idxs = batch_idxs[0, :1] # (N, ) batch_size number
regions = torch.ones((1, 50, 50)).to(torch.float32).to(input_ids.device)
regions_embeddings = self.visual_encoder.forward_region_sam(
regions, batch_idxs
)[:, 0] # (N, C)
regions_embeddings = self.projector.model.forward_visual_prompts_embeddings(
regions_embeddings, batch_idxs)
return regions_embeddings
def get_points_embeddings(self, points, input_ids, width, height):
success = True
if points is None or len(points) == 0:
return []
mark_token_mask = input_ids == self.mark_token_idx
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(
input_ids.device)
batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number
if len(points) != len(batch_idxs):
# There is a bug !!! skip it.
success = False
if len(points) > len(batch_idxs):
points = points[:len(batch_idxs)]
else:
n_pad = len(batch_idxs) - len(points)
pad_region = points[:1].repeat(n_pad, 1, 1)
points = torch.cat([pad_region, points])
marks_embeddings = self.visual_encoder.forward_point_sam(
points, batch_idxs, width=width, height=height
)[:, 0] # (N, C)
marks_embeddings = self.projector.model.forward_visual_prompts_embeddings(
marks_embeddings, batch_idxs)
return marks_embeddings, success # (N, C)
def get_none_points_embeddings(self, input_ids, width, height):
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(
input_ids.device)
batch_idxs = batch_idxs[0, :1] # (N, ) batch_size number
marks_embeddings = self.visual_encoder.forward_point_sam(
torch.zeros((1, 2)).to(input_ids), batch_idxs, width=width, height=height
)[:, 0] # (N, C)
marks_embeddings = self.projector.model.forward_visual_prompts_embeddings(
marks_embeddings, batch_idxs)
return marks_embeddings # (N, C)
def get_visual_prompts_projector_zero(self):
return self.projector.model.visual_prompt_zero
def compute_seg_loss(self, input_ids, hidden_states, gt_masks):
if not self.text2vision_projector or self.add_seg_pretrain:
return 0.0, 0.0
success = True
if gt_masks is None or len(gt_masks) == 0:
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(
input_ids.device)
batch_idxs = batch_idxs[0, :1] # (N, ) batch_size number
gt_masks = [None]
hidden_states = hidden_states[0, :1]
hidden_states = self.projector_text2vision(hidden_states) # (N, C)
pred_masks_list = self.visual_encoder.forward_llm_seg(hidden_states, batch_idxs)
dice_loss, mask_loss = self.visual_encoder.loss_llm_seg(pred_masks_list, gt_masks)
return dice_loss * 0.0, mask_loss * 0.0
seg_tokens_mask = input_ids == self.seg_token_idx
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(seg_tokens_mask.device)
ori_hidden_states = hidden_states
hidden_states = hidden_states[seg_tokens_mask]
batch_idxs = batch_idxs[seg_tokens_mask] # (N, ) batch_size number
if len(hidden_states) != len(gt_masks) or len(hidden_states) == 0:
# drop this batch
print("Drop the batch because the number of [SEG] and masks not equal !!!")
hidden_states = ori_hidden_states
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(
input_ids.device)
batch_idxs = batch_idxs[0, :1] # (N, ) batch_size number
gt_masks = [None]
hidden_states = hidden_states[0, :1]
hidden_states = self.projector_text2vision(hidden_states) # (N, C)
pred_masks_list = self.visual_encoder.forward_llm_seg(hidden_states, batch_idxs)
dice_loss, mask_loss = self.visual_encoder.loss_llm_seg(pred_masks_list, gt_masks)
return dice_loss * 0.0, mask_loss * 0.0
assert len(hidden_states) == len(gt_masks), "expect [seg] number equal to mask number, but get {} [seg] and {} masks".format(len(hidden_states), len(gt_masks))
hidden_states = self.projector_text2vision(hidden_states) # (N, C)
pred_masks_list = self.visual_encoder.forward_llm_seg(hidden_states, batch_idxs)
dice_loss, mask_loss = self.visual_encoder.loss_llm_seg(pred_masks_list, gt_masks)
if not success:
return dice_loss * 0.0, mask_loss * 0.0
return dice_loss, mask_loss
def preparing_for_generation(self, metainfo, **kwargs):
# set stop criteria and generation configs for model
assert hasattr(self, 'tokenizer'), "The Model does not have the tokenizer!!!"
self.bot_name = 'BOT'
if 'template' in metainfo.keys():
template = metainfo['template']
else:
template = PROMPT_TEMPLATE['internlm2_chat']
self.template = template
stop_words = []
stop_words += template.get('STOP_WORDS', [])
stop_criteria = get_stop_criteria(
tokenizer=self.tokenizer, stop_words=stop_words)
self.stop_criteria = stop_criteria
default_generation_kwargs = dict(
max_new_tokens=2048,
do_sample=False,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=(
self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None
else self.tokenizer.eos_token_id
),
)
default_generation_kwargs.update(metainfo.get('generation_kwargs', {}))
self.gen_config = GenerationConfig(**default_generation_kwargs)
self.init_prediction_config = True
self.llm.to(self.visual_encoder.dtype)
self.visual_encoder.to(self.visual_encoder.dtype)
self.projector.to(self.visual_encoder.dtype)
self.projector_text2vision.to(self.visual_encoder.dtype)
return
def predict_forward(
self, pixel_values, text_prompts,
ori_image_size=None,
box_prompts=None, points_prompts=None, mask_prompts=None, **kwargs):
# pixel_values: image tensor
# text_prompts: question without template
assert self.init_prediction_config, "Please set prediction configs using self.preparing_for_generation()"
ret_predictions = []
ret_masks = []
image = pixel_values.cuda().unsqueeze(0).to(self.visual_encoder.dtype)
visual_outputs = self.visual_encoder(image, output_hidden_states=True)
if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple) \
or isinstance(visual_outputs, torch.Tensor):
pixel_values = self.projector(visual_outputs)
else:
pixel_values = self.projector(
visual_outputs.hidden_states[self.visual_select_layer][:, 1:])
if isinstance(text_prompts, str):
text_prompts = [text_prompts]
for text_prompt in text_prompts:
# add template for text
input_text = ''
input_text += self.template['INSTRUCTION'].format(
input=text_prompt, round=1, bot_name=self.bot_name)
chunk_encode = []
for idx, chunk in enumerate(input_text.split(DEFAULT_IMAGE_TOKEN)):
if idx == 0:
cur_encode = self.tokenizer.encode(chunk)
else:
cur_encode = self.tokenizer.encode(chunk, add_special_tokens=False)
chunk_encode.append(cur_encode)
assert len(chunk_encode) == 2
ids = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
ids.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
ids.append(IMAGE_TOKEN_INDEX)
ids = torch.tensor(ids).cuda().unsqueeze(0)
mm_inputs = prepare_inputs_labels_for_multimodal_with_visual_prompts(
llm=self.llm, input_ids=ids, pixel_values=pixel_values,
region_id=self.region_token_idx,
regions_feats=[],
mark_id=self.mark_token_idx,
mark_feats=[],
)
generate_output = self.llm.generate(
**mm_inputs,
generation_config=self.gen_config,
streamer=None,
bos_token_id=self.tokenizer.bos_token_id,
stopping_criteria=self.stop_criteria,
output_hidden_states=True,
return_dict_in_generate=True
)
predict = self.tokenizer.decode(
generate_output.sequences[0], skip_special_tokens=True).strip()
ret_predictions.append(predict)
if ori_image_size is not None and 'masks' in kwargs.keys():
hidden_states = generate_output.hidden_states
last_hidden_states = [item[-1][0] for item in hidden_states]
last_hidden_states = torch.cat(last_hidden_states, dim=0)
seg_hidden_states = get_seg_hidden_states(
last_hidden_states, generate_output.sequences[0][:-1],
seg_id=self.seg_token_idx
)
if len(seg_hidden_states) == 0:
print("Warning, no [SEG] tokens !!!")
ret_masks.append(None)
continue
elif len(seg_hidden_states) > 1:
print("Warning, {} [SEG] tokens !!!".format(len(seg_hidden_states)))
seg_hidden_states = seg_hidden_states[:1]
seg_hidden_states = self.projector_text2vision(seg_hidden_states)
batch_idxs = torch.zeros((seg_hidden_states.shape[0],),
dtype=torch.int64).to(seg_hidden_states.device)
pred_masks_list = self.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs)
pred_masks = pred_masks_list[-1]
w, h = copy.deepcopy(ori_image_size)
masks = F.interpolate(pred_masks, size=(max(w, h), max(w, h)),
mode='bilinear', align_corners=False)
masks = masks[:, 0]
# remove padding
if w == h:
pass
elif w > h:
n_pad = w - h
n_pad_1 = n_pad // 2
n_pad_2 = n_pad - n_pad_1
masks = masks[:, n_pad_1: w - n_pad_2]
else:
n_pad = h - w
n_pad_1 = n_pad // 2
n_pad_2 = n_pad - n_pad_1
masks = masks[:, :, n_pad_1: h - n_pad_2]
# binary
masks = masks.sigmoid() > 0.5
masks = masks.int()
ret_masks.append(masks)
if len(ret_predictions) == 1:
ret_predictions = ret_predictions[0]
if len(ret_masks) == 0:
return {'prediction': ret_predictions}
_ret_masks = []
for i, ret_mask in enumerate(ret_masks):
if ret_mask is None:
_ret_masks.append(None)
else:
ret_mask = ret_mask.cpu().numpy()
_ret_masks.append(mask_to_rle(ret_mask))
if 'masks' not in kwargs.keys():
gt_masks = None
else:
gt_masks = mask_to_rle(kwargs['masks'].cpu().numpy())
return {
'prediction': ret_predictions, 'prediction_masks': _ret_masks,
'gt_masks': gt_masks,
}
def prepare_seg_pretrain_data(visual_outputs,
query_in_proj, query_out_proj):
clip_feature, query_feat, attention_mask = visual_outputs
# clip feature (bs, hw, c + 2 * q_c)
# query_feat (bs, q, 2c)
# attention_mask (bs, q, hw)
bs, q, _ = query_feat.shape
pred_query_embed = []
gt_query_embed = []
for i in range(bs):
valid = attention_mask[i].sum(-1) > 0
valid_query_feat = query_feat[i][valid] # (n, 2c)
gt_query_embed.append(valid_query_feat)
if isinstance(query_in_proj, list):
llm_query = valid_query_feat
for proj in query_in_proj:
llm_query = proj(llm_query)
else:
llm_query = query_in_proj(valid_query_feat)
pred_query_embed.append(query_out_proj(llm_query))
pred_query_embed = torch.cat(pred_query_embed, dim=0)
gt_query_embed = torch.cat(gt_query_embed, dim=0)
return pred_query_embed, gt_query_embed
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
seg_mask = output_ids == seg_id
n_out = len(seg_mask)
return hidden_states[-n_out:][seg_mask]
def mask_to_rle(mask):
rle = []
for m in mask:
rle.append(_mask.encode(np.asfortranarray(m.astype(np.uint8))))
return rle
|