import torch from transformers import TextStreamer import webcolors import os import random from collections import Counter import numpy as np from torchvision import transforms from .magic_utils import get_colored_contour, find_different_colors, get_bounding_box_from_mask from .LLaVA.llava.conversation import conv_templates, SeparatorStyle from .LLaVA.llava.model.builder import load_pretrained_model from .LLaVA.llava.mm_utils import get_model_name_from_path, expand2square, tokenizer_image_token from .LLaVA.llava.constants import ( IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_PLACEHOLDER, ) import re class LLaVAModel: def __init__(self): current_dir = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(current_dir, "../models/llava-v1.5-7b-finetune-clean") self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( model_path=model_path, model_base=None, model_name=get_model_name_from_path(model_path), load_4bit=True ) def generate_description(self, images, question): qs = question image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN if IMAGE_PLACEHOLDER in qs: if self.model.config.mm_use_im_start_end: qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) else: qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) else: if self.model.config.mm_use_im_start_end: qs = image_token_se + "\n" + qs else: qs = DEFAULT_IMAGE_TOKEN + "\n" + qs images_tensor = [] image_sizes = [] to_pil = transforms.ToPILImage() for image in images: image = image.clone().permute(2, 0, 1).cpu() image = to_pil(image) image_sizes.append(image.size) image = expand2square(image, tuple(int(x) for x in self.image_processor.image_mean)) image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] images_tensor.append(image.half()) conv = conv_templates["llava_v1"].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = ( tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") .unsqueeze(0) .cuda() ) with torch.inference_mode(): output_ids = self.model.generate( input_ids, images=images_tensor, image_sizes=image_sizes, temperature=0.2, do_sample=True, use_cache=True, ) outputs = self.tokenizer.decode(output_ids[0]).strip() outputs = outputs.split('>')[1].split('<')[0] # print(outputs) return outputs def process(self, image, colored_image, add_mask): description = "" answer1 = "" answer2 = "" image_with_sketch = image.clone() if torch.sum(add_mask).item() > 0: x_min, y_min, x_max, y_max = get_bounding_box_from_mask(add_mask) # print(x_min, y_min, x_max, y_max) question = f"This is an 'I draw, you guess' game. I will upload an image containing some sketches. To help you locate the sketch, I will give you the normalized bounding box coordinates of the sketch where their original coordinates are divided by the image width and height. The top-left corner of the bounding box is at ({x_min}, {y_min}), and the bottom-right corner is at ({x_max}, {y_max}). Now tell me, what am I trying to draw with these sketches in the image?" # image_with_sketch[add_mask > 0.5] = 1.0 bool_add_mask = add_mask > 0.5 mean_brightness = image_with_sketch[bool_add_mask].mean() if mean_brightness > 0.8: image_with_sketch[bool_add_mask] = 0.0 else: image_with_sketch[bool_add_mask] = 1.0 answer1 = self.generate_description([image_with_sketch.squeeze() * 255], question) print(answer1) if not torch.equal(image, colored_image): color = find_different_colors(image.squeeze() * 255, colored_image.squeeze() * 255) image_with_bbox, colored_mask = get_colored_contour(colored_image.squeeze() * 255, image.squeeze() * 255) x_min, y_min, x_max, y_max = get_bounding_box_from_mask(colored_mask) question = f"The user will upload an image containing some contours in red color. To help you locate the contour, I will give you the normalized bounding box coordinates where their original coordinates are divided by the image width and height. The top-left corner of the bounding box is at ({x_min}, {y_min}), and the bottom-right corner is at ({x_max}, {y_max}). You need to identify what is inside the contours using a single word or phrase." answer2 = color + ', ' + self.generate_description([image_with_bbox.squeeze() * 255], question) print(answer2) return (description, answer1, answer2)