| import gc |
| import gradio as gr |
| from fastapi import FastAPI, Body,Header,status |
| from gradio import components,Blocks,Row |
| import json |
| from PIL import Image |
| import torchvision.transforms as transforms |
| import torch |
| from pathlib import Path |
| import os |
| import requests |
| from segment_anything import sam_model_registry, SamPredictor |
| import numpy as np |
| from modules.safe import unsafe_torch_load, load |
| from modules.devices import device, torch_gc, cpu |
|
|
| from modules.processing import process_images |
| import modules.scripts as scripts |
|
|
| UNIT_DEBUG=False |
| def import_or_install(package,pip_name=None): |
| import importlib |
| import subprocess |
| if pip_name is None: |
| pip_name=package |
| try: |
| importlib.import_module(package) |
| print(f"{package} is already installed") |
| except ImportError: |
| print(f"{package} is not installed, installing now...") |
| subprocess.call(['pip', 'install', package]) |
| print(f"{package} has been installed") |
|
|
| import_or_install("segment_anything","git+https://github.com/facebookresearch/segment-anything.git") |
|
|
| class InteractiveImageSegmentor: |
| def download_file_if_not_exists(file_url, file_name): |
| if not os.path.isfile(file_name): |
| response = requests.get(file_url) |
| if response.status_code == 200: |
| with open(file_name, 'wb') as file: |
| file.write(response.content) |
| print("File downloaded successfully!") |
| else: |
| print("Failed to download the file.") |
|
|
| def load_model(self,model_choice="sam_vit_b"): |
| sam_checkpoint=f"{model_choice}.pth" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.device = device |
| if model_choice=="sam_vit_b":InteractiveImageSegmentor.download_file_if_not_exists("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",sam_checkpoint) |
| elif model_choice=="sam_vit_l":InteractiveImageSegmentor.download_file_if_not_exists("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",sam_checkpoint) |
| elif model_choice=="sam_vit_h":InteractiveImageSegmentor.download_file_if_not_exists("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",sam_checkpoint) |
| model_type=model_choice.replace("sam_","") |
| if model_type not in sam_model_registry: |
| model_type="default" |
| print(f"Loading model {model_type} from {sam_checkpoint}") |
| torch.load = unsafe_torch_load |
| self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
| self.sam.to(self.device) |
| self.predictor = SamPredictor(self.sam) |
| torch.load = load |
| def clear_sam_cache(self): |
| self.sam.unload_model() |
| gc.collect() |
| torch_gc() |
|
|
| def mask2image_multi(self,mask:torch.Tensor): |
| |
| if mask.dim() == 3 and mask.shape[-1] == 3: |
| mask = mask.permute(2, 0, 1) |
| elif mask.dim() == 3 and mask.shape[0] == 3: |
| pass |
| else: |
| print(mask.shape) |
| raise ValueError("Mask tensor has an unexpected shape.") |
| color = torch.Tensor([255/255, 155/255, 114/255, 0.6]).to(self.device) |
| binary_mask = mask[0, :, :] |
| h, w = binary_mask.shape |
| mask_image = binary_mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
| return mask_image.permute(2, 0, 1) |
| def mask2image(self, mask: torch.Tensor): |
| if mask.dim() == 3 and mask.shape[0] == 1: |
| binary_mask = mask.squeeze(0) |
| elif mask.dim() == 2: |
| binary_mask = mask |
| else: |
| print(mask.shape) |
| raise ValueError("Mask tensor has an unexpected shape.") |
| h, w = binary_mask.shape |
| rgb_image = binary_mask.repeat(3, 1, 1) |
| alpha_channel = torch.full((1, h, w), 0.6).to(self.device) |
| rgba_image = torch.cat((rgb_image, alpha_channel), dim=0) |
| color = torch.Tensor([255/255, 155/255, 114/255]).to(self.device).reshape(3, 1, 1) |
| return rgba_image |
| def preview_segment(self,image:Image,points:list[list[float]]=None,bbox=None,labels:list[int]=None): |
| pil_2_tensor = transforms.PILToTensor() |
| rgba_image = image.convert("RGBA") |
| image_tensor = pil_2_tensor(rgba_image).cuda() |
| result_tensor=image_tensor.clone() |
| mask_tensor=self.segment(points,bbox,labels) |
| mask_image_tensor=self.mask2image(mask_tensor) |
| mask_image=transforms.ToPILImage()(mask_image_tensor) |
| result_image:Image = transforms.ToPILImage()(result_tensor) |
| result_image=Image.alpha_composite(result_image,mask_image) |
| return result_image |
| def segment(self,points:list[list[float]]=None,bbox=None,labels:list[int]=None)->torch.Tensor: |
| if len(points)==0:points=None |
| if len(labels)==0:labels=None |
| if len(bbox)==0:bbox=None |
| if points is not None:points = torch.Tensor(np.array(points)).to(self.device).unsqueeze(0) |
| if labels is not None:labels = torch.Tensor(np.array(labels)).to(self.device).unsqueeze(0) |
| if bbox is not None:bbox = torch.Tensor(np.array(bbox)).to(self.device) |
| print(points,labels,bbox) |
| masks, scores, logits = self.predictor.predict_torch( |
| point_coords=points, |
| point_labels=labels, |
| boxes=bbox, |
| multimask_output=False, |
| ) |
| return masks[0] |
| |
| def remove_selected(self,image:Image,points:list[list[float]]=None,boxes=None,labels:list[int]=None): |
| pil_2_tensor = transforms.PILToTensor() |
| rgba_image = image.convert("RGBA") |
| image_tensor = pil_2_tensor(rgba_image).cuda() |
| mask_tensor = image_segmentor.segment(points=points,bbox=boxes,labels=labels) |
| result_tensor=image_tensor*(1-mask_tensor) |
| result_image:Image = transforms.ToPILImage()(result_tensor) |
| return result_image |
| def remove_unselected(self,image:Image,points:list[list[float]]=None,boxes=None,labels:list[int]=None): |
| pil_2_tensor = transforms.PILToTensor() |
| rgba_image = image.convert("RGBA") |
| image_tensor = pil_2_tensor(rgba_image).cuda() |
| mask_tensor = image_segmentor.segment(points=points,bbox=boxes,labels=labels) |
| print(image_tensor.shape,mask_tensor.shape) |
| result_tensor=image_tensor*mask_tensor |
| result_image:Image = transforms.ToPILImage()(result_tensor) |
| return result_image |
| pass |
|
|
| def reset_image(image:Image): |
| global image_segmentor |
| if image_segmentor is None: |
| image_segmentor=InteractiveImageSegmentor() |
| image_segmentor.load_model() |
| image_segmentor.predictor.reset_image() |
| image_array = np.array(image) |
| image_segmentor.predictor.set_image(image_array) |
| return image |
|
|
| def on_image_changed(image:Image): |
| global points,labels,box_cache,boxes |
| points=[] |
| labels=[] |
| boxes=[] |
| box_cache=[] |
| reset_image(image) |
| return image |
|
|
| def on_image_clicked(image:Image,choice,input_type,event_data:gr.SelectData): |
| global box_cache,boxes,points,labels |
| if isinstance(choice,str): |
| if choice=="Select":choice=1 |
| elif choice=="Deselect":choice=0 |
| if input_type=="Point": |
| points.append(event_data.index) |
| labels.append(choice) |
| return image_segmentor.preview_segment(image,points=points,bbox=boxes,labels=labels) |
| elif input_type=="Box": |
| box_cache.extend(event_data.index) |
| if len(box_cache)==4: |
| boxes.append(box_cache) |
| box_cache=[] |
| return image_segmentor.preview_segment(image,points=points,bbox=boxes,labels=labels) |
| return image |
|
|
| def on_remove_btn_clicked(image:Image,remove_type:str): |
| global points,labels,box_cache,boxes |
| if remove_type=="Selected": |
| return image_segmentor.remove_selected(image,points=points,boxes=boxes,labels=labels) |
| elif remove_type=="Unselected": |
| return image_segmentor.remove_unselected(image,points=points,boxes=boxes,labels=labels) |
| return image |
|
|
| class Script(scripts.Script): |
| def title(self): |
| return "Interactive Image Segmentor" |
| def show(self, is_img2img): |
| return is_img2img |
| def ui(self, is_img2img): |
| if not is_img2img: return |
| with Blocks(): |
| with Row(equal_height=True): |
| choice=components.Radio(choices=["Select","Deselect"],value="Select",label="Selection Type") |
| input_type=components.Radio(choices=["Point","Box"],value="Point",label="Input Type") |
| remove_type=components.Radio(choices=["Selected","Unselected"],value="Selected",label="Remove Type") |
| with Row(equal_height=True): |
| image=components.Image(type="pil",interactive=True,image_mode="RGB") |
| resulting_image=components.Image(type="pil",image_mode="RGBA") |
| image.change(on_image_changed,inputs=[image],outputs=[resulting_image]) |
| image.select(on_image_clicked,inputs=[image,choice,input_type],outputs=[resulting_image]) |
| with Row(equal_height=True): |
| remove_btn = components.Button(value="Preview Remove Effect") |
| remove_btn.click(on_remove_btn_clicked,inputs=[image,remove_type],outputs=[resulting_image]) |
| pass |
| return [image,points,labels,boxes] |
|
|
| def run(self,p,image,points,labels,boxes): |
| if image is None: |
| image=p.init_images[0] |
| image_segmentor.predictor.set_image(np.array(image)) |
| mask=image_segmentor.predictor.predict_torch(points,labels,boxes) |
| p.image_mask=mask |
| proc = process_images(p) |
| proc.images.append(mask) |
| return proc |
| pass |
|
|
| def interactive_image_segmentor_api(_: Blocks, app: FastAPI): |
| @app.post("/figma/interactive_image_segmentor/upload_image") |
| async def upload_image(image_str:str = Body(...)): |
| import base64 |
| import io |
| image_bytes = base64.b64decode(image_str) |
| image = Image.open(io.BytesIO(image_bytes),formats=["PNG"]) |
| image_segmentor.predictor.reset_image() |
| image_segmentor.predictor.set_image(np.array(image)) |
| return image |
| @app.post("/figma/interactive_image_segmentor/image_x_mask") |
| async def remove_selected(image_str:str = Body(...),points:list[list[float]]=Body(...),\ |
| boxes:list[list[float]]=Body(...),labels:list[int]=Body(...), remove_type:bool=Body(...)): |
| import base64 |
| import io |
| image_bytes = base64.b64decode(image_str) |
| image = Image.open(io.BytesIO(image_bytes),formats=["PNG"]) |
| if remove_type=="Selected": |
| image= image_segmentor.remove_selected(image,points=points,boxes=boxes,labels=labels) |
| elif remove_type=="Unselected": |
| image= image_segmentor.remove_unselected(image,points=points,boxes=boxes,labels=labels) |
| return image |
| pass |
|
|
| points=[] |
| labels=[] |
| box_cache:list=[] |
| boxes=[] |
|
|
| image_segmentor=None |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| try: |
| import modules.script_callbacks as script_callbacks |
|
|
| script_callbacks.on_app_started(interactive_image_segmentor_api) |
| except: |
| pass |
| pass |