Spaces:
Running
Running
| import base64 | |
| import os | |
| from io import BytesIO | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import pyrebase | |
| import requests | |
| from openai import OpenAI | |
| from PIL import Image, ImageDraw, ImageFont | |
| from ultralytics import YOLO | |
| from prompts import remove_unwanted_prompt | |
| model = YOLO("yolo11n.pt") | |
| def get_middle_thumbnail(input_image: Image, grid_size=(10, 10), padding=3): | |
| """ | |
| Extract the middle thumbnail from a sprite sheet, handling different aspect ratios | |
| and removing padding. | |
| Args: | |
| input_image: PIL Image | |
| grid_size: Tuple of (columns, rows) | |
| padding: Number of padding pixels on each side (default 3) | |
| Returns: | |
| PIL.Image: The middle thumbnail image with padding removed | |
| """ | |
| sprite_sheet = input_image | |
| # Calculate thumbnail dimensions based on actual sprite sheet size | |
| sprite_width, sprite_height = sprite_sheet.size | |
| thumb_width_with_padding = sprite_width // grid_size[0] | |
| thumb_height_with_padding = sprite_height // grid_size[1] | |
| # Remove padding to get actual image dimensions | |
| thumb_width = thumb_width_with_padding - (2 * padding) # 726 - 6 = 720 | |
| thumb_height = thumb_height_with_padding - (2 * padding) # varies based on input | |
| # Calculate the middle position | |
| total_thumbs = grid_size[0] * grid_size[1] | |
| middle_index = total_thumbs // 2 | |
| # Calculate row and column of middle thumbnail | |
| middle_row = middle_index // grid_size[0] | |
| middle_col = middle_index % grid_size[0] | |
| # Calculate pixel coordinates for cropping, including padding offset | |
| left = (middle_col * thumb_width_with_padding) + padding | |
| top = (middle_row * thumb_height_with_padding) + padding | |
| right = left + thumb_width # Don't add padding here | |
| bottom = top + thumb_height # Don't add padding here | |
| # Crop and return the middle thumbnail | |
| middle_thumb = sprite_sheet.crop((left, top, right, bottom)) | |
| return middle_thumb | |
| def encode_image_to_base64(image: Image.Image, format: str = "JPEG") -> str: | |
| """ | |
| Convert a PIL image to a base64 string. | |
| Args: | |
| image: PIL Image object | |
| format: Image format to use for encoding (default: PNG) | |
| Returns: | |
| Base64 encoded string of the image | |
| """ | |
| buffered = BytesIO() | |
| image.save(buffered, format=format) | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| def add_top_numbers( | |
| input_image, | |
| num_divisions=20, | |
| margin=90, | |
| font_size=120, | |
| dot_spacing=20, | |
| ): | |
| """ | |
| Add numbered divisions across the top and bottom of any image with dotted vertical lines. | |
| Args: | |
| input_image (Image): PIL Image | |
| num_divisions (int): Number of divisions to create | |
| margin (int): Size of margin in pixels for numbers | |
| font_size (int): Font size for numbers | |
| dot_spacing (int): Spacing between dots in pixels | |
| """ | |
| # Load the image | |
| original_image = input_image | |
| # Create new image with extra space for numbers on top and bottom | |
| new_width = original_image.width | |
| new_height = original_image.height + ( | |
| 2 * margin | |
| ) # Add margin to both top and bottom | |
| new_image = Image.new("RGB", (new_width, new_height), "white") | |
| # Paste original image in the middle | |
| new_image.paste(original_image, (0, margin)) | |
| # Initialize drawing context | |
| draw = ImageDraw.Draw(new_image) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", font_size) | |
| except OSError: | |
| print("Using default font") | |
| font = ImageFont.load_default(size=font_size) | |
| # Calculate division width | |
| division_width = original_image.width / num_divisions | |
| # Draw division numbers and dotted lines | |
| for i in range(num_divisions): | |
| x = (i * division_width) + (division_width / 2) | |
| # Draw number at top | |
| draw.text((x, margin // 2), str(i + 1), fill="black", font=font, anchor="mm") | |
| # Draw number at bottom | |
| draw.text( | |
| (x, new_height - (margin // 2)), | |
| str(i + 1), | |
| fill="black", | |
| font=font, | |
| anchor="mm", | |
| ) | |
| # Draw dotted line from top margin to bottom margin | |
| y_start = margin | |
| y_end = new_height - margin | |
| # Draw dots with specified spacing | |
| current_y = y_start | |
| while current_y < y_end: | |
| draw.circle( | |
| [x - 1, current_y - 1, x + 1, current_y + 1], | |
| fill="black", | |
| width=5, | |
| radius=3, | |
| ) | |
| current_y += dot_spacing | |
| return new_image | |
| def crop_and_draw_divisions( | |
| input_image, | |
| left_division, | |
| right_division, | |
| num_divisions=20, | |
| line_color=(255, 0, 0), | |
| line_width=2, | |
| head_margin_percent=0.1, | |
| ): | |
| """ | |
| Create both 9:16 and 16:9 crops and draw guide lines. | |
| Args: | |
| input_image (Image): PIL Image | |
| left_division (int): Left-side division number (1-20) | |
| right_division (int): Right-side division number (1-20) | |
| num_divisions (int): Total number of divisions (default=20) | |
| line_color (tuple): RGB color tuple for lines (default: red) | |
| line_width (int): Width of lines in pixels (default: 2) | |
| head_margin_percent (float): Percentage margin above head (default: 0.1) | |
| Returns: | |
| tuple: (cropped_image_16_9, image_with_lines, cropped_image_9_16) | |
| """ | |
| yolo_model = YOLO("yolo11n.pt") | |
| # Calculate division width and boundaries | |
| division_width = input_image.width / num_divisions | |
| left_boundary = (left_division - 1) * division_width | |
| right_boundary = right_division * division_width | |
| # First get the 9:16 crop | |
| cropped_image_9_16 = input_image.crop( | |
| (left_boundary, 0, right_boundary, input_image.height) | |
| ) | |
| # Run YOLO on the 9:16 crop to get person bbox | |
| bbox = yolo_model(cropped_image_9_16, classes=[0])[0].boxes.xyxy.cpu().numpy()[0] | |
| x1, y1, x2, y2 = bbox | |
| # Calculate top boundary with head margin | |
| head_margin = (y2 - y1) * head_margin_percent | |
| top_boundary = max(0, y1 - head_margin) | |
| # Calculate 16:9 dimensions based on the width between divisions | |
| crop_width = right_boundary - left_boundary | |
| crop_height_16_9 = int(crop_width * 9 / 16) | |
| # Calculate bottom boundary for 16:9 | |
| bottom_boundary = min(input_image.height, top_boundary + crop_height_16_9) | |
| # Create 16:9 crop from original image | |
| cropped_image_16_9 = input_image.crop( | |
| (left_boundary, top_boundary, right_boundary, bottom_boundary) | |
| ) | |
| # Draw guide lines for both crops on original image | |
| image_with_lines = input_image.copy() | |
| draw = ImageDraw.Draw(image_with_lines) | |
| # Draw vertical lines (for both crops) | |
| draw.line( | |
| [(left_boundary, 0), (left_boundary, input_image.height)], | |
| fill=line_color, | |
| width=line_width, | |
| ) | |
| draw.line( | |
| [(right_boundary, 0), (right_boundary, input_image.height)], | |
| fill=line_color, | |
| width=line_width, | |
| ) | |
| # Draw horizontal lines (for 16:9 crop) | |
| draw.line( | |
| [(left_boundary, top_boundary), (right_boundary, top_boundary)], | |
| fill=line_color, | |
| width=line_width, | |
| ) | |
| draw.line( | |
| [(left_boundary, bottom_boundary), (right_boundary, bottom_boundary)], | |
| fill=line_color, | |
| width=line_width, | |
| ) | |
| return cropped_image_16_9, image_with_lines, cropped_image_9_16 | |
| def analyze_image(numbered_input_image: Image, prompt, input_image): | |
| """ | |
| Perform inference on an image using GPT-4V. | |
| Args: | |
| numbered_input_image (Image): PIL Image | |
| prompt (str): The prompt/question about the image | |
| input_image (Image): input image without numbers | |
| Returns: | |
| str: The model's response | |
| """ | |
| client = OpenAI() | |
| base64_image = encode_image_to_base64(numbered_input_image, format="JPEG") | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}, | |
| }, | |
| ], | |
| } | |
| ] | |
| response = client.chat.completions.create( | |
| model="gpt-4o", messages=messages, max_tokens=300 | |
| ) | |
| messages.extend( | |
| [ | |
| {"role": "assistant", "content": response.choices[0].message.content}, | |
| { | |
| "role": "user", | |
| "content": "please return the response in the json with keys left_row and right_row", | |
| }, | |
| ], | |
| ) | |
| response = ( | |
| client.chat.completions.create(model="gpt-4o", messages=messages) | |
| .choices[0] | |
| .message.content | |
| ) | |
| left_index = response.find("{") | |
| right_index = response.rfind("}") | |
| try: | |
| if left_index != -1 and right_index != -1: | |
| response_json = eval(response[left_index : right_index + 1]) | |
| cropped_image_16_9, image_with_lines, cropped_image_9_16 = ( | |
| crop_and_draw_divisions( | |
| input_image=input_image, | |
| left_division=response_json["left_row"], | |
| right_division=response_json["right_row"], | |
| ) | |
| ) | |
| except Exception as e: | |
| print(e) | |
| return input_image, input_image, input_image, 0, 20 | |
| return ( | |
| cropped_image_16_9, | |
| image_with_lines, | |
| cropped_image_9_16, | |
| response_json["left_row"], | |
| response_json["right_row"], | |
| ) | |
| def get_sprite_firebase(cid, rsid, uid): | |
| config = { | |
| "apiKey": f"{os.getenv('FIREBASE_API_KEY')}", | |
| "authDomain": f"{os.getenv('FIREBASE_AUTH_DOMAIN')}", | |
| "databaseURL": f"{os.getenv('FIREBASE_DATABASE_URL')}", | |
| "projectId": f"{os.getenv('FIREBASE_PROJECT_ID')}", | |
| "storageBucket": f"{os.getenv('FIREBASE_STORAGE_BUCKET')}", | |
| "messagingSenderId": f"{os.getenv('FIREBASE_MESSAGING_SENDER_ID')}", | |
| "appId": f"{os.getenv('FIREBASE_APP_ID')}", | |
| "measurementId": f"{os.getenv('FIREBASE_MEASUREMENT_ID')}", | |
| } | |
| firebase = pyrebase.initialize_app(config) | |
| db = firebase.database() | |
| account_id = os.getenv("ROLL_ACCOUNT") | |
| COLLAB_EDIT_LINK = "collab_sprite_link_handler" | |
| path = f"{account_id}/{COLLAB_EDIT_LINK}/{uid}/{cid}/{rsid}" | |
| data = db.child(path).get() | |
| return data.val() | |
| def find_persons_center(image): | |
| """ | |
| Find the center point of all persons in the image. | |
| If multiple persons are detected, merge all bounding boxes and find the center. | |
| Args: | |
| image: CV2/numpy array image | |
| Returns: | |
| int: x-coordinate of the center point of all persons | |
| """ | |
| # Detect persons (class 0 in COCO dataset) | |
| results = model(image, classes=[0]) | |
| if not results or len(results[0].boxes) == 0: | |
| # If no persons detected, return center of image | |
| return image.shape[1] // 2 | |
| # Get all person boxes | |
| boxes = results[0].boxes.xyxy.cpu().numpy() | |
| # Print the number of persons detected (for debugging) | |
| print(f"Detected {len(boxes)} persons in the image") | |
| if len(boxes) == 1: | |
| # If only one person, return center of their bounding box | |
| x1, _, x2, _ = boxes[0] | |
| center_x = int((x1 + x2) // 2) | |
| print(f"Single person detected at center x: {center_x}") | |
| return center_x | |
| else: | |
| # Multiple persons - create a merged bounding box | |
| left_x = min(box[0] for box in boxes) | |
| right_x = max(box[2] for box in boxes) | |
| merged_center_x = int((left_x + right_x) // 2) | |
| print(f"Multiple persons merged bounding box center x: {merged_center_x}") | |
| print(f"Merged bounds: left={left_x}, right={right_x}") | |
| return merged_center_x | |
| def create_layouts(image, left_division, right_division): | |
| """ | |
| Create different layout variations of the image using half, one-third, and two-thirds width. | |
| All layout variations will be centered on detected persons, including 16:9 and 9:16 crops. | |
| Args: | |
| image: PIL Image | |
| left_division: Left division index (1-20) | |
| right_division: Right division index (1-20) | |
| Returns: | |
| tuple: (list of layout variations, cutout_image, cutout_16_9, cutout_9_16) | |
| """ | |
| # Convert PIL Image to cv2 format | |
| if isinstance(image, Image.Image): | |
| image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| else: | |
| image_cv = image.copy() | |
| # Get image dimensions | |
| height, width = image_cv.shape[:2] | |
| # Calculate division width and crop boundaries | |
| division_width = width / 20 # Assuming 20 divisions | |
| left_boundary = int((left_division - 1) * division_width) | |
| right_boundary = int(right_division * division_width) | |
| # 1. Create cutout image based on divisions | |
| cutout_image = image_cv[:, left_boundary:right_boundary].copy() | |
| cutout_width = right_boundary - left_boundary | |
| cutout_height = cutout_image.shape[0] | |
| # 2. Run YOLO on cutout to get person bounding box and center | |
| results = model(cutout_image, classes=[0]) | |
| # Default center if no detection | |
| cutout_center_x = cutout_image.shape[1] // 2 | |
| cutout_center_y = cutout_height // 2 | |
| # Default values for bounding box | |
| person_top = 0.0 | |
| person_height = float(cutout_height) | |
| if results and len(results[0].boxes) > 0: | |
| # Get person detection | |
| boxes = results[0].boxes.xyxy.cpu().numpy() | |
| if len(boxes) == 1: | |
| # Single person | |
| x1, y1, x2, y2 = boxes[0] | |
| cutout_center_x = int((x1 + x2) // 2) | |
| cutout_center_y = int((y1 + y2) // 2) | |
| person_top = y1 | |
| person_height = y2 - y1 | |
| else: | |
| # Multiple persons - merge bounding boxes | |
| left_x = min(box[0] for box in boxes) | |
| right_x = max(box[2] for box in boxes) | |
| top_y = min(box[1] for box in boxes) # Top of highest person | |
| bottom_y = max(box[3] for box in boxes) # Bottom of lowest person | |
| cutout_center_x = int((left_x + right_x) // 2) | |
| cutout_center_y = int((top_y + bottom_y) // 2) | |
| person_top = top_y | |
| person_height = bottom_y - top_y | |
| # 3. Create 16:9 and 9:16 versions with person properly framed | |
| aspect_16_9 = 16 / 9 | |
| aspect_9_16 = 9 / 16 | |
| # For 16:9 version (with 20% margin above person) | |
| target_height_16_9 = int(cutout_width / aspect_16_9) | |
| if target_height_16_9 <= cutout_height: | |
| # Calculate 20% of person height for top margin | |
| top_margin = int(person_height * 0.2) | |
| # Start 20% above the person's top | |
| y_start = int(max(0, person_top - top_margin)) | |
| # If this would make the crop exceed the bottom, adjust y_start | |
| if y_start + target_height_16_9 > cutout_height: | |
| y_start = int(max(0, cutout_height - target_height_16_9)) | |
| y_end = int(min(cutout_height, y_start + target_height_16_9)) | |
| cutout_16_9 = cutout_image[y_start:y_end, :].copy() | |
| else: | |
| # Handle rare case where we need to adjust width (not expected with normal images) | |
| new_width = int(cutout_height * aspect_16_9) | |
| x_start = max( | |
| 0, min(cutout_width - new_width, cutout_center_x - new_width // 2) | |
| ) | |
| x_end = min(cutout_width, x_start + new_width) | |
| cutout_16_9 = cutout_image[:, x_start:x_end].copy() | |
| # For 9:16 version (centered on person) | |
| target_width_9_16 = int(cutout_height * aspect_9_16) | |
| if target_width_9_16 <= cutout_width: | |
| # Center horizontally around person | |
| x_start = int( | |
| max( | |
| 0, | |
| min( | |
| cutout_width - target_width_9_16, | |
| cutout_center_x - target_width_9_16 // 2, | |
| ), | |
| ) | |
| ) | |
| x_end = int(min(cutout_width, x_start + target_width_9_16)) | |
| cutout_9_16 = cutout_image[:, x_start:x_end].copy() | |
| else: | |
| # Handle rare case where we need to adjust height | |
| new_height = int(cutout_width / aspect_9_16) | |
| y_start = int( | |
| max(0, min(cutout_height - new_height, cutout_center_y - new_height // 2)) | |
| ) | |
| y_end = int(min(cutout_height, y_start + new_height)) | |
| cutout_9_16 = cutout_image[y_start:y_end, :].copy() | |
| # 4. Scale the center back to original image coordinates | |
| original_center_x = left_boundary + cutout_center_x | |
| # 5. Create layout variations on the original image centered on persons | |
| # Half width layout | |
| half_width = width // 2 | |
| half_left_x = max(0, min(width - half_width, original_center_x - half_width // 2)) | |
| half_right_x = half_left_x + half_width | |
| half_width_crop = image_cv[:, half_left_x:half_right_x].copy() | |
| # Third width layout | |
| third_width = width // 3 | |
| third_left_x = max( | |
| 0, min(width - third_width, original_center_x - third_width // 2) | |
| ) | |
| third_right_x = third_left_x + third_width | |
| third_width_crop = image_cv[:, third_left_x:third_right_x].copy() | |
| # Two-thirds width layout | |
| two_thirds_width = (width * 2) // 3 | |
| two_thirds_left_x = max( | |
| 0, min(width - two_thirds_width, original_center_x - two_thirds_width // 2) | |
| ) | |
| two_thirds_right_x = two_thirds_left_x + two_thirds_width | |
| two_thirds_crop = image_cv[:, two_thirds_left_x:two_thirds_right_x].copy() | |
| # Add labels to all crops | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| label_settings = { | |
| "fontScale": 1.0, | |
| "fontFace": 1, | |
| "thickness": 2, | |
| } | |
| # Draw label backgrounds for better visibility | |
| def add_label(img, label): | |
| # Draw background for text | |
| text_size = cv2.getTextSize( | |
| label, **{k: v for k, v in label_settings.items() if k != "color"} | |
| ) | |
| cv2.rectangle( | |
| img, | |
| (10, 10), | |
| (10 + text_size[0][0] + 10, 10 + text_size[0][1] + 10), | |
| (0, 0, 0), | |
| -1, | |
| ) # Black background | |
| # Draw text | |
| cv2.putText( | |
| img, | |
| label, | |
| (15, 15 + text_size[0][1]), | |
| **label_settings, | |
| color=(255, 255, 255), | |
| lineType=cv2.LINE_AA, | |
| ) | |
| return img | |
| cutout_image = add_label(cutout_image, "Cutout") | |
| cutout_16_9 = add_label(cutout_16_9, "16:9") | |
| cutout_9_16 = add_label(cutout_9_16, "9:16") | |
| half_width_crop = add_label(half_width_crop, "Half Width") | |
| third_width_crop = add_label(third_width_crop, "Third Width") | |
| two_thirds_crop = add_label(two_thirds_crop, "Two-Thirds Width") | |
| # Convert all output images to PIL format | |
| layout_crops = [] | |
| for layout, label in [ | |
| (half_width_crop, "Half Width"), | |
| (third_width_crop, "Third Width"), | |
| (two_thirds_crop, "Two-Thirds Width"), | |
| ]: | |
| pil_layout = Image.fromarray(cv2.cvtColor(layout, cv2.COLOR_BGR2RGB)) | |
| layout_crops.append(pil_layout) | |
| cutout_pil = Image.fromarray(cv2.cvtColor(cutout_image, cv2.COLOR_BGR2RGB)) | |
| cutout_16_9_pil = Image.fromarray(cv2.cvtColor(cutout_16_9, cv2.COLOR_BGR2RGB)) | |
| cutout_9_16_pil = Image.fromarray(cv2.cvtColor(cutout_9_16, cv2.COLOR_BGR2RGB)) | |
| return layout_crops, cutout_pil, cutout_16_9_pil, cutout_9_16_pil | |
| def draw_all_crops_on_original(image, left_division, right_division): | |
| """ | |
| Create a visualization showing all crop regions overlaid on the original image. | |
| Each crop region is outlined with a different color and labeled. | |
| All crops are centered on the person's center point. | |
| Args: | |
| image: PIL Image | |
| left_division: Left division index (1-20) | |
| right_division: Right division index (1-20) | |
| Returns: | |
| PIL Image: Original image with all crop regions visualized | |
| """ | |
| # Convert PIL Image to cv2 format | |
| if isinstance(image, Image.Image): | |
| image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| else: | |
| image_cv = image.copy() | |
| # Get a clean copy for drawing | |
| visualization = image_cv.copy() | |
| # Get image dimensions | |
| height, width = image_cv.shape[:2] | |
| # Calculate division width and crop boundaries | |
| division_width = width / 20 # Assuming 20 divisions | |
| left_boundary = int((left_division - 1) * division_width) | |
| right_boundary = int(right_division * division_width) | |
| # Find person bounding box and center in cutout | |
| cutout_image = image_cv[:, left_boundary:right_boundary].copy() | |
| # Get YOLO detections for person bounding box | |
| results = model(cutout_image, classes=[0]) | |
| # Default values | |
| cutout_center_x = cutout_image.shape[1] // 2 | |
| cutout_center_y = cutout_image.shape[0] // 2 | |
| person_top = 0.0 | |
| person_height = float(cutout_image.shape[0]) | |
| if results and len(results[0].boxes) > 0: | |
| # Get person detection | |
| boxes = results[0].boxes.xyxy.cpu().numpy() | |
| if len(boxes) == 1: | |
| # Single person | |
| x1, y1, x2, y2 = boxes[0] | |
| cutout_center_x = int((x1 + x2) // 2) | |
| cutout_center_y = int((y1 + y2) // 2) | |
| person_top = y1 | |
| person_height = y2 - y1 | |
| else: | |
| # Multiple persons - merge bounding boxes | |
| left_x = min(box[0] for box in boxes) | |
| right_x = max(box[2] for box in boxes) | |
| top_y = min(box[1] for box in boxes) # Top of highest person | |
| bottom_y = max(box[3] for box in boxes) # Bottom of lowest person | |
| cutout_center_x = int((left_x + right_x) // 2) | |
| cutout_center_y = int((top_y + bottom_y) // 2) | |
| person_top = top_y | |
| person_height = bottom_y - top_y | |
| # Scale back to original image | |
| original_center_x = left_boundary + cutout_center_x | |
| original_center_y = cutout_center_y | |
| original_person_top = ( | |
| person_top # Already in original image space since we didn't crop vertically | |
| ) | |
| original_person_height = person_height # Same in original space | |
| # Define colors for different crops (BGR format) | |
| colors = { | |
| "cutout": (0, 165, 255), # Orange | |
| "16:9": (0, 255, 0), # Green | |
| "9:16": (255, 0, 0), # Blue | |
| "half": (255, 255, 0), # Cyan | |
| "third": (255, 0, 255), # Magenta | |
| "two_thirds": (0, 255, 255), # Yellow | |
| } | |
| # Define line thickness and font | |
| thickness = 3 | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 0.8 | |
| font_thickness = 2 | |
| # 1. Draw cutout region (original divisions) | |
| cv2.rectangle( | |
| visualization, | |
| (left_boundary, 0), | |
| (right_boundary, height), | |
| colors["cutout"], | |
| thickness, | |
| ) | |
| cv2.putText( | |
| visualization, | |
| "Cutout", | |
| (left_boundary + 5, 30), | |
| font, | |
| font_scale, | |
| colors["cutout"], | |
| font_thickness, | |
| ) | |
| # 2. Create 16:9 and 9:16 versions of the cutout - CENTERED on person | |
| cutout_width = right_boundary - left_boundary | |
| cutout_height = height | |
| # For 16:9 version with 20% margin above person | |
| aspect_16_9 = 16 / 9 | |
| target_height_16_9 = int(cutout_width / aspect_16_9) | |
| if target_height_16_9 <= height: | |
| # Calculate 20% of person height for top margin | |
| top_margin = int(original_person_height * 0.2) | |
| # Start 20% above the person's top | |
| y_start = int(max(0, original_person_top - top_margin)) | |
| # If this would make the crop exceed the bottom, adjust y_start | |
| if y_start + target_height_16_9 > height: | |
| y_start = int(max(0, height - target_height_16_9)) | |
| y_end = int(min(height, y_start + target_height_16_9)) | |
| cv2.rectangle( | |
| visualization, | |
| (left_boundary, y_start), | |
| (right_boundary, y_end), | |
| colors["16:9"], | |
| thickness, | |
| ) | |
| cv2.putText( | |
| visualization, | |
| "16:9", | |
| (left_boundary + 5, y_start + 30), | |
| font, | |
| font_scale, | |
| colors["16:9"], | |
| font_thickness, | |
| ) | |
| # For 9:16 version centered on person | |
| aspect_9_16 = 9 / 16 | |
| target_width_9_16 = int(cutout_height * aspect_9_16) | |
| if target_width_9_16 <= cutout_width: | |
| # Center horizontally around person | |
| x_start = max( | |
| 0, | |
| min( | |
| left_boundary + cutout_width - target_width_9_16, | |
| original_center_x - target_width_9_16 // 2, | |
| ), | |
| ) | |
| x_end = x_start + target_width_9_16 | |
| cv2.rectangle( | |
| visualization, (x_start, 0), (x_end, height), colors["9:16"], thickness | |
| ) | |
| cv2.putText( | |
| visualization, | |
| "9:16", | |
| (x_start + 5, 60), | |
| font, | |
| font_scale, | |
| colors["9:16"], | |
| font_thickness, | |
| ) | |
| # 3. Draw centered layout variations | |
| # Half width layout | |
| half_width = width // 2 | |
| half_left_x = max(0, min(width - half_width, original_center_x - half_width // 2)) | |
| half_right_x = half_left_x + half_width | |
| cv2.rectangle( | |
| visualization, | |
| (half_left_x, 0), | |
| (half_right_x, height), | |
| colors["half"], | |
| thickness, | |
| ) | |
| cv2.putText( | |
| visualization, | |
| "Half Width", | |
| (half_left_x + 5, 90), | |
| font, | |
| font_scale, | |
| colors["half"], | |
| font_thickness, | |
| ) | |
| # Third width layout | |
| third_width = width // 3 | |
| third_left_x = max( | |
| 0, min(width - third_width, original_center_x - third_width // 2) | |
| ) | |
| third_right_x = third_left_x + third_width | |
| cv2.rectangle( | |
| visualization, | |
| (third_left_x, 0), | |
| (third_right_x, height), | |
| colors["third"], | |
| thickness, | |
| ) | |
| cv2.putText( | |
| visualization, | |
| "Third Width", | |
| (third_left_x + 5, 120), | |
| font, | |
| font_scale, | |
| colors["third"], | |
| font_thickness, | |
| ) | |
| # Two-thirds width layout | |
| two_thirds_width = (width * 2) // 3 | |
| two_thirds_left_x = max( | |
| 0, min(width - two_thirds_width, original_center_x - two_thirds_width // 2) | |
| ) | |
| two_thirds_right_x = two_thirds_left_x + two_thirds_width | |
| cv2.rectangle( | |
| visualization, | |
| (two_thirds_left_x, 0), | |
| (two_thirds_right_x, height), | |
| colors["two_thirds"], | |
| thickness, | |
| ) | |
| cv2.putText( | |
| visualization, | |
| "Two-Thirds Width", | |
| (two_thirds_left_x + 5, 150), | |
| font, | |
| font_scale, | |
| colors["two_thirds"], | |
| font_thickness, | |
| ) | |
| # 4. Draw center point of person(s) | |
| center_radius = 8 | |
| cv2.circle( | |
| visualization, | |
| (original_center_x, height // 2), | |
| center_radius, | |
| (255, 255, 255), | |
| -1, | |
| ) | |
| cv2.circle( | |
| visualization, (original_center_x, height // 2), center_radius, (0, 0, 0), 2 | |
| ) | |
| cv2.putText( | |
| visualization, | |
| "Person Center", | |
| (original_center_x + 10, height // 2), | |
| font, | |
| font_scale, | |
| (255, 255, 255), | |
| font_thickness, | |
| ) | |
| # Convert back to PIL format | |
| visualization_pil = Image.fromarray(cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB)) | |
| return visualization_pil | |
| def get_image_crop(cid=None, rsid=None, uid=None): | |
| """ | |
| Function that returns both 16:9 and 9:16 crops and layout variations for visualization. | |
| Returns: | |
| gr.Gallery: Gallery of all generated images | |
| """ | |
| image_paths = get_sprite_firebase(cid, rsid, uid) | |
| # Lists to store all images | |
| all_images = [] | |
| all_captions = [] | |
| for image_path in image_paths: | |
| # Load image (from local file or URL) | |
| try: | |
| if image_path.startswith(("http://", "https://")): | |
| response = requests.get(image_path) | |
| input_image = Image.open(BytesIO(response.content)) | |
| else: | |
| input_image = Image.open(image_path) | |
| except Exception as e: | |
| print(f"Error loading image {image_path}: {e}") | |
| continue | |
| # Get the middle thumbnail | |
| mid_image = get_middle_thumbnail(input_image) | |
| # Add numbered divisions for GPT-4V analysis | |
| numbered_mid_image = add_top_numbers( | |
| input_image=mid_image, | |
| num_divisions=20, | |
| margin=50, | |
| font_size=30, | |
| dot_spacing=20, | |
| ) | |
| # Analyze the image to get optimal crop divisions | |
| # This uses GPT-4V to identify the optimal crop points | |
| ( | |
| _, | |
| _, | |
| _, | |
| left_division, | |
| right_division, | |
| ) = analyze_image(numbered_mid_image, remove_unwanted_prompt(2), mid_image) | |
| # Safety check for divisions | |
| if left_division <= 0: | |
| left_division = 1 | |
| if right_division > 20: | |
| right_division = 20 | |
| if left_division >= right_division: | |
| left_division = 1 | |
| right_division = 20 | |
| print(f"Using divisions: left={left_division}, right={right_division}") | |
| # Create layouts and cutouts | |
| layouts, cutout_image, cutout_16_9, cutout_9_16 = create_layouts( | |
| mid_image, left_division, right_division | |
| ) | |
| # Create the visualization with all crops overlaid on original | |
| all_crops_visualization = draw_all_crops_on_original( | |
| mid_image, left_division, right_division | |
| ) | |
| # Start with the visualization showing all crops | |
| all_images.append(all_crops_visualization) | |
| all_captions.append(f"All Crops Visualization {all_crops_visualization.size}") | |
| # Add input and middle image to gallery | |
| all_images.append(input_image) | |
| all_captions.append(f"Input Image {input_image.size}") | |
| all_images.append(mid_image) | |
| all_captions.append(f"Middle Thumbnail {mid_image.size}") | |
| # Add cutout images to gallery | |
| all_images.append(cutout_image) | |
| all_captions.append(f"Cutout Image {cutout_image.size}") | |
| all_images.append(cutout_16_9) | |
| all_captions.append(f"16:9 Crop {cutout_16_9.size}") | |
| all_images.append(cutout_9_16) | |
| all_captions.append(f"9:16 Crop {cutout_9_16.size}") | |
| # Add layout variations | |
| for i, layout in enumerate(layouts): | |
| label = ["Half Width", "Third Width", "Two-Thirds Width"][i] | |
| all_images.append(layout) | |
| all_captions.append(f"{label} {layout.size}") | |
| # Return gallery with all images | |
| return gr.Gallery(value=list(zip(all_images, all_captions))) | |