import os import matplotlib.pyplot as plt import matplotlib.patches as patches from PIL import Image import gradio as gr from transformers import AutoProcessor, AutoModelForCausalLM import torch import numpy as np import spaces import subprocess from io import BytesIO # Ensure flash-attn is installed correctly subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) # Initialize Florence-2-large model and processor model_id = 'microsoft/Florence-2-large' model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval() processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) # Function to resize and preprocess image def preprocess_image(image_path, max_size=(800, 800)): image = Image.open(image_path).convert('RGB') if image.size[0] > max_size[0] or image.size[1] > max_size[1]: image.thumbnail(max_size, Image.LANCZOS) # Convert image to numpy array image_np = np.array(image, dtype=np.float32) # Ensure the array is float32 # Ensure the image is in the format [height, width, channels] if image_np.ndim == 2: # Grayscale image image_np = np.expand_dims(image_np, axis=-1) elif image_np.shape[0] == 3: # Image in [channels, height, width] format image_np = np.transpose(image_np, (1, 2, 0)) return image_np, image.size # Function to run Florence-2-large model @spaces.GPU def run_florence_model(image_np, image_size, task_prompt, text_input=None): if text_input is None: prompt = task_prompt else: prompt = task_prompt + text_input inputs = processor(text=prompt, images=image_np, return_tensors="pt") with torch.no_grad(): outputs = model.generate( input_ids=inputs["input_ids"].cuda(), pixel_values=inputs["pixel_values"].cuda(), max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3, ) generated_text = processor.batch_decode(outputs, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation( generated_text, task=task_prompt, image_size=image_size ) return parsed_answer, generated_text # Function to plot image with bounding boxes def plot_image_with_bboxes(image_np, bboxes, labels=None): fig, ax = plt.subplots(1) ax.imshow(image_np / 255.0) # Normalize the image for plotting colors = ['red', 'blue', 'green', 'yellow', 'purple', 'cyan'] for i, bbox in enumerate(bboxes): color = colors[i % len(colors)] x, y, width, height = bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1] rect = patches.Rectangle((x, y), width, height, linewidth=2, edgecolor=color, facecolor='none') ax.add_patch(rect) if labels and i < len(labels): ax.text(x, y, labels[i], color=color, fontsize=8, bbox=dict(facecolor='white', alpha=0.7)) plt.axis('off') # Save the plot to a BytesIO object buf = BytesIO() plt.savefig(buf, format='png') plt.close() buf.seek(0) # Convert the BytesIO buffer to PIL Image pil_image = Image.open(buf) return pil_image # Gradio function to process uploaded images @spaces.GPU def process_image(image_path): image_np, image_size = preprocess_image(image_path) # Convert image_np to float32 image_np = image_np.astype(np.float32) # Image Captioning caption_result, _ = run_florence_model(image_np, image_size, '') detailed_caption_result, _ = run_florence_model(image_np, image_size, '') # Object Detection od_result, _ = run_florence_model(image_np, image_size, '') od_bboxes = od_result[''].get('bboxes', []) od_labels = od_result[''].get('labels', []) # OCR ocr_result, _ = run_florence_model(image_np, image_size, '') # Phrase Grounding pg_result, _ = run_florence_model(image_np, image_size, '', text_input=caption_result['']) pg_bboxes = pg_result[''].get('bboxes', []) pg_labels = pg_result[''].get('labels', []) # Cascaded Tasks (Detailed Caption + Phrase Grounding) cascaded_result, _ = run_florence_model(image_np, image_size, '', text_input=detailed_caption_result['']) cascaded_bboxes = cascaded_result[''].get('bboxes', []) cascaded_labels = cascaded_result[''].get('labels', []) # Create plots od_fig = plot_image_with_bboxes(image_np, od_bboxes, od_labels) pg_fig = plot_image_with_bboxes(image_np, pg_bboxes, pg_labels) cascaded_fig = plot_image_with_bboxes(image_np, cascaded_bboxes, cascaded_labels) # Prepare response response = f""" Image Captioning: - Simple Caption: {caption_result['']} - Detailed Caption: {detailed_caption_result['']} Object Detection: - Detected {len(od_bboxes)} objects OCR: {ocr_result['']} Phrase Grounding: - Grounded {len(pg_bboxes)} phrases from the simple caption Cascaded Tasks: - Grounded {len(cascaded_bboxes)} phrases from the detailed caption """ return response, od_fig, pg_fig, cascaded_fig # Gradio interface with gr.Blocks(theme='NoCrypt/miku') as demo: gr.Markdown(""" # Image Processing with Florence-2-large Upload an image to perform image captioning, object detection, OCR, phrase grounding, and cascaded tasks. """) image_input = gr.Image(type="filepath") text_output = gr.Textbox() plot_output_1 = gr.Image() plot_output_2 = gr.Image() plot_output_3 = gr.Image() image_input.upload(process_image, inputs=[image_input], outputs=[text_output, plot_output_1, plot_output_2, plot_output_3]) footer = """
LinkedIn | GitHub | Live demo of my PhD defense
Made with 💖 by Pejman Ebrahimi
""" gr.HTML(footer) demo.launch()