motheecreator's picture
Update app.py
d00962d verified
import torch
from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor
from PIL import Image
from torchvision.transforms.functional import crop
import gradio as gr
import base64
import io
from huggingface_hub import hf_hub_download
import zipfile
import os
# Global variables for models
object_detection_model = None
captioning_model = None
tokenizer = None
captioning_processor = None
# Load models during initialization
def init():
global object_detection_model, captioning_model, tokenizer, captioning_processor
# Step 1: Load the YOLOv5 model from Hugging Face
try:
print("Loading YOLOv5 model...")
# Get Hugging Face auth token from environment variable
auth_token = os.getenv("HF_AUTH_TOKEN")
if not auth_token:
print("Error: HF_AUTH_TOKEN environment variable not set.")
object_detection_model = None
else:
# Download the zip file from Hugging Face
zip_path = hf_hub_download(repo_id='Mexbow/Yolov5_object_detection', filename='yolov5.zip', use_auth_token=auth_token)
# Extract the YOLOv5 model
extract_path = './yolov5_model' # Specify extraction path
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
os.makedirs(extract_path, exist_ok=True)
zip_ref.extractall(extract_path)
# Load the YOLOv5 model
model_path = os.path.join(extract_path, 'yolov5/weights/best14.pt')
if not os.path.exists(model_path):
print(f"Error: YOLOv5 model file not found at {model_path}")
object_detection_model = None
else:
object_detection_model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path, trust_repo=True)
print("YOLOv5 model loaded successfully.")
except Exception as e:
print(f"Error loading YOLOv5 model: {e}")
object_detection_model = None
# Step 2: Load the ViT-GPT2 captioning model from Hugging Face
try:
print("Loading ViT-GPT2 model...")
captioning_model = VisionEncoderDecoderModel.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
tokenizer = AutoTokenizer.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
captioning_processor = AutoImageProcessor.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
print("ViT-GPT2 model loaded successfully.")
except Exception as e:
print(f"Error loading captioning model: {e}")
captioning_model, tokenizer, captioning_processor = None, None, None
# Utility function to crop objects from the image based on bounding boxes
def crop_objects(image, boxes):
cropped_images = []
for box in boxes:
left, top, right, bottom = box
cropped_image = image.crop((left, top, right, bottom))
cropped_images.append(cropped_image)
return cropped_images
# Gradio interface function
def process_image(image):
global object_detection_model, captioning_model, tokenizer, captioning_processor
# Ensure models are loaded
if object_detection_model is None or captioning_model is None or tokenizer is None or captioning_processor is None:
return None, {"error": "Models are not loaded properly"}, None
try:
# Step 1: Perform object detection with YOLOv5
results = object_detection_model(image)
boxes = results.xyxy[0][:, :4].cpu().numpy() # Bounding boxes
labels = [results.names[int(class_id)] for class_id in results.xyxy[0][:, 5].cpu().numpy().astype(int)] # Class names
scores = results.xyxy[0][:, 4].cpu().numpy() # Confidence scores
# Step 2: Generate caption for the whole image
original_inputs = captioning_processor(images=image, return_tensors="pt")
with torch.no_grad():
original_caption_ids = captioning_model.generate(**original_inputs)
original_caption = tokenizer.decode(original_caption_ids[0], skip_special_tokens=True)
# Step 3: Crop detected objects and generate captions for each object
cropped_images = crop_objects(image, boxes)
captions = []
for cropped_image in cropped_images:
inputs = captioning_processor(images=cropped_image, return_tensors="pt")
with torch.no_grad():
caption_ids = captioning_model.generate(**inputs)
caption = tokenizer.decode(caption_ids[0], skip_special_tokens=True)
captions.append(caption)
# Prepare the result for visualization as a formatted string
detection_results = ""
for i, (label, box, score, caption) in enumerate(zip(labels, boxes, scores, captions)):
detection_results += f"Object {i + 1}: {label} - Caption: {caption}\n"
# Render image with bounding boxes
result_image = results.render()[0]
# Return the image with detections, formatted captions, and the whole image caption
return result_image, detection_results, original_caption
except Exception as e:
return None, {"error": str(e)}, None
# Initialize models
init()
# Gradio Interface
interface = gr.Interface(
fn=process_image, # Function to run
inputs=gr.Image(type="pil"), # Input: Image upload
outputs=[
gr.Image(type="pil", label="Detected Objects"), # Output 1: Image with bounding boxes
gr.Textbox(label="Object Captions & Bounding Boxes", lines=10), # Output 2: Formatted captions
gr.Textbox(label="Whole Image Caption") # Output 3: Caption for the whole image
],
live=True
)
# Launch the Gradio app
interface.launch()