|
from xtuner.dataset.utils import expand2square, load_image |
|
from xtuner.model.utils import prepare_inputs_labels_for_multimodal |
|
from xtuner.registry import BUILDER |
|
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, StopWordStoppingCriteria) |
|
from xtuner.dataset.utils import load_image |
|
from xtuner.engine.hooks import EvaluateChatHook |
|
|
|
import warnings |
|
import json |
|
import copy |
|
from distinctipy import distinctipy |
|
from pycocotools import mask |
|
from PIL import Image |
|
import cv2 |
|
import numpy as np |
|
from mmengine.utils.misc import get_object_from_string |
|
from mmengine.model import is_model_wrapper |
|
from transformers import GenerationConfig, StoppingCriteriaList |
|
from transformers import AutoConfig, AutoTokenizer |
|
import torch |
|
import torchvision.transforms as T |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
from ..dataset.process_functions import dynamic_preprocess |
|
from ..dataset.utils import VPT_CONTEXT_TOKEN, VPT_START_TOKEN, VPT_END_TOKEN |
|
from ..dataset.process_functions import contour_rendering |
|
|
|
|
|
class EvaluateChatHook_withSpecialTokens(EvaluateChatHook): |
|
|
|
priority = 'LOW' |
|
IMAGENET_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_STD = (0.229, 0.224, 0.225) |
|
|
|
def __init__(self, |
|
tokenizer, |
|
evaluation_inputs, |
|
evaluation_images=None, |
|
evaluation_vprompts=None, |
|
image_tokenize_config=None, |
|
image_processor=None, |
|
system='', |
|
prompt_template=None, |
|
every_n_iters=None, |
|
max_new_tokens=600, |
|
stop_word=None, |
|
stop_words=[], |
|
generation_kwargs={}): |
|
super().__init__(tokenizer, evaluation_inputs, evaluation_images, |
|
image_processor, system, prompt_template, every_n_iters, |
|
max_new_tokens, stop_word, stop_words, generation_kwargs) |
|
|
|
self.evaluation_inputs = evaluation_inputs |
|
if isinstance(self.evaluation_inputs, str): |
|
self.evaluation_inputs = [self.evaluation_inputs] |
|
self.evaluation_images = evaluation_images |
|
self.evaluation_merged_visual_prompts = evaluation_images |
|
if isinstance(self.evaluation_images, str): |
|
self.evaluation_images = [self.evaluation_images] |
|
self.evaluation_merged_visual_prompts = [self.evaluation_merged_visual_prompts] |
|
if self.evaluation_images is not None: |
|
assert len(self.evaluation_images) in [1, len(self.evaluation_inputs)] |
|
if len(self.evaluation_images) == 1: |
|
self.evaluation_images = [self.evaluation_images[0]] * len( |
|
self.evaluation_inputs) |
|
self.evaluation_images = [ |
|
load_image(img) for img in self.evaluation_images |
|
] |
|
self.evaluation_merged_visual_prompts = [ |
|
cv2.imread(img) for img in self.evaluation_merged_visual_prompts |
|
] |
|
self.evaluation_vprompts = evaluation_vprompts |
|
if isinstance(self.evaluation_vprompts, str): |
|
self.evaluation_vprompts = [self.evaluation_vprompts] |
|
if self.evaluation_vprompts is not None: |
|
assert len(self.evaluation_vprompts) in [1, len(self.evaluation_inputs)] |
|
if len(self.evaluation_vprompts) == 1: |
|
self.evaluation_vprompts = [self.evaluation_vprompts[0]] * len(self.evaluation_inputs) |
|
|
|
self.min_dynamic_patch = image_tokenize_config.min_dynamic_patch |
|
self.max_dynamic_patch = image_tokenize_config.max_dynamic_patch |
|
self.image_size = image_tokenize_config.force_image_size |
|
self.use_thumbnail = image_tokenize_config.use_thumbnail |
|
|
|
self.transform = T.Compose([ |
|
T.Lambda(lambda img: img.convert('RGB') |
|
if img.mode != 'RGB' else img), |
|
T.Resize((self.image_size, self.image_size)), |
|
T.ToTensor(), |
|
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD) |
|
]) |
|
self.vprompt_transform = T.Compose([ |
|
T.ToTensor(), |
|
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.NEAREST_EXACT), |
|
]) |
|
|
|
generation_config = dict( |
|
max_new_tokens=1024, do_sample=True, |
|
) |
|
self.generation_config = generation_config |
|
|
|
self.is_first_run = True |
|
|
|
self._add_special_tokens() |
|
|
|
def _add_special_tokens(self): |
|
special_tokens = [VPT_CONTEXT_TOKEN,] |
|
num_new_tokens = self.tokenizer.add_tokens(special_tokens, special_tokens=True) |
|
|
|
def decode_mask(self, object_masks, ori_height, ori_width): |
|
binary_masks = [] |
|
for object_mask in object_masks: |
|
if isinstance(object_mask, dict): |
|
if isinstance(object_mask["counts"], list): |
|
|
|
object_mask = mask.frPyObjects(object_mask, ori_height, ori_width) |
|
m = mask.decode(object_mask) |
|
m = m.astype(np.uint8).squeeze() |
|
elif object_mask: |
|
rles = mask.frPyObjects(object_mask, ori_height, ori_width) |
|
rle = mask.merge(rles) |
|
m = mask.decode(rle).astype(np.uint8).squeeze() |
|
else: |
|
m = np.zeros((ori_height, ori_width), dtype=np.uint8) |
|
binary_masks.append(m) |
|
masks = np.stack(binary_masks, axis=0) |
|
return masks |
|
|
|
def _eval_images(self, runner, model, device, max_new_tokens=None, save_eval_output=False): |
|
if save_eval_output: |
|
eval_outputs = [] |
|
|
|
for idx, (sample_image, sample_vprompt, sample_input) in enumerate( |
|
zip(self.evaluation_images, self.evaluation_vprompts, self.evaluation_inputs) |
|
): |
|
if isinstance(sample_input, str): |
|
sample_input = [sample_input] |
|
|
|
with open(sample_vprompt, 'r') as f: |
|
vprompt_data = json.load(f) |
|
|
|
ori_width, ori_height = sample_image.size |
|
|
|
annotations = [] |
|
for anno in vprompt_data['objects']: |
|
annotation = dict() |
|
annotation['bbox'] = anno['bbox'] |
|
annotation['segmentation'] = [np.array(anno['segmentation']).flatten().tolist()] |
|
annotations.append(annotation) |
|
segmentations = [anno['segmentation'] for anno in annotations] |
|
regions = self.decode_mask(segmentations, ori_height, ori_width) |
|
|
|
merged_visual_prompts = self.evaluation_merged_visual_prompts[idx] |
|
contour_rendering(merged_visual_prompts, regions) |
|
merged_visual_prompts = Image.fromarray(cv2.cvtColor(merged_visual_prompts, cv2.COLOR_BGR2RGB)) |
|
|
|
|
|
|
|
images, regions, merged_regions = dynamic_preprocess( |
|
sample_image, regions, merged_visual_prompts, |
|
min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch, |
|
image_size=self.image_size, use_thumbnail=self.use_thumbnail) |
|
|
|
|
|
pixel_values = [self.transform(image) for image in images] |
|
pixel_values = torch.stack(pixel_values).to(model.model.vision_model.dtype).to("cuda") |
|
|
|
merged_visual_prompts = [self.transform(merged_region) for merged_region in merged_regions] |
|
merged_visual_prompts = torch.stack(merged_visual_prompts).to(model.model.vision_model.dtype).to("cuda") |
|
|
|
num_patches_list = [pixel_values.shape[0],] |
|
|
|
responses = model.batch_chat( |
|
pixel_values, sample_input, merged_visual_prompts, |
|
copy.deepcopy(self.generation_config), num_patches_list=num_patches_list, |
|
) |
|
|
|
runner.logger.info(f'Sample output:\n' |
|
f'{sample_input[0] + responses[0]}\n') |
|
|
|
if save_eval_output: |
|
eval_outputs.append(f'{sample_input[0] + responses[0]}\n') |
|
|
|
if save_eval_output: |
|
self._save_eval_output(runner, eval_outputs) |
|
|
|
def _generate_samples(self, |
|
runner, |
|
max_new_tokens=None, |
|
save_eval_output=False): |
|
if max_new_tokens is None: |
|
max_new_tokens = self.max_new_tokens |
|
model = runner.model |
|
if is_model_wrapper(model): |
|
model = model.module |
|
|
|
device = next(iter(model.parameters())).device |
|
|
|
if self.is_first_run: |
|
|
|
|
|
model.to(device) |
|
self.is_first_run = False |
|
|
|
is_checkpointing = model.model.language_model.is_gradient_checkpointing |
|
use_cache = model.model.language_model.config.use_cache |
|
|
|
|
|
model.activation_checkpointing_disable() |
|
model.model.language_model.config.use_cache = True |
|
model.eval() |
|
if self.evaluation_images is not None: |
|
self._eval_images(runner, model, device, max_new_tokens, |
|
save_eval_output) |
|
else: |
|
self._eval_language(runner, model, device, max_new_tokens, |
|
save_eval_output) |
|
|
|
|
|
if is_checkpointing: |
|
model.activation_checkpointing_enable() |
|
model.model.language_model.config.use_cache = use_cache |
|
model.train() |