import warnings warnings.filterwarnings('ignore') import subprocess, io, os, sys, time os.system("pip install gradio==3.50.2") import gradio as gr from loguru import logger os.environ["CUDA_VISIBLE_DEVICES"] = "0" # if os.environ.get('IS_MY_DEBUG') is None: # result = subprocess.run(['pip', 'install', '-e', 'GroundingDINO'], check=True) # print(f'pip install GroundingDINO = {result}') logger.info(f"Start app...") result = subprocess.run(['pip', 'list'], check=True) print(f'pip list = {result}') sys.path.insert(0, './GroundingDINO') import argparse import copy import numpy as np import torch from PIL import Image, ImageDraw, ImageFont, ImageOps # Grounding DINO import GroundingDINO.groundingdino.datasets.transforms as T from GroundingDINO.groundingdino.models import build_model from GroundingDINO.groundingdino.util import box_ops from GroundingDINO.groundingdino.util.slconfig import SLConfig from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # I2SB import sys sys.path.insert(0, "/home/ubuntu/Thesis-Demo/I2SB") import numpy as np import torch import torch.distributed as dist import torchvision.transforms as transforms import torchvision.utils as tu from easydict import EasyDict as edict from fastapi import (Body, Depends, FastAPI, File, Form, HTTPException, Query, UploadFile) from ipdb import set_trace as debug from PIL import Image from torch.multiprocessing import Process from torch.utils.data import DataLoader, Subset from torch_ema import ExponentialMovingAverage import I2SB.distributed_util as dist_util from I2SB.corruption import build_corruption from I2SB.dataset import air_liquide from I2SB.i2sb import Runner, ckpt_util, download_ckpt from I2SB.logger import Logger from I2SB.sample import * import cv2 import numpy as np import matplotlib matplotlib.use('AGG') plt = matplotlib.pyplot # import matplotlib.pyplot as plt groundingdino_enable = True sam_enable = True inpainting_enable = True ram_enable = False lama_cleaner_enable = True kosmos_enable = False # qwen_enable = True # from qwen_utils import * if os.environ.get('IS_MY_DEBUG') is not None: sam_enable = False ram_enable = False inpainting_enable = False kosmos_enable = False if lama_cleaner_enable: try: from lama_cleaner.model_manager import ModelManager from lama_cleaner.schema import Config as lama_Config except Exception as e: lama_cleaner_enable = False # segment anything from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator # diffusers import PIL import requests import torch from io import BytesIO from diffusers import StableDiffusionInpaintPipeline from huggingface_hub import hf_hub_download from util_computer import computer_info # relate anything from ram_utils import iou, sort_and_deduplicate, relation_classes, MLP, show_anns, ram_show_mask from ram_train_eval import RamModel, RamPredictor from mmengine.config import Config as mmengine_Config if lama_cleaner_enable: from lama_cleaner.helper import ( load_img, numpy_to_bytes, resize_max_size, ) # from transformers import AutoProcessor, AutoModelForVision2Seq import ast if kosmos_enable: os.system("pip install transformers@git+https://github.com/huggingface/transformers.git@main") # os.system("pip install transformers==4.32.0") from kosmos_utils import * from util_tencent import getTextTrans config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py' ckpt_repo_id = "ShilongLiu/GroundingDINO" ckpt_filenmae = "groundingdino_swint_ogc.pth" sam_checkpoint = './sam_vit_h_4b8939.pth' output_dir = "outputs" device = 'cpu' os.makedirs(output_dir, exist_ok=True) groundingdino_model = None sam_device = None sam_model = None sam_predictor = None sam_mask_generator = None sd_model = None lama_cleaner_model= None ram_model = None kosmos_model = None kosmos_processor = None colors = [(255, 0, 0), (0, 255, 0)] markers = [1, 5] i2sb_opt = edict( distributed=False, device="cuda", batch_size=1, nfe=10, dataset="sample", dataset_dir=Path(f"dataset/sample"), n_gpu_per_node=1, use_fp16=False, ckpt="inpaint-freeform2030", image_size=256, partition=None, global_size=1, global_rank=0, clip_denoise=True ) i2sb_transforms = transforms.Compose([ transforms.Resize(i2sb_opt.image_size), transforms.CenterCrop(i2sb_opt.image_size), transforms.ToTensor(), transforms.Lambda(lambda t: (t * 2) - 1) # [0,1] --> [-1, 1] ]) def get_point(img, sel_pix, evt: gr.SelectData): img = np.array(img, dtype=np.uint8) sel_pix.append(evt.index) # draw points print(sel_pix) for point in sel_pix: cv2.drawMarker(img, point, colors[0], markerType=markers[0], markerSize=6, thickness=2) return Image.fromarray(img).convert("RGB") def undo_button(orig_img, sel_pix): temp = orig_img.copy() temp = np.array(temp, dtype=np.uint8) if len(sel_pix) != 0: sel_pix.pop() for point in sel_pix: cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2) return Image.fromarray(temp).convert("RGB") def clear_button(orig_img): return orig_img, [] def toggle_button(orig_img, task_type): print(task_type) if task_type == "segment": ret = gr.Image(value= orig_img,elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6) elif task_type == "inpainting": ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6) task_type = not task_type return ret, task_type def store_img(img): print("call for store") return img, [] # when new image is uploaded, `selected_points` should be empty def load_model_hf(model_config_path, repo_id, filename, device='cpu'): args = SLConfig.fromfile(model_config_path) model = build_model(args) args.device = device cache_file = hf_hub_download(repo_id=repo_id, filename=filename) checkpoint = torch.load(cache_file, map_location=device) log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) print("Model loaded from {} \n => {}".format(cache_file, log)) _ = model.eval() return model def load_i2sb_model(): RESULT_DIR = Path("I2SB/results") global i2sb_model global ckpt_opt global corrupt_type global nfe s = time.time() # main from here log = Logger(0, ".log") # get (default) ckpt option ckpt_opt = ckpt_util.build_ckpt_option(i2sb_opt, log, RESULT_DIR / i2sb_opt.ckpt) corrupt_type = ckpt_opt.corrupt nfe = i2sb_opt.nfe or ckpt_opt.interval-1 # build corruption method # corrupt_method = build_corruption(i2sb_opt, log, corrupt_type=cor # rupt_type) runner = Runner(ckpt_opt, log, save_opt=False) if i2sb_opt.use_fp16: runner.ema.copy_to() # copy weight from ema to net runner.net.diffusion_model.convert_to_fp16() runner.ema = ExponentialMovingAverage( runner.net.parameters(), decay=0.99) # re-init ema with fp16 weight print("Loading time:", (time.time()-s)*1e3, "ms.") i2sb_model = runner return runner def plot_boxes_to_image(image_pil, tgt): H, W = tgt["size"] boxes = tgt["boxes"] labels = tgt["labels"] assert len(boxes) == len(labels), "boxes and labels must have same length" draw = ImageDraw.Draw(image_pil) mask = Image.new("L", image_pil.size, 0) mask_draw = ImageDraw.Draw(mask) # draw boxes and masks for box, label in zip(boxes, labels): # from 0..1 to 0..W, 0..H box = box * torch.Tensor([W, H, W, H]) # from xywh to xyxy box[:2] -= box[2:] / 2 box[2:] += box[:2] # random color color = tuple(np.random.randint(0, 255, size=3).tolist()) # draw x0, y0, x1, y1 = box x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) draw.rectangle([x0, y0, x1, y1], outline=color, width=6) # draw.text((x0, y0), str(label), fill=color) font = ImageFont.load_default() if hasattr(font, "getbbox"): bbox = draw.textbbox((x0, y0), str(label), font) else: w, h = draw.textsize(str(label), font) bbox = (x0, y0, w + x0, y0 + h) # bbox = draw.textbbox((x0, y0), str(label)) draw.rectangle(bbox, fill=color) try: font = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf') font_size = 36 new_font = ImageFont.truetype(font, font_size) draw.text((x0+2, y0+2), str(label), font=new_font, fill="white") except Exception as e: pass mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6) return image_pil, mask def load_image(image_path): # # load image if isinstance(image_path, PIL.Image.Image): image_pil = image_path else: image_pil = Image.open(image_path).convert("RGB") # load image transform = T.Compose( [ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) image, _ = transform(image_pil, None) # 3, h, w return image_pil, image def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"): caption = caption.lower() caption = caption.strip() if not caption.endswith("."): caption = caption + "." model = model.to(device) image = image.to(device) with torch.no_grad(): outputs = model(image[None], captions=[caption]) logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) logits.shape[0] # filter output logits_filt = logits.clone() boxes_filt = boxes.clone() filt_mask = logits_filt.max(dim=1)[0] > box_threshold logits_filt = logits_filt[filt_mask] # num_filt, 256 boxes_filt = boxes_filt[filt_mask] # num_filt, 4 logits_filt.shape[0] # get phrase tokenlizer = model.tokenizer tokenized = tokenlizer(caption) # build pred pred_phrases = [] for logit, box in zip(logits_filt, boxes_filt): pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer) if with_logits: pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") else: pred_phrases.append(pred_phrase) return boxes_filt, pred_phrases def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_box(box, ax, label): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) ax.text(x0, y0, label) def xywh_to_xyxy(box, sizeW, sizeH): if isinstance(box, list): box = torch.Tensor(box) box = box * torch.Tensor([sizeW, sizeH, sizeW, sizeH]) box[:2] -= box[2:] / 2 box[2:] += box[:2] box = box.numpy() return box def mask_extend(img, box, extend_pixels=10, useRectangle=True): box[0] = int(box[0]) box[1] = int(box[1]) box[2] = int(box[2]) box[3] = int(box[3]) region = img.crop(tuple(box)) new_width = box[2] - box[0] + 2*extend_pixels new_height = box[3] - box[1] + 2*extend_pixels region_BILINEAR = region.resize((int(new_width), int(new_height))) if useRectangle: region_draw = ImageDraw.Draw(region_BILINEAR) region_draw.rectangle((0, 0, new_width, new_height), fill=(255, 255, 255)) img.paste(region_BILINEAR, (int(box[0]-extend_pixels), int(box[1]-extend_pixels))) return img def mix_masks(imgs): re_img = 1 - np.asarray(imgs[0].convert("1")) for i in range(len(imgs)-1): re_img = np.multiply(re_img, 1 - np.asarray(imgs[i+1].convert("1"))) re_img = 1 - re_img return Image.fromarray(np.uint8(255*re_img)) def set_device(args): global device if os.environ.get('IS_MY_DEBUG') is None: device = args.cuda if torch.cuda.is_available() else 'cpu' else: device = 'cpu' print(f'device={device}') def get_sam_vit_h_4b8939(): if not os.path.exists('./sam_vit_h_4b8939.pth'): logger.info(f"get sam_vit_h_4b8939.pth...") result = subprocess.run(['wget', '-nv', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True) print(f'wget sam_vit_h_4b8939.pth result = {result}') def load_sam_model(device): # initialize SAM global sam_model, sam_predictor, sam_mask_generator, sam_device get_sam_vit_h_4b8939() logger.info(f"initialize SAM model...") sam_device = device sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device) sam_predictor = SamPredictor(sam_model) sam_mask_generator = SamAutomaticMaskGenerator(sam_model) def load_sd_model(device): # initialize stable-diffusion-inpainting global sd_model logger.info(f"initialize stable-diffusion-inpainting...") sd_model = None if os.environ.get('IS_MY_DEBUG') is None: sd_model = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", revision="fp16", # "stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float16, ) sd_model = sd_model.to(device) def forward_i2sb(img, mask): print(np.unique(img),mask.shape) mask = np.where(mask > 0, 1, 0) img_tensor = i2sb_transforms(img).to( i2sb_opt.device).unsqueeze(0) mask_tensor = torch.from_numpy(np.resize(np.array(mask), (256,256))).to( i2sb_opt.device).unsqueeze(0).unsqueeze(0) print("POST PROCESSING\t", torch.unique(img_tensor)) # corrupt_tensor = img_tensor * (1. - mask_tensor) + mask_tensor f = time.time() xs, _ = i2sb_model.ddpm_sampling( ckpt_opt, img_tensor, mask=mask_tensor, cond=None, clip_denoise=i2sb_opt.clip_denoise, nfe=nfe, verbose=i2sb_opt.n_gpu_per_node == 1) recon_img = xs[:, 0, ...].to(i2sb_opt.device) tu.save_image((recon_img+1)/2, "output.png") print(recon_img.shape) return transforms.ToPILImage()(((recon_img+1)/2)[0]) def lama_cleaner_process(image, mask, cleaner_size_limit=1080): try: logger.info(f'_______lama_cleaner_process_______1____') ori_image = image if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]: # rotate image logger.info(f'_______lama_cleaner_process_______2____') ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...] logger.info(f'_______lama_cleaner_process_______3____') image = ori_image logger.info(f'_______lama_cleaner_process_______4____') original_shape = ori_image.shape logger.info(f'_______lama_cleaner_process_______5____') interpolation = cv2.INTER_CUBIC size_limit = cleaner_size_limit if size_limit == -1: logger.info(f'_______lama_cleaner_process_______6____') size_limit = max(image.shape) else: logger.info(f'_______lama_cleaner_process_______7____') size_limit = int(size_limit) logger.info(f'_______lama_cleaner_process_______8____') config = lama_Config( ldm_steps=25, ldm_sampler='plms', zits_wireframe=True, hd_strategy='Original', hd_strategy_crop_margin=196, hd_strategy_crop_trigger_size=1280, hd_strategy_resize_limit=2048, prompt='', use_croper=False, croper_x=0, croper_y=0, croper_height=512, croper_width=512, sd_mask_blur=5, sd_strength=0.75, sd_steps=50, sd_guidance_scale=7.5, sd_sampler='ddim', sd_seed=42, cv2_flag='INPAINT_NS', cv2_radius=5, ) logger.info(f'_______lama_cleaner_process_______9____') if config.sd_seed == -1: config.sd_seed = random.randint(1, 999999999) # logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}") logger.info(f'_______lama_cleaner_process_______10____') image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) # logger.info(f"Resized image shape_1_: {image.shape}") # logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}") logger.info(f'_______lama_cleaner_process_______11____') mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) # logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}") logger.info(f'_______lama_cleaner_process_______12____') res_np_img = lama_cleaner_model(image, mask, config) logger.info(f'_______lama_cleaner_process_______13____') torch.cuda.empty_cache() logger.info(f'_______lama_cleaner_process_______14____') image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png'))) logger.info(f'_______lama_cleaner_process_______15____') except Exception as e: logger.info(f'lama_cleaner_process[Error]:' + str(e)) image = None return image # visualization def draw_selected_mask(mask, draw): color = (255, 0, 0, 153) nonzero_coords = np.transpose(np.nonzero(mask)) for coord in nonzero_coords: draw.point(coord[::-1], fill=color) def draw_object_mask(mask, draw): color = (0, 0, 255, 153) nonzero_coords = np.transpose(np.nonzero(mask)) for coord in nonzero_coords: draw.point(coord[::-1], fill=color) def create_title_image(word1, word2, word3, width, font_path='./assets/OpenSans-Bold.ttf'): # Define the colors to use for each word color_red = (255, 0, 0) color_black = (0, 0, 0) color_blue = (0, 0, 255) # Define the initial font size and spacing between words font_size = 40 # Create a new image with the specified width and white background image = Image.new('RGB', (width, 60), (255, 255, 255)) try: # Load the specified font font = ImageFont.truetype(font_path, font_size) # Keep increasing the font size until all words fit within the desired width while True: # Create a draw object for the image draw = ImageDraw.Draw(image) word_spacing = font_size / 2 # Draw each word in the appropriate color x_offset = word_spacing draw.text((x_offset, 0), word1, color_red, font=font) x_offset += font.getsize(word1)[0] + word_spacing draw.text((x_offset, 0), word2, color_black, font=font) x_offset += font.getsize(word2)[0] + word_spacing draw.text((x_offset, 0), word3, color_blue, font=font) word_sizes = [font.getsize(word) for word in [word1, word2, word3]] total_width = sum([size[0] for size in word_sizes]) + word_spacing * 3 # Stop increasing font size if the image is within the desired width if total_width <= width: break # Increase font size and reset the draw object font_size -= 1 image = Image.new('RGB', (width, 50), (255, 255, 255)) font = ImageFont.truetype(font_path, font_size) draw = None except Exception as e: pass return image def concatenate_images_vertical(image1, image2): # Get the dimensions of the two images width1, height1 = image1.size width2, height2 = image2.size # Create a new image with the combined height and the maximum width new_image = Image.new('RGBA', (max(width1, width2), height1 + height2)) # Paste the first image at the top of the new image new_image.paste(image1, (0, 0)) # Paste the second image below the first image new_image.paste(image2, (0, height1)) return new_image mask_source_draw = "draw a mask on input image" mask_source_segment = "upload a mask" def get_time_cost(run_task_time, time_cost_str): now_time = int(time.time()*1000) if run_task_time == 0: time_cost_str = 'start' else: if time_cost_str != '': time_cost_str += f'-->' time_cost_str += f'{now_time - run_task_time}' run_task_time = now_time return run_task_time, time_cost_str def run_anything_task(input_image, input_points, origin_image, task_type, mask_source_radio, cleaner_size_limit=1080): run_task_time = 0 time_cost_str = '' run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) print("HERE................", task_type) if (task_type == 'Kosmos-2'): global kosmos_model, kosmos_processor if isinstance(input_image, dict): image_pil, image = load_image(input_image['image'].convert("RGB")) input_img = input_image['image'] else: image_pil, image = load_image(input_image.convert("RGB")) input_img = input_image kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(image_pil, kosmos_model, kosmos_processor) run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities if input_image is None: return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None file_temp = int(time.time()) logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/[{mask_source_radio}]_1_') output_images = [] # load image if isinstance(input_image, dict): image_pil, image = load_image(input_image['image'].convert("RGB")) input_img = input_image['image'] output_images.append(input_image['image']) run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) else: image_pil, image = load_image(input_image.convert("RGB")) input_img = input_image output_images.append(input_image) run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) size = image_pil.size H, W = size[1], size[0] # run grounding dino model if (task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw: pass else: groundingdino_device = 'cpu' logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_') if task_type == 'segment' or ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_segment): image = np.array(input_img) if sam_predictor: sam_predictor.set_image(image) if sam_predictor: logger.info(f"Forward with: {input_points}") masks, _, _, _ = sam_predictor.predict( point_coords = np.array(input_points), point_labels = np.array([1 for _ in range(len(input_points))]), # boxes = transformed_boxes, multimask_output = False, ) # masks: [9, 1, 512, 512] assert sam_checkpoint, 'sam_checkpoint is not found!' else: run_mode = "rectangle" # draw output image plt.figure(figsize=(10, 10)) plt.imshow(origin_image) for mask in masks: show_mask(mask, plt.gca(), random_color=True) # for box, label in zip(boxes_filt, pred_phrases): # show_box(box.cpu().numpy(), plt.gca(), label) plt.axis('off') image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg") plt.savefig(image_path, bbox_inches="tight") plt.clf() plt.close('all') segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) os.remove(image_path) output_images.append(Image.fromarray(segment_image_result)) run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_') if task_type == 'detection' or task_type == 'segment': logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_') return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None elif task_type in ['inpainting', 'outpainting'] or task_type == 'remove': if mask_source_radio == mask_source_segment: task_type = 'remove' logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_') if mask_source_radio == mask_source_draw: input_mask_pil = input_image['mask'] input_mask = np.array(input_mask_pil.convert("L")) mask_pil = input_mask_pil mask = input_mask else: masks_ori = copy.deepcopy(masks) masks = torch.where(masks > 0, True, False) mask = masks[0][0].cpu().numpy() mask_pil = Image.fromarray(mask) output_images.append(mask_pil.convert("RGB")) run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) if task_type in ['inpainting', 'outpainting']: # image_inpainting = sd_model(prompt = "", image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0] input_img.save("test.png") image_inpainting = forward_i2sb(input_img, mask) print("RESULT\t", np.array(image_inpainting)) else: # remove from mask aasds = 1 logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_') image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit) if image_inpainting is None: logger.info(f'run_anything_task_failed_') return None, None, None, None, None, None, None # output_images.append(image_inpainting) # run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_') image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1])) output_images.append(image_inpainting) run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_') return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None else: logger.info(f"task_type:{task_type} error!") logger.info(f'run_anything_task_[{file_temp}]_9_9_') return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None def change_radio_display(task_type, mask_source_radio, orig_img): mask_source_radio_visible = False num_relation_visible = False image_gallery_visible = True kosmos_input_visible = False kosmos_output_visible = False kosmos_text_output_visible = False print(task_type) if task_type == "Kosmos-2": if kosmos_enable: image_gallery_visible = False kosmos_input_visible = True kosmos_output_visible = True kosmos_text_output_visible = True if task_type in ['inpainting', 'outpainting'] or task_type == "remove": mask_source_radio_visible = True if task_type == "relate anything": num_relation_visible = True if task_type == "segment": ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6) elif task_type == "inpainting": ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6) return (gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible), gr.Gallery.update(visible=image_gallery_visible), gr.Radio.update(visible=kosmos_input_visible), gr.Image.update(visible=kosmos_output_visible), gr.HighlightedText.update(visible=kosmos_text_output_visible), ret, [], gr.Button("Undo point", visible = task_type == "segment"), gr.Button("Clear point", visible = task_type == "segment"),) def get_model_device(module): try: if module is None: return 'None' if isinstance(module, torch.nn.DataParallel): module = module.module for submodule in module.children(): if hasattr(submodule, "_parameters"): parameters = submodule._parameters if "weight" in parameters: return parameters["weight"].device return 'UnKnown' except Exception as e: return 'Error' def click_callback(coords): print("Clicked at here: ", coords) def main_gradio(args): block = gr.Blocks( title="Thesis-Demo", # theme="shivi/calm_seafoam@>=0.0.1,<1.0.0", ) with block: with gr.Row(): with gr.Column(): selected_points = gr.State([]) original_image = gr.State() task_types = ["segment"] if inpainting_enable: task_types.append("inpainting") input_image = gr.Image(elem_id="image_upload", type='pil', label="Upload", height=512) input_image.upload( store_img, [input_image], [original_image, selected_points] ) input_image.select( get_point, [input_image, selected_points], [input_image] ) with gr.Row(): with gr.Column(): undo_point_button = gr.Button("Undo point") undo_point_button.click( fn= undo_button, inputs=[original_image, selected_points], outputs=[input_image] ) with gr.Column(): clear_point_button = gr.Button("Clear point") clear_point_button.click( fn= clear_button, inputs=[original_image], outputs=[input_image, selected_points] ) print(dir(input_image)) task_type = gr.Radio(task_types, value="segment", label='Task type', visible=True) mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment], value=mask_source_draw, label="Mask from", visible=False) num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False) kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False) run_button = gr.Button(label="Run", visible=True) # with gr.Accordion("Advanced options", open=False) as advanced_options: # box_threshold = gr.Slider( # label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001 # ) # text_threshold = gr.Slider( # label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001 # ) # iou_threshold = gr.Slider( # label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001 # ) # inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode") # with gr.Row(): # with gr.Column(scale=1): # remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode') # with gr.Column(scale=1): # remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10') with gr.Column(): image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True ).style(preview=True, columns=[5], object_fit="scale-down", height="auto") time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False) kosmos_output = gr.Image(type="pil", label="result images", visible=False) kosmos_text_output = gr.HighlightedText( label="Generated Description", combine_adjacent=False, show_legend=True, visible=False, ).style(color_map=color_map) # record which text span (label) is selected selected = gr.Number(-1, show_label=False, placeholder="Selected", visible=False) # record the current `entities` entity_output = gr.Textbox(visible=False) # get the current selected span label def get_text_span_label(evt: gr.SelectData): if evt.value[-1] is None: return -1 return int(evt.value[-1]) # and set this information to `selected` kosmos_text_output.select(get_text_span_label, None, selected) # update output image when we change the span (enity) selection def update_output_image(img_input, image_output, entities, idx): entities = ast.literal_eval(entities) updated_image = draw_entity_boxes_on_image(img_input, entities, entity_index=idx) return updated_image selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output]) run_button.click(fn=run_anything_task, inputs=[ input_image, selected_points, original_image, task_type, mask_source_radio], outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True) mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image], outputs=[mask_source_radio, num_relation]) task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image], outputs=[mask_source_radio, num_relation, image_gallery, kosmos_input, kosmos_output, kosmos_text_output, input_image, selected_points, undo_point_button, clear_point_button ]) # DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything).
' # if lama_cleaner_enable: # DESCRIPTION += f'Remove(cleaner) from [lama-cleaner](https://github.com/Sanster/lama-cleaner).
' # if kosmos_enable: # DESCRIPTION += f'Kosmos-2 from [Kosmos-2](https://github.com/microsoft/unilm/tree/master/kosmos-2).
' # if ram_enable: # DESCRIPTION += f'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything).
' # DESCRIPTION += f'Thanks for their excellent work.' # DESCRIPTION += f'

For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. \ # Duplicate Space

' # gr.Markdown(DESCRIPTION) print(f'device = {device}') print(f'torch.cuda.is_available = {torch.cuda.is_available()}') computer_info() block.queue(max_size=10, api_open=False) block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share) if __name__ == "__main__": parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True) parser.add_argument("--debug", action="store_true", help="using debug mode") parser.add_argument("--share", action="store_true", help="share the app") parser.add_argument("--port", "-p", type=int, default=7860, help="port") parser.add_argument("--cuda", "-c", type=str, default='cuda:0', help="cuda") args, _ = parser.parse_known_args() print(f'args = {args}') # if os.environ.get('IS_MY_DEBUG') is None: # os.system("pip list") set_device(args) if device == 'cpu': kosmos_enable = False # if kosmos_enable: # kosmos_model, kosmos_processor = load_kosmos_model(device) # if groundingdino_enable: # load_groundingdino_model('cpu') if sam_enable: load_sam_model(device) if inpainting_enable: load_sd_model(device) load_i2sb_model() # if lama_cleaner_enable: # load_lama_cleaner_model(device) # if ram_enable: # load_ram_model(device) # if os.environ.get('IS_MY_DEBUG') is None: # os.system("pip list") main_gradio(args)