from typing import Optional import gradio as gr import spaces import supervision as sv import torch from PIL import Image from utils.florence import load_florence_model, run_florence_inference, \ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK from utils.sam import load_sam_image_model, run_sam_inference DEVICE = torch.device("cuda") # DEVICE = torch.device("cpu") torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE) SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) @spaces.GPU(duration=20) @torch.inference_mode() @torch.autocast(device_type="cuda", dtype=torch.bfloat16) def process_image(image_input, task_prompt, text_input) -> Optional[Image.Image]: if not image_input: gr.Info("Please upload an image.") return None if not task_prompt: gr.Info("Please enter a task prompt.") return None if not text_input: gr.Info("Please enter a text prompt.") return None _, result = run_florence_inference( model=FLORENCE_MODEL, processor=FLORENCE_PROCESSOR, device=DEVICE, image=image_input, task=text_input, text=prompt ) detections = sv.Detections.from_lmm( lmm=sv.LMM.FLORENCE_2, result=result, resolution_wh=image_input.size ) detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) if len(detections) == 0: gr.Info("No objects detected.") return None images = [] print("mask generated:", len(detections.mask)) for i in range(len(detections.mask)): img = Image.fromarray(detections.mask[i].astype(np.uint8) * 255) images.append(img) return images with gr.Blocks() as demo: with gr.Row(): with gr.Column(): image = gr.Image(type='pil', label='Upload image') image_url = gr.Textbox( label='Image url', placeholder='Enter text prompts (Optional)') task_prompt = gr.Dropdown( [ "", "", "", "", "", '' ], value="", label="Task Prompt", info="task prompts" ), text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts') submit_button = gr.Button(value='Submit', variant='primary') with gr.Column(): image_gallery = gr.Gallery(label="Generated images") text_prompt.sumbit( fn=process_image, inputs=[ image, task_prompt, text_prompt ], outputs=image_gallery ) submit_button.click( fn=process_image, inputs=[ image, task_prompt, text_prompt ], outputs=image_gallery ) demo.launch(debug=True, show_error=True)