from typing import Optional import gradio as gr import spaces import supervision as sv import torch from PIL import Image from io import BytesIO import PIL.Image import requests from utils.florence import load_florence_model, run_florence_inference, \ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK from utils.sam import load_sam_image_model, run_sam_inference DEVICE = torch.device("cuda") # DEVICE = torch.device("cpu") torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE) SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) @spaces.GPU(duration=20) @torch.inference_mode() @torch.autocast(device_type="cuda", dtype=torch.bfloat16) def process_image(image_input, image_url, task_prompt, text_input) -> Optional[Image.Image]: if not image_input: gr.Info("Please upload an image.") return None if not task_prompt: gr.Info("Please enter a task prompt.") return None if not text_input: gr.Info("Please enter a text prompt.") return None if image_url: print("start to fetch image from url", image_url) response = requests.get(image_url) response.raise_for_status() image_input = PIL.Image.open(BytesIO(response.content)) print("fetch image success") _, result = run_florence_inference( model=FLORENCE_MODEL, processor=FLORENCE_PROCESSOR, device=DEVICE, image=image_input, task=text_input, text=prompt ) detections = sv.Detections.from_lmm( lmm=sv.LMM.FLORENCE_2, result=result, resolution_wh=image_input.size ) detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) if len(detections) == 0: gr.Info("No objects detected.") return None images = [] print("mask generated:", len(detections.mask)) for i in range(len(detections.mask)): img = Image.fromarray(detections.mask[i].astype(np.uint8) * 255) images.append(img) return images with gr.Blocks() as demo: with gr.Row(): with gr.Column(): image = gr.Image(type='pil', label='Upload image') image_url = gr.Textbox( label='Image url', placeholder='Enter text prompts (Optional)') task_prompt = gr.Dropdown( ["", "", "", "", "", ""], value="", label="Task Prompt", info="task prompts" ) text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts') submit_button = gr.Button(value='Submit', variant='primary') with gr.Column(): image_gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto") print(image, image_url, task_prompt, text_prompt, image_gallery) submit_button.click( fn = process_image, inputs = [image, image_url, task_prompt, text_prompt], outputs = [image_gallery,], show_api=False ) demo.launch(debug=True, show_error=True)