import torch from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor from PIL import Image from torchvision.transforms.functional import crop import gradio as gr import base64 import io from huggingface_hub import hf_hub_download import zipfile import os # Global variables for models object_detection_model = None captioning_model = None tokenizer = None captioning_processor = None # Load models during initialization def init(): global object_detection_model, captioning_model, tokenizer, captioning_processor # Step 1: Load the YOLOv5 model from Hugging Face try: print("Loading YOLOv5 model...") # Get Hugging Face auth token from environment variable auth_token = os.getenv("HF_AUTH_TOKEN") if not auth_token: print("Error: HF_AUTH_TOKEN environment variable not set.") object_detection_model = None else: # Download the zip file from Hugging Face zip_path = hf_hub_download(repo_id='Mexbow/Yolov5_object_detection', filename='yolov5.zip', use_auth_token=auth_token) # Extract the YOLOv5 model extract_path = './yolov5_model' # Specify extraction path with zipfile.ZipFile(zip_path, 'r') as zip_ref: os.makedirs(extract_path, exist_ok=True) zip_ref.extractall(extract_path) # Load the YOLOv5 model model_path = os.path.join(extract_path, 'yolov5/weights/best14.pt') if not os.path.exists(model_path): print(f"Error: YOLOv5 model file not found at {model_path}") object_detection_model = None else: object_detection_model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path, trust_repo=True) print("YOLOv5 model loaded successfully.") except Exception as e: print(f"Error loading YOLOv5 model: {e}") object_detection_model = None # Step 2: Load the ViT-GPT2 captioning model from Hugging Face try: print("Loading ViT-GPT2 model...") captioning_model = VisionEncoderDecoderModel.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning") tokenizer = AutoTokenizer.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning") captioning_processor = AutoImageProcessor.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning") print("ViT-GPT2 model loaded successfully.") except Exception as e: print(f"Error loading captioning model: {e}") captioning_model, tokenizer, captioning_processor = None, None, None # Utility function to crop objects from the image based on bounding boxes def crop_objects(image, boxes): cropped_images = [] for box in boxes: left, top, right, bottom = box cropped_image = image.crop((left, top, right, bottom)) cropped_images.append(cropped_image) return cropped_images # Gradio interface function def process_image(image): global object_detection_model, captioning_model, tokenizer, captioning_processor # Ensure models are loaded if object_detection_model is None or captioning_model is None or tokenizer is None or captioning_processor is None: return None, {"error": "Models are not loaded properly"}, None try: # Step 1: Perform object detection with YOLOv5 results = object_detection_model(image) boxes = results.xyxy[0][:, :4].cpu().numpy() # Bounding boxes labels = [results.names[int(class_id)] for class_id in results.xyxy[0][:, 5].cpu().numpy().astype(int)] # Class names scores = results.xyxy[0][:, 4].cpu().numpy() # Confidence scores # Step 2: Generate caption for the whole image original_inputs = captioning_processor(images=image, return_tensors="pt") with torch.no_grad(): original_caption_ids = captioning_model.generate(**original_inputs) original_caption = tokenizer.decode(original_caption_ids[0], skip_special_tokens=True) # Step 3: Crop detected objects and generate captions for each object cropped_images = crop_objects(image, boxes) captions = [] for cropped_image in cropped_images: inputs = captioning_processor(images=cropped_image, return_tensors="pt") with torch.no_grad(): caption_ids = captioning_model.generate(**inputs) caption = tokenizer.decode(caption_ids[0], skip_special_tokens=True) captions.append(caption) # Prepare the result for visualization as a formatted string detection_results = "" for i, (label, box, score, caption) in enumerate(zip(labels, boxes, scores, captions)): detection_results += f"Object {i + 1}: {label} - Caption: {caption}\n" # Render image with bounding boxes result_image = results.render()[0] # Return the image with detections, formatted captions, and the whole image caption return result_image, detection_results, original_caption except Exception as e: return None, {"error": str(e)}, None # Initialize models init() # Gradio Interface interface = gr.Interface( fn=process_image, # Function to run inputs=gr.Image(type="pil"), # Input: Image upload outputs=[ gr.Image(type="pil", label="Detected Objects"), # Output 1: Image with bounding boxes gr.Textbox(label="Object Captions & Bounding Boxes", lines=10), # Output 2: Formatted captions gr.Textbox(label="Whole Image Caption") # Output 3: Caption for the whole image ], live=True ) # Launch the Gradio app interface.launch()