Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,436 Bytes
49da751 219a1a5 49da751 851d044 49da751 96e1778 49da751 cbd54ec 49da751 219a1a5 49da751 cbd54ec 49da751 96e1778 851d044 49da751 219a1a5 49da751 64ea3ca 49da751 96e1778 49da751 96e1778 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
import torch
import numpy as np
import spaces
import subprocess
from io import BytesIO
# Ensure flash-attn is installed correctly
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# Initialize Florence-2-large model and processor
model_id = 'microsoft/Florence-2-large'
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval()
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# Function to resize and preprocess image
def preprocess_image(image_path, max_size=(800, 800)):
image = Image.open(image_path).convert('RGB')
if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
image.thumbnail(max_size, Image.LANCZOS)
# Convert image to numpy array
image_np = np.array(image, dtype=np.float32) # Ensure the array is float32
# Ensure the image is in the format [height, width, channels]
if image_np.ndim == 2: # Grayscale image
image_np = np.expand_dims(image_np, axis=-1)
elif image_np.shape[0] == 3: # Image in [channels, height, width] format
image_np = np.transpose(image_np, (1, 2, 0))
return image_np, image.size
# Function to run Florence-2-large model
@spaces.GPU
def run_florence_model(image_np, image_size, task_prompt, text_input=None):
if text_input is None:
prompt = task_prompt
else:
prompt = task_prompt + text_input
inputs = processor(text=prompt, images=image_np, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"].cuda(),
pixel_values=inputs["pixel_values"].cuda(),
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = processor.batch_decode(outputs, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=image_size
)
return parsed_answer, generated_text
# Function to plot image with bounding boxes
def plot_image_with_bboxes(image_np, bboxes, labels=None):
fig, ax = plt.subplots(1)
ax.imshow(image_np / 255.0) # Normalize the image for plotting
colors = ['red', 'blue', 'green', 'yellow', 'purple', 'cyan']
for i, bbox in enumerate(bboxes):
color = colors[i % len(colors)]
x, y, width, height = bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]
rect = patches.Rectangle((x, y), width, height, linewidth=2, edgecolor=color, facecolor='none')
ax.add_patch(rect)
if labels and i < len(labels):
ax.text(x, y, labels[i], color=color, fontsize=8, bbox=dict(facecolor='white', alpha=0.7))
plt.axis('off')
# Save the plot to a BytesIO object
buf = BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
# Convert the BytesIO buffer to PIL Image
pil_image = Image.open(buf)
return pil_image
# Gradio function to process uploaded images
@spaces.GPU
def process_image(image_path):
image_np, image_size = preprocess_image(image_path)
# Convert image_np to float32
image_np = image_np.astype(np.float32)
# Image Captioning
caption_result, _ = run_florence_model(image_np, image_size, '<CAPTION>')
detailed_caption_result, _ = run_florence_model(image_np, image_size, '<DETAILED_CAPTION>')
# Object Detection
od_result, _ = run_florence_model(image_np, image_size, '<OD>')
od_bboxes = od_result['<OD>'].get('bboxes', [])
od_labels = od_result['<OD>'].get('labels', [])
# OCR
ocr_result, _ = run_florence_model(image_np, image_size, '<OCR>')
# Phrase Grounding
pg_result, _ = run_florence_model(image_np, image_size, '<CAPTION_TO_PHRASE_GROUNDING>', text_input=caption_result['<CAPTION>'])
pg_bboxes = pg_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('bboxes', [])
pg_labels = pg_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('labels', [])
# Cascaded Tasks (Detailed Caption + Phrase Grounding)
cascaded_result, _ = run_florence_model(image_np, image_size, '<CAPTION_TO_PHRASE_GROUNDING>', text_input=detailed_caption_result['<DETAILED_CAPTION>'])
cascaded_bboxes = cascaded_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('bboxes', [])
cascaded_labels = cascaded_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('labels', [])
# Create plots
od_fig = plot_image_with_bboxes(image_np, od_bboxes, od_labels)
pg_fig = plot_image_with_bboxes(image_np, pg_bboxes, pg_labels)
cascaded_fig = plot_image_with_bboxes(image_np, cascaded_bboxes, cascaded_labels)
# Prepare response
response = f"""
Image Captioning:
- Simple Caption: {caption_result['<CAPTION>']}
- Detailed Caption: {detailed_caption_result['<DETAILED_CAPTION>']}
Object Detection:
- Detected {len(od_bboxes)} objects
OCR:
{ocr_result['<OCR>']}
Phrase Grounding:
- Grounded {len(pg_bboxes)} phrases from the simple caption
Cascaded Tasks:
- Grounded {len(cascaded_bboxes)} phrases from the detailed caption
"""
return response, od_fig, pg_fig, cascaded_fig
# Gradio interface
with gr.Blocks(theme='NoCrypt/miku') as demo:
gr.Markdown("""
# Image Processing with Florence-2-large
Upload an image to perform image captioning, object detection, OCR, phrase grounding, and cascaded tasks.
""")
image_input = gr.Image(type="filepath")
text_output = gr.Textbox()
plot_output_1 = gr.Image()
plot_output_2 = gr.Image()
plot_output_3 = gr.Image()
image_input.upload(process_image, inputs=[image_input], outputs=[text_output, plot_output_1, plot_output_2, plot_output_3])
footer = """
<div style="text-align: center; margin-top: 20px;">
<a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
<a href="https://github.com/arad1367" target="_blank">GitHub</a> |
<a href="https://arad1367.pythonanywhere.com/" target="_blank">Live demo of my PhD defense</a>
<br>
Made with 💖 by Pejman Ebrahimi
</div>
"""
gr.HTML(footer)
demo.launch()
|