import warnings warnings.filterwarnings('ignore') import subprocess, io, os, sys, time os.system("pip install gradio==3.50.2") import gradio as gr from loguru import logger os.environ["CUDA_VISIBLE_DEVICES"] = "0" # if os.environ.get('IS_MY_DEBUG') is None: # result = subprocess.run(['pip', 'install', '-e', 'GroundingDINO'], check=True) # print(f'pip install GroundingDINO = {result}') logger.info(f"Start app...") result = subprocess.run(['pip', 'list'], check=True) print(f'pip list = {result}') sys.path.insert(0, './GroundingDINO') import argparse import copy import numpy as np import torch from PIL import Image, ImageDraw, ImageFont, ImageOps # Grounding DINO import GroundingDINO.groundingdino.datasets.transforms as T from GroundingDINO.groundingdino.models import build_model from GroundingDINO.groundingdino.util import box_ops from GroundingDINO.groundingdino.util.slconfig import SLConfig from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # I2SB import sys sys.path.insert(0, "/home/ubuntu/Thesis-Demo/I2SB") sys.path.insert(0, "/home/ubuntu/Thesis-Demo/SegFormer") import numpy as np import torch import torch.distributed as dist import torchvision.transforms as transforms import torchvision.utils as tu from easydict import EasyDict as edict from fastapi import (Body, Depends, FastAPI, File, Form, HTTPException, Query, UploadFile) from ipdb import set_trace as debug from PIL import Image from torch.multiprocessing import Process from torch.utils.data import DataLoader, Subset from torch_ema import ExponentialMovingAverage import I2SB.distributed_util as dist_util from I2SB.corruption import build_corruption from I2SB.dataset import air_liquide from I2SB.i2sb import Runner, ckpt_util, download_ckpt from I2SB.logger import Logger from I2SB.sample import * from pathlib import Path inpaint_checkpoint = Path("/home/ubuntu/Thesis-Demo/I2SB/results") if not inpaint_checkpoint.exists(): os.system("pip install transformers==4.32.0") # SegFormer from PIL import Image from SegFormer.mmseg.apis import inference_segmentor, init_segmentor, visualize_result_pyplot from SegFormer.mmseg.core.evaluation import get_palette import cv2 import numpy as np import matplotlib matplotlib.use('AGG') plt = matplotlib.pyplot # import matplotlib.pyplot as plt groundingdino_enable = True sam_enable = True inpainting_enable = True ram_enable = False lama_cleaner_enable = True kosmos_enable = False # qwen_enable = True # from qwen_utils import * if os.environ.get('IS_MY_DEBUG') is not None: sam_enable = False ram_enable = False inpainting_enable = False kosmos_enable = False # segment anything from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator # diffusers import PIL import requests import torch from io import BytesIO from diffusers import StableDiffusionInpaintPipeline from huggingface_hub import hf_hub_download from util_computer import computer_info # relate anything from ram_utils import iou, sort_and_deduplicate, relation_classes, MLP, show_anns, ram_show_mask from ram_train_eval import RamModel, RamPredictor from mmengine.config import Config as mmengine_Config if lama_cleaner_enable: from lama_cleaner.helper import ( load_img, numpy_to_bytes, resize_max_size, ) # from transformers import AutoProcessor, AutoModelForVision2Seq import ast if kosmos_enable: os.system("pip install transformers@git+https://github.com/huggingface/transformers.git@main") # os.system("pip install transformers==4.32.0") from kosmos_utils import * from util_tencent import getTextTrans config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py' ckpt_repo_id = "ShilongLiu/GroundingDINO" ckpt_filenmae = "groundingdino_swint_ogc.pth" sam_checkpoint = './sam_vit_h_4b8939.pth' output_dir = "outputs" device = 'cpu' os.makedirs(output_dir, exist_ok=True) groundingdino_model = None sam_device = None sam_model = None sam_predictor = None sam_mask_generator = None sd_model = None lama_cleaner_model= None ram_model = None kosmos_model = None kosmos_processor = None colors = [(255, 0, 0), (0, 255, 0)] markers = [1, 5] i2sb_opt = edict( distributed=False, device="cuda", batch_size=1, nfe=10, dataset="sample", dataset_dir=Path(f"dataset/sample"), n_gpu_per_node=1, use_fp16=False, ckpt="inpaint-freeform2030", image_size=256, partition=None, global_size=1, global_rank=0, clip_denoise=True ) i2sb_transforms = transforms.Compose([ transforms.Resize(i2sb_opt.image_size), transforms.CenterCrop(i2sb_opt.image_size), transforms.ToTensor(), transforms.Lambda(lambda t: (t * 2) - 1) # [0,1] --> [-1, 1] ]) def get_point(img, sel_pix, evt: gr.SelectData): img = np.array(img, dtype=np.uint8) sel_pix.append(evt.index) # draw points print(sel_pix) for point in sel_pix: cv2.drawMarker(img, point, colors[0], markerType=markers[0], markerSize=6, thickness=2) return Image.fromarray(img).convert("RGB") def undo_button(orig_img, sel_pix): if orig_img: temp = orig_img.copy() temp = np.array(temp, dtype=np.uint8) if len(sel_pix) != 0: sel_pix.pop() for point in sel_pix: cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2) return Image.fromarray(temp).convert("RGB") return orig_img def clear_button(orig_img): return orig_img, [] def toggle_button(orig_img, task_type): print(task_type) if task_type == "segment": ret = gr.Image(value= orig_img,elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6) elif task_type == "inpainting": ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6) task_type = not task_type return ret, task_type def store_img(img): print("call for store") return img, [] # when new image is uploaded, `selected_points` should be empty def load_model_hf(model_config_path, repo_id, filename, device='cpu'): args = SLConfig.fromfile(model_config_path) model = build_model(args) args.device = device cache_file = hf_hub_download(repo_id=repo_id, filename=filename) checkpoint = torch.load(cache_file, map_location=device) log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) print("Model loaded from {} \n => {}".format(cache_file, log)) _ = model.eval() return model def load_i2sb_model(): RESULT_DIR = Path("I2SB/results") global i2sb_model global ckpt_opt global corrupt_type global nfe s = time.time() # main from here log = Logger(0, ".log") # get (default) ckpt option ckpt_opt = ckpt_util.build_ckpt_option(i2sb_opt, log, RESULT_DIR / i2sb_opt.ckpt) corrupt_type = ckpt_opt.corrupt nfe = i2sb_opt.nfe or ckpt_opt.interval-1 # build corruption method # corrupt_method = build_corruption(i2sb_opt, log, corrupt_type=cor # rupt_type) runner = Runner(ckpt_opt, log, save_opt=False) if i2sb_opt.use_fp16: runner.ema.copy_to() # copy weight from ema to net runner.net.diffusion_model.convert_to_fp16() runner.ema = ExponentialMovingAverage( runner.net.parameters(), decay=0.99) # re-init ema with fp16 weight logger.info(f"I2SB Loading time:\t {(time.time()-s)*1e3} ms.") print("Loading time:", (time.time()-s)*1e3, "ms.") i2sb_model = runner return runner def load_segformer(device): global segformer_model s = time.time() config = "SegFormer/local_configs/segformer/B3/segformer.b3.256x256.wtm.160k.py" checkpoint = "SegFormer/work_dirs/segformer.b3.256x256.wtm.160k/iter_160000.pth" model = init_segmentor(config, checkpoint, device=device) logger.info(f"SegFormer Loading time:\t {(time.time()-s)*1e3} ms.") segformer_model = model return model def plot_boxes_to_image(image_pil, tgt): H, W = tgt["size"] boxes = tgt["boxes"] labels = tgt["labels"] assert len(boxes) == len(labels), "boxes and labels must have same length" draw = ImageDraw.Draw(image_pil) mask = Image.new("L", image_pil.size, 0) mask_draw = ImageDraw.Draw(mask) # draw boxes and masks for box, label in zip(boxes, labels): # from 0..1 to 0..W, 0..H box = box * torch.Tensor([W, H, W, H]) # from xywh to xyxy box[:2] -= box[2:] / 2 box[2:] += box[:2] # random color color = tuple(np.random.randint(0, 255, size=3).tolist()) # draw x0, y0, x1, y1 = box x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) draw.rectangle([x0, y0, x1, y1], outline=color, width=6) # draw.text((x0, y0), str(label), fill=color) font = ImageFont.load_default() if hasattr(font, "getbbox"): bbox = draw.textbbox((x0, y0), str(label), font) else: w, h = draw.textsize(str(label), font) bbox = (x0, y0, w + x0, y0 + h) # bbox = draw.textbbox((x0, y0), str(label)) draw.rectangle(bbox, fill=color) try: font = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf') font_size = 36 new_font = ImageFont.truetype(font, font_size) draw.text((x0+2, y0+2), str(label), font=new_font, fill="white") except Exception as e: pass mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6) return image_pil, mask def load_image(image_path): # # load image if isinstance(image_path, PIL.Image.Image): image_pil = image_path else: image_pil = Image.open(image_path).convert("RGB") # load image transform = T.Compose( [ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) image, _ = transform(image_pil, None) # 3, h, w return image_pil, image def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_box(box, ax, label): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) ax.text(x0, y0, label) def xywh_to_xyxy(box, sizeW, sizeH): if isinstance(box, list): box = torch.Tensor(box) box = box * torch.Tensor([sizeW, sizeH, sizeW, sizeH]) box[:2] -= box[2:] / 2 box[2:] += box[:2] box = box.numpy() return box def mask_extend(img, box, extend_pixels=10, useRectangle=True): box[0] = int(box[0]) box[1] = int(box[1]) box[2] = int(box[2]) box[3] = int(box[3]) region = img.crop(tuple(box)) new_width = box[2] - box[0] + 2*extend_pixels new_height = box[3] - box[1] + 2*extend_pixels region_BILINEAR = region.resize((int(new_width), int(new_height))) if useRectangle: region_draw = ImageDraw.Draw(region_BILINEAR) region_draw.rectangle((0, 0, new_width, new_height), fill=(255, 255, 255)) img.paste(region_BILINEAR, (int(box[0]-extend_pixels), int(box[1]-extend_pixels))) return img def mix_masks(imgs): re_img = 1 - np.asarray(imgs[0].convert("1")) for i in range(len(imgs)-1): re_img = np.multiply(re_img, 1 - np.asarray(imgs[i+1].convert("1"))) re_img = 1 - re_img return Image.fromarray(np.uint8(255*re_img)) def set_device(args): global device if os.environ.get('IS_MY_DEBUG') is None: device = args.cuda if torch.cuda.is_available() else 'cpu' else: device = 'cpu' print(f'device={device}') def get_sam_vit_h_4b8939(): if not os.path.exists('./sam_vit_h_4b8939.pth'): logger.info(f"get sam_vit_h_4b8939.pth...") result = subprocess.run(['wget', '-nv', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True) print(f'wget sam_vit_h_4b8939.pth result = {result}') def load_sam_model(device): # initialize SAM global sam_model, sam_predictor, sam_mask_generator, sam_device get_sam_vit_h_4b8939() logger.info(f"initialize SAM model...") sam_device = device sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device) sam_predictor = SamPredictor(sam_model) sam_mask_generator = SamAutomaticMaskGenerator(sam_model) def load_sd_model(device): # initialize stable-diffusion-inpainting global sd_model logger.info(f"initialize stable-diffusion-inpainting...") sd_model = None if os.environ.get('IS_MY_DEBUG') is None: sd_model = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", revision="fp16", # "stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float16, ) sd_model = sd_model.to(device) def forward_i2sb(img, mask, dilation_mask_extend): print(np.unique(mask),mask.shape) mask = np.where(mask > 0, 1, 0) print(np.unique(mask),mask.shape) mask = mask.astype(np.uint8) if dilation_mask_extend.isdigit(): kernel_size = int(dilation_mask_extend) kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (int(kernel_size), int(kernel_size))) mask = cv2.dilate(mask, kernel, iterations = 1) img_tensor = i2sb_transforms(img).to( i2sb_opt.device).unsqueeze(0) mask_tensor = torch.from_numpy(np.resize(np.array(mask), (256,256))).to( i2sb_opt.device).unsqueeze(0).unsqueeze(0) # print("POST PROCESSING\t", torch.unique(img_tensor)) corrupt_tensor = img_tensor * (1. - mask_tensor) + mask_tensor print("DOUBLE CHECK:\t", corrupt_tensor.shape) print("DOUBLE CHECK:\t", img_tensor.shape) print("DOUBLE CHECK:\t", mask_tensor.shape) f = time.time() xs, _ = i2sb_model.ddpm_sampling( ckpt_opt, img_tensor, mask=mask_tensor, cond=None, clip_denoise=i2sb_opt.clip_denoise, nfe=nfe, verbose=i2sb_opt.n_gpu_per_node == 1) recon_img = xs[:, 0, ...].to(i2sb_opt.device) # tu.save_image((recon_img+1)/2, "output.png") # tu.save_image((corrupt_tensor+1)/2, "output.png") print(recon_img.shape) return transforms.ToPILImage()(((recon_img+1)/2)[0]), transforms.ToPILImage()(((corrupt_tensor+1)/2)[0]) def forward_segformer(img): img_np = np.array(img) img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) result = inference_segmentor(segformer_model, img_np) return np.asarray(result[0], dtype=np.uint8) # visualization def draw_selected_mask(mask, draw): color = (255, 0, 0, 153) nonzero_coords = np.transpose(np.nonzero(mask)) for coord in nonzero_coords: draw.point(coord[::-1], fill=color) def draw_object_mask(mask, draw): color = (0, 0, 255, 153) nonzero_coords = np.transpose(np.nonzero(mask)) for coord in nonzero_coords: draw.point(coord[::-1], fill=color) def create_title_image(word1, word2, word3, width, font_path='./assets/OpenSans-Bold.ttf'): # Define the colors to use for each word color_red = (255, 0, 0) color_black = (0, 0, 0) color_blue = (0, 0, 255) # Define the initial font size and spacing between words font_size = 40 # Create a new image with the specified width and white background image = Image.new('RGB', (width, 60), (255, 255, 255)) try: # Load the specified font font = ImageFont.truetype(font_path, font_size) # Keep increasing the font size until all words fit within the desired width while True: # Create a draw object for the image draw = ImageDraw.Draw(image) word_spacing = font_size / 2 # Draw each word in the appropriate color x_offset = word_spacing draw.text((x_offset, 0), word1, color_red, font=font) x_offset += font.getsize(word1)[0] + word_spacing draw.text((x_offset, 0), word2, color_black, font=font) x_offset += font.getsize(word2)[0] + word_spacing draw.text((x_offset, 0), word3, color_blue, font=font) word_sizes = [font.getsize(word) for word in [word1, word2, word3]] total_width = sum([size[0] for size in word_sizes]) + word_spacing * 3 # Stop increasing font size if the image is within the desired width if total_width <= width: break # Increase font size and reset the draw object font_size -= 1 image = Image.new('RGB', (width, 50), (255, 255, 255)) font = ImageFont.truetype(font_path, font_size) draw = None except Exception as e: pass return image def concatenate_images_vertical(image1, image2): # Get the dimensions of the two images width1, height1 = image1.size width2, height2 = image2.size # Create a new image with the combined height and the maximum width new_image = Image.new('RGBA', (max(width1, width2), height1 + height2)) # Paste the first image at the top of the new image new_image.paste(image1, (0, 0)) # Paste the second image below the first image new_image.paste(image2, (0, height1)) return new_image mask_source_draw = "draw a mask on input image" mask_source_segment = "upload a mask" def get_time_cost(run_task_time, time_cost_str): now_time = int(time.time()*1000) if run_task_time == 0: time_cost_str = 'start' else: if time_cost_str != '': time_cost_str += f'-->' time_cost_str += f'{now_time - run_task_time}' run_task_time = now_time return run_task_time, time_cost_str def run_anything_task(input_image, input_points, origin_image, task_type, mask_source_radio, segmentation_radio, dilation_mask_extend): run_task_time = 0 time_cost_str = '' run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) print("HERE................", task_type) if input_image is None: return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')) file_temp = int(time.time()) logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/[{mask_source_radio}]_1_') output_images = [] # load image if isinstance(input_image, dict): image_pil, image = load_image(input_image['image'].convert("RGB")) input_img = input_image['image'] output_images.append(input_image['image']) run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) else: image_pil, image = load_image(input_image.convert("RGB")) input_img = input_image output_images.append(input_image) run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) size = image_pil.size H, W = size[1], size[0] # run grounding dino model if (task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw: pass else: groundingdino_device = 'cpu' logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_') if task_type == 'segment' or task_type == 'pipeline': image = np.array(origin_image) if segmentation_radio == "SAM": if sam_predictor: sam_predictor.set_image(image) if sam_predictor: logger.info(f"Forward with: {input_points}") masks, _, _, _ = sam_predictor.predict( point_coords = np.array(input_points), point_labels = np.array([1 for _ in range(len(input_points))]), # boxes = transformed_boxes, multimask_output = False, ) # masks: [9, 1, 512, 512] assert sam_checkpoint, 'sam_checkpoint is not found!' else: run_mode = "rectangle" # draw output image plt.figure(figsize=(10, 10)) plt.imshow(origin_image) for mask in masks: show_mask(mask, plt.gca(), random_color=True) # for box, label in zip(boxes_filt, pred_phrases): # show_box(box.cpu().numpy(), plt.gca(), label) plt.axis('off') image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg") plt.savefig(image_path, bbox_inches="tight") plt.clf() plt.close('all') segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) os.remove(image_path) else: masks = forward_segformer(image) segment_image_result = visualize_result_pyplot(segformer_model, image, masks, get_palette("wtm"), dilation=dilation_mask_extend)# if task_type == "pipeline" else None) output_images.append(Image.fromarray(segment_image_result)) run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_') if task_type == 'detection' or task_type == 'segment': logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_') return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')) elif task_type in ['inpainting', 'outpainting'] or task_type == 'pipeline': logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_') if task_type == "pipeline": if segmentation_radio == "SAM": masks_ori = copy.deepcopy(masks) print(masks.shape) # masks = torch.where(masks > 0, True, False) mask = masks[0] mask_pil = Image.fromarray(mask) mask = np.where(mask == True, 1, 0) else: mask = masks save_mask = copy.deepcopy(mask) save_mask = np.where(mask > 0, 255, 0).astype(np.uint8) print((save_mask.dtype)) mask_pil = Image.fromarray(save_mask) else: if mask_source_radio == mask_source_draw: input_mask_pil = input_image['mask'] input_mask = np.array(input_mask_pil.convert("L")) mask_pil = input_mask_pil mask = input_mask else: pass # masks_ori = copy.deepcopy(masks) # masks = torch.where(masks > 0, True, False) # mask = masks[0][0].cpu().numpy() # mask_pil = Image.fromarray(mask) output_images.append(mask_pil.convert("RGB")) run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) if task_type in ['inpainting', 'pipeline']: # image_inpainting = sd_model(prompt = "", image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0] # input_img.save("test.png") w, h = input_img.size input_img = input_img.resize((256,256)) image_inpainting, corrupted = forward_i2sb(input_img, mask, dilation_mask_extend) input_img = input_img.resize((w,h)) corrupted = corrupted.resize((w,h)) image_inpainting = image_inpainting.resize((w,h)) # print("RESULT\t", np.array(image_inpainting)) else: # remove from mask aasds = 1 logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_') if image_inpainting is None: logger.info(f'run_anything_task_failed_') return None, None, None, None # output_images.append(image_inpainting) # run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_') image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1])) output_images.append(corrupted) output_images.append(image_inpainting) run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_') return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')) else: logger.info(f"task_type:{task_type} error!") logger.info(f'run_anything_task_[{file_temp}]_9_9_') return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')) def change_radio_display(task_type, mask_source_radio, orig_img): mask_source_radio_visible = False num_relation_visible = False image_gallery_visible = True kosmos_input_visible = False kosmos_output_visible = False kosmos_text_output_visible = False print(task_type) if task_type == "Kosmos-2": if kosmos_enable: image_gallery_visible = False kosmos_input_visible = True kosmos_output_visible = True kosmos_text_output_visible = True if task_type in ['inpainting', 'outpainting'] or task_type == "remove": mask_source_radio_visible = True if task_type == "relate anything": num_relation_visible = True if task_type == "inpainting": ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6) elif task_type in ["segment", "pipeline"]: ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6) return (gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible), gr.Gallery.update(visible=image_gallery_visible), gr.Radio(["SegFormer", "SAM"], value="SAM", label="Segementation Model", visible= task_type != "inpainting"), gr.Textbox(label="Dilation kernel size", value='7', visible= task_type == "pipeline"), ret, [], gr.Button("Undo point", visible = task_type != "inpainting"), gr.Button("Clear point", visible = task_type != "inpainting"),) def get_model_device(module): try: if module is None: return 'None' if isinstance(module, torch.nn.DataParallel): module = module.module for submodule in module.children(): if hasattr(submodule, "_parameters"): parameters = submodule._parameters if "weight" in parameters: return parameters["weight"].device return 'UnKnown' except Exception as e: return 'Error' def click_callback(coords): print("Clicked at here: ", coords) def main_gradio(args): block = gr.Blocks( title="Thesis-Demo", # theme="shivi/calm_seafoam@>=0.0.1,<1.0.0", ) with block: with gr.Row(): with gr.Column(): selected_points = gr.State([]) original_image = gr.State(None) task_types = ["segment"] if inpainting_enable: task_types.append("inpainting") task_types.append("pipeline") input_image = gr.Image(elem_id="image_upload", type='pil', label="Upload", height=512) input_image.upload( store_img, [input_image], [original_image, selected_points] ) input_image.select( get_point, [input_image, selected_points], [input_image] ) with gr.Row(): with gr.Column(): undo_point_button = gr.Button("Undo point", visible= True if original_image is not None else False) undo_point_button.click( fn= undo_button, inputs=[original_image, selected_points], outputs=[input_image] ) with gr.Column(): clear_point_button = gr.Button("Clear point", visible= True if original_image is not None else False) clear_point_button.click( fn= clear_button, inputs=[original_image], outputs=[input_image, selected_points] ) print(dir(input_image)) task_type = gr.Radio(task_types, value="segment", label='Task type', visible=True) mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment], value=mask_source_draw, label="Mask from", visible=False) segmentation_radio = gr.Radio(["SegFormer", "SAM"], value="SAM", label="Segementation Model", visible=True) dilation_mask_extend = gr.Textbox(label="Dilation kernel size", value='5', visible=False) num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False) run_button = gr.Button(label="Run", visible=True) # with gr.Accordion("Advanced options", open=False) as advanced_options: # box_threshold = gr.Slider( # label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001 # ) # text_threshold = gr.Slider( # label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001 # ) # iou_threshold = gr.Slider( # label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001 # ) # inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode") # with gr.Row(): # with gr.Column(scale=1): # remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode') # with gr.Column(scale=1): # remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10') with gr.Column(): image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True ).style(preview=True, columns=[5], object_fit="scale-down", height=512) time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False) run_button.click(fn=run_anything_task, inputs=[ input_image, selected_points, original_image, task_type, mask_source_radio, segmentation_radio, dilation_mask_extend], outputs=[image_gallery, image_gallery, time_cost, time_cost], show_progress=True, queue=True) mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image], outputs=[mask_source_radio, num_relation]) task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image], outputs=[mask_source_radio, num_relation, image_gallery, segmentation_radio, dilation_mask_extend, input_image, selected_points, undo_point_button, clear_point_button ]) # DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything).
' # if lama_cleaner_enable: # DESCRIPTION += f'Remove(cleaner) from [lama-cleaner](https://github.com/Sanster/lama-cleaner).
' # if kosmos_enable: # DESCRIPTION += f'Kosmos-2 from [Kosmos-2](https://github.com/microsoft/unilm/tree/master/kosmos-2).
' # if ram_enable: # DESCRIPTION += f'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything).
' # DESCRIPTION += f'Thanks for their excellent work.' # DESCRIPTION += f'

For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. \ # Duplicate Space

' # gr.Markdown(DESCRIPTION) print(f'device = {device}') print(f'torch.cuda.is_available = {torch.cuda.is_available()}') computer_info() block.queue(max_size=10, api_open=False) block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share) if __name__ == "__main__": parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True) parser.add_argument("--debug", action="store_true", help="using debug mode") parser.add_argument("--share", action="store_true", help="share the app") parser.add_argument("--port", "-p", type=int, default=7860, help="port") parser.add_argument("--cuda", "-c", type=str, default='cuda:0', help="cuda") args, _ = parser.parse_known_args() print(f'args = {args}') # if os.environ.get('IS_MY_DEBUG') is None: # os.system("pip list") set_device(args) if device == 'cpu': kosmos_enable = False # if kosmos_enable: # kosmos_model, kosmos_processor = load_kosmos_model(device) # if groundingdino_enable: # load_groundingdino_model('cpu') if sam_enable: load_sam_model(device) load_segformer(device) if inpainting_enable: load_sd_model(device) load_i2sb_model() # if lama_cleaner_enable: # load_lama_cleaner_model(device) # if ram_enable: # load_ram_model(device) # if os.environ.get('IS_MY_DEBUG') is None: # os.system("pip list") main_gradio(args)