Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor | |
from PIL import Image | |
from torchvision.transforms.functional import crop | |
import gradio as gr | |
import base64 | |
import io | |
from huggingface_hub import hf_hub_download | |
import zipfile | |
import os | |
# Global variables for models | |
object_detection_model = None | |
captioning_model = None | |
tokenizer = None | |
captioning_processor = None | |
# Load models during initialization | |
def init(): | |
global object_detection_model, captioning_model, tokenizer, captioning_processor | |
# Step 1: Load the YOLOv5 model from Hugging Face | |
try: | |
print("Loading YOLOv5 model...") | |
# Get Hugging Face auth token from environment variable | |
auth_token = os.getenv("HF_AUTH_TOKEN") | |
if not auth_token: | |
print("Error: HF_AUTH_TOKEN environment variable not set.") | |
object_detection_model = None | |
else: | |
# Download the zip file from Hugging Face | |
zip_path = hf_hub_download(repo_id='Mexbow/Yolov5_object_detection', filename='yolov5.zip', use_auth_token=auth_token) | |
# Extract the YOLOv5 model | |
extract_path = './yolov5_model' # Specify extraction path | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
os.makedirs(extract_path, exist_ok=True) | |
zip_ref.extractall(extract_path) | |
# Load the YOLOv5 model | |
model_path = os.path.join(extract_path, 'yolov5/weights/best14.pt') | |
if not os.path.exists(model_path): | |
print(f"Error: YOLOv5 model file not found at {model_path}") | |
object_detection_model = None | |
else: | |
object_detection_model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path, trust_repo=True) | |
print("YOLOv5 model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading YOLOv5 model: {e}") | |
object_detection_model = None | |
# Step 2: Load the ViT-GPT2 captioning model from Hugging Face | |
try: | |
print("Loading ViT-GPT2 model...") | |
captioning_model = VisionEncoderDecoderModel.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning") | |
tokenizer = AutoTokenizer.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning") | |
captioning_processor = AutoImageProcessor.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning") | |
print("ViT-GPT2 model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading captioning model: {e}") | |
captioning_model, tokenizer, captioning_processor = None, None, None | |
# Utility function to crop objects from the image based on bounding boxes | |
def crop_objects(image, boxes): | |
cropped_images = [] | |
for box in boxes: | |
left, top, right, bottom = box | |
cropped_image = image.crop((left, top, right, bottom)) | |
cropped_images.append(cropped_image) | |
return cropped_images | |
# Gradio interface function | |
def process_image(image): | |
global object_detection_model, captioning_model, tokenizer, captioning_processor | |
# Ensure models are loaded | |
if object_detection_model is None or captioning_model is None or tokenizer is None or captioning_processor is None: | |
return None, {"error": "Models are not loaded properly"}, None | |
try: | |
# Step 1: Perform object detection with YOLOv5 | |
results = object_detection_model(image) | |
boxes = results.xyxy[0][:, :4].cpu().numpy() # Bounding boxes | |
labels = [results.names[int(class_id)] for class_id in results.xyxy[0][:, 5].cpu().numpy().astype(int)] # Class names | |
scores = results.xyxy[0][:, 4].cpu().numpy() # Confidence scores | |
# Step 2: Generate caption for the whole image | |
original_inputs = captioning_processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
original_caption_ids = captioning_model.generate(**original_inputs) | |
original_caption = tokenizer.decode(original_caption_ids[0], skip_special_tokens=True) | |
# Step 3: Crop detected objects and generate captions for each object | |
cropped_images = crop_objects(image, boxes) | |
captions = [] | |
for cropped_image in cropped_images: | |
inputs = captioning_processor(images=cropped_image, return_tensors="pt") | |
with torch.no_grad(): | |
caption_ids = captioning_model.generate(**inputs) | |
caption = tokenizer.decode(caption_ids[0], skip_special_tokens=True) | |
captions.append(caption) | |
# Prepare the result for visualization as a formatted string | |
detection_results = "" | |
for i, (label, box, score, caption) in enumerate(zip(labels, boxes, scores, captions)): | |
detection_results += f"Object {i + 1}: {label} - Caption: {caption}\n" | |
# Render image with bounding boxes | |
result_image = results.render()[0] | |
# Return the image with detections, formatted captions, and the whole image caption | |
return result_image, detection_results, original_caption | |
except Exception as e: | |
return None, {"error": str(e)}, None | |
# Initialize models | |
init() | |
# Gradio Interface | |
interface = gr.Interface( | |
fn=process_image, # Function to run | |
inputs=gr.Image(type="pil"), # Input: Image upload | |
outputs=[ | |
gr.Image(type="pil", label="Detected Objects"), # Output 1: Image with bounding boxes | |
gr.Textbox(label="Object Captions & Bounding Boxes", lines=10), # Output 2: Formatted captions | |
gr.Textbox(label="Whole Image Caption") # Output 3: Caption for the whole image | |
], | |
live=True | |
) | |
# Launch the Gradio app | |
interface.launch() | |