Spaces:
Running
Running
import base64 | |
import os | |
from io import BytesIO | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
import pyrebase | |
import requests | |
from openai import OpenAI | |
from PIL import Image, ImageDraw, ImageFont | |
from prompts import remove_unwanted_prompt | |
from ultralytics import YOLO | |
def get_middle_thumbnail(input_image: Image, grid_size=(10, 10), padding=3): | |
""" | |
Extract the middle thumbnail from a sprite sheet, handling different aspect ratios | |
and removing padding. | |
Args: | |
input_image: PIL Image | |
grid_size: Tuple of (columns, rows) | |
padding: Number of padding pixels on each side (default 3) | |
Returns: | |
PIL.Image: The middle thumbnail image with padding removed | |
""" | |
sprite_sheet = input_image | |
# Calculate thumbnail dimensions based on actual sprite sheet size | |
sprite_width, sprite_height = sprite_sheet.size | |
thumb_width_with_padding = sprite_width // grid_size[0] | |
thumb_height_with_padding = sprite_height // grid_size[1] | |
# Remove padding to get actual image dimensions | |
thumb_width = thumb_width_with_padding - (2 * padding) # 726 - 6 = 720 | |
thumb_height = thumb_height_with_padding - (2 * padding) # varies based on input | |
# Calculate the middle position | |
total_thumbs = grid_size[0] * grid_size[1] | |
middle_index = total_thumbs // 2 | |
# Calculate row and column of middle thumbnail | |
middle_row = middle_index // grid_size[0] | |
middle_col = middle_index % grid_size[0] | |
# Calculate pixel coordinates for cropping, including padding offset | |
left = (middle_col * thumb_width_with_padding) + padding | |
top = (middle_row * thumb_height_with_padding) + padding | |
right = left + thumb_width # Don't add padding here | |
bottom = top + thumb_height # Don't add padding here | |
# Crop and return the middle thumbnail | |
middle_thumb = sprite_sheet.crop((left, top, right, bottom)) | |
return middle_thumb | |
def get_person_bbox(frame, model): | |
"""Detect person and return the largest bounding box""" | |
results = model(frame, classes=[0]) # class 0 is person in COCO | |
if not results or len(results[0].boxes) == 0: | |
return None | |
# Get all person boxes | |
boxes = results[0].boxes.xyxy.cpu().numpy() | |
# Calculate areas to find the largest person | |
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |
largest_idx = np.argmax(areas) | |
return boxes[largest_idx] | |
def generate_crops(frame): | |
"""Generate both 16:9 and 9:16 crops based on person detection""" | |
# Load YOLO model | |
model = YOLO("yolo11n.pt") | |
# Convert PIL Image to cv2 format if needed | |
if isinstance(frame, Image.Image): | |
frame = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR) | |
original_height, original_width = frame.shape[:2] | |
bbox = get_person_bbox(frame, model) | |
if bbox is None: | |
return None, None | |
# Extract coordinates | |
x1, y1, x2, y2 = map(int, bbox) | |
person_height = y2 - y1 | |
person_width = x2 - x1 | |
person_center_x = (x1 + x2) // 2 | |
person_center_y = (y1 + y2) // 2 | |
# Generate 16:9 crop (focus on upper body) | |
aspect_ratio_16_9 = 16 / 9 | |
crop_width_16_9 = min(original_width, int(person_height * aspect_ratio_16_9)) | |
crop_height_16_9 = min(original_height, int(crop_width_16_9 / aspect_ratio_16_9)) | |
# For 16:9, center horizontally and align top with person's top | |
x1_16_9 = max(0, person_center_x - crop_width_16_9 // 2) | |
x2_16_9 = min(original_width, x1_16_9 + crop_width_16_9) | |
y1_16_9 = max(0, y1) # Start from person's top | |
y2_16_9 = min(original_height, y1_16_9 + crop_height_16_9) | |
# Adjust if exceeding boundaries | |
if x2_16_9 > original_width: | |
x1_16_9 = original_width - crop_width_16_9 | |
x2_16_9 = original_width | |
if y2_16_9 > original_height: | |
y1_16_9 = original_height - crop_height_16_9 | |
y2_16_9 = original_height | |
# Generate 9:16 crop (full body) | |
aspect_ratio_9_16 = 9 / 16 | |
crop_width_9_16 = min(original_width, int(person_height * aspect_ratio_9_16)) | |
crop_height_9_16 = min(original_height, int(crop_width_9_16 / aspect_ratio_9_16)) | |
# For 9:16, center both horizontally and vertically | |
x1_9_16 = max(0, person_center_x - crop_width_9_16 // 2) | |
x2_9_16 = min(original_width, x1_9_16 + crop_width_9_16) | |
y1_9_16 = max(0, person_center_y - crop_height_9_16 // 2) | |
y2_9_16 = min(original_height, y1_9_16 + crop_height_9_16) | |
# Adjust if exceeding boundaries | |
if x2_9_16 > original_width: | |
x1_9_16 = original_width - crop_width_9_16 | |
x2_9_16 = original_width | |
if y2_9_16 > original_height: | |
y1_9_16 = original_height - crop_height_9_16 | |
y2_9_16 = original_height | |
# Create crops | |
crop_16_9 = frame[y1_16_9:y2_16_9, x1_16_9:x2_16_9] | |
crop_9_16 = frame[y1_9_16:y2_9_16, x1_9_16:x2_9_16] | |
# Resize to standard dimensions | |
crop_16_9 = cv2.resize(crop_16_9, (426, 240)) # 16:9 aspect ratio | |
crop_9_16 = cv2.resize(crop_9_16, (240, 426)) # 9:16 aspect ratio | |
return crop_16_9, crop_9_16 | |
def visualize_crops(image, bbox, crops_info): | |
""" | |
Visualize original bbox and calculated crops | |
bbox: [x1, y1, x2, y2] | |
crops_info: dict with 'crop_16_9' and 'crop_9_16' coordinates | |
""" | |
viz = image.copy() | |
# Draw original person bbox in blue | |
cv2.rectangle( | |
viz, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (255, 0, 0), 2 | |
) | |
# Draw 16:9 crop in green | |
crop_16_9 = crops_info["crop_16_9"] | |
cv2.rectangle( | |
viz, | |
(int(crop_16_9["x1"]), int(crop_16_9["y1"])), | |
(int(crop_16_9["x2"]), int(crop_16_9["y2"])), | |
(0, 255, 0), | |
2, | |
) | |
# Draw 9:16 crop in red | |
crop_9_16 = crops_info["crop_9_16"] | |
cv2.rectangle( | |
viz, | |
(int(crop_9_16["x1"]), int(crop_9_16["y1"])), | |
(int(crop_9_16["x2"]), int(crop_9_16["y2"])), | |
(0, 0, 255), | |
2, | |
) | |
return viz | |
def encode_image_to_base64(image: Image.Image, format: str = "JPEG") -> str: | |
""" | |
Convert a PIL image to a base64 string. | |
Args: | |
image: PIL Image object | |
format: Image format to use for encoding (default: PNG) | |
Returns: | |
Base64 encoded string of the image | |
""" | |
buffered = BytesIO() | |
image.save(buffered, format=format) | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
def add_top_numbers( | |
input_image, | |
num_divisions=20, | |
margin=90, | |
font_size=120, | |
dot_spacing=20, | |
): | |
""" | |
Add numbered divisions across the top and bottom of any image with dotted vertical lines. | |
Args: | |
input_image (Image): PIL Image | |
num_divisions (int): Number of divisions to create | |
margin (int): Size of margin in pixels for numbers | |
font_size (int): Font size for numbers | |
dot_spacing (int): Spacing between dots in pixels | |
""" | |
# Load the image | |
original_image = input_image | |
# Create new image with extra space for numbers on top and bottom | |
new_width = original_image.width | |
new_height = original_image.height + ( | |
2 * margin | |
) # Add margin to both top and bottom | |
new_image = Image.new("RGB", (new_width, new_height), "white") | |
# Paste original image in the middle | |
new_image.paste(original_image, (0, margin)) | |
# Initialize drawing context | |
draw = ImageDraw.Draw(new_image) | |
try: | |
font = ImageFont.truetype("arial.ttf", font_size) | |
except OSError: | |
print("Using default font") | |
font = ImageFont.load_default(size=font_size) | |
# Calculate division width | |
division_width = original_image.width / num_divisions | |
# Draw division numbers and dotted lines | |
for i in range(num_divisions): | |
x = (i * division_width) + (division_width / 2) | |
# Draw number at top | |
draw.text((x, margin // 2), str(i + 1), fill="black", font=font, anchor="mm") | |
# Draw number at bottom | |
draw.text( | |
(x, new_height - (margin // 2)), | |
str(i + 1), | |
fill="black", | |
font=font, | |
anchor="mm", | |
) | |
# Draw dotted line from top margin to bottom margin | |
y_start = margin | |
y_end = new_height - margin | |
# Draw dots with specified spacing | |
current_y = y_start | |
while current_y < y_end: | |
draw.circle( | |
[x - 1, current_y - 1, x + 1, current_y + 1], | |
fill="black", | |
width=5, | |
radius=3, | |
) | |
current_y += dot_spacing | |
return new_image | |
def crop_and_draw_divisions( | |
input_image, | |
left_division, | |
right_division, | |
num_divisions=20, | |
line_color=(255, 0, 0), | |
line_width=2, | |
head_margin_percent=0.1, | |
): | |
""" | |
Create both 9:16 and 16:9 crops and draw guide lines. | |
Args: | |
input_image (Image): PIL Image | |
left_division (int): Left-side division number (1-20) | |
right_division (int): Right-side division number (1-20) | |
num_divisions (int): Total number of divisions (default=20) | |
line_color (tuple): RGB color tuple for lines (default: red) | |
line_width (int): Width of lines in pixels (default: 2) | |
head_margin_percent (float): Percentage margin above head (default: 0.1) | |
Returns: | |
tuple: (cropped_image_16_9, image_with_lines, cropped_image_9_16) | |
""" | |
yolo_model = YOLO("yolo11n.pt") | |
# Calculate division width and boundaries | |
division_width = input_image.width / num_divisions | |
left_boundary = (left_division - 1) * division_width | |
right_boundary = right_division * division_width | |
# First get the 9:16 crop | |
cropped_image_9_16 = input_image.crop( | |
(left_boundary, 0, right_boundary, input_image.height) | |
) | |
# Run YOLO on the 9:16 crop to get person bbox | |
bbox = yolo_model(cropped_image_9_16, classes=[0])[0].boxes.xyxy.cpu().numpy()[0] | |
x1, y1, x2, y2 = bbox | |
# Calculate top boundary with head margin | |
head_margin = (y2 - y1) * head_margin_percent | |
top_boundary = max(0, y1 - head_margin) | |
# Calculate 16:9 dimensions based on the width between divisions | |
crop_width = right_boundary - left_boundary | |
crop_height_16_9 = int(crop_width * 9 / 16) | |
# Calculate bottom boundary for 16:9 | |
bottom_boundary = min(input_image.height, top_boundary + crop_height_16_9) | |
# Create 16:9 crop from original image | |
cropped_image_16_9 = input_image.crop( | |
(left_boundary, top_boundary, right_boundary, bottom_boundary) | |
) | |
# Draw guide lines for both crops on original image | |
image_with_lines = input_image.copy() | |
draw = ImageDraw.Draw(image_with_lines) | |
# Draw vertical lines (for both crops) | |
draw.line( | |
[(left_boundary, 0), (left_boundary, input_image.height)], | |
fill=line_color, | |
width=line_width, | |
) | |
draw.line( | |
[(right_boundary, 0), (right_boundary, input_image.height)], | |
fill=line_color, | |
width=line_width, | |
) | |
# Draw horizontal lines (for 16:9 crop) | |
draw.line( | |
[(left_boundary, top_boundary), (right_boundary, top_boundary)], | |
fill=line_color, | |
width=line_width, | |
) | |
draw.line( | |
[(left_boundary, bottom_boundary), (right_boundary, bottom_boundary)], | |
fill=line_color, | |
width=line_width, | |
) | |
return cropped_image_16_9, image_with_lines, cropped_image_9_16 | |
def analyze_image(numbered_input_image: Image, prompt, input_image): | |
""" | |
Perform inference on an image using GPT-4V. | |
Args: | |
numbered_input_image (Image): PIL Image | |
prompt (str): The prompt/question about the image | |
input_image (Image): input image without numbers | |
Returns: | |
str: The model's response | |
""" | |
client = OpenAI() | |
base64_image = encode_image_to_base64(numbered_input_image, format="JPEG") | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": prompt}, | |
{ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}, | |
}, | |
], | |
} | |
] | |
response = client.chat.completions.create( | |
model="gpt-4o", messages=messages, max_tokens=300 | |
) | |
messages.extend( | |
[ | |
{"role": "assistant", "content": response.choices[0].message.content}, | |
{ | |
"role": "user", | |
"content": "please return the response in the json with keys left_row and right_row", | |
}, | |
], | |
) | |
response = ( | |
client.chat.completions.create(model="gpt-4o", messages=messages) | |
.choices[0] | |
.message.content | |
) | |
left_index = response.find("{") | |
right_index = response.rfind("}") | |
try: | |
if left_index != -1 and right_index != -1: | |
response_json = eval(response[left_index : right_index + 1]) | |
cropped_image_16_9, image_with_lines, cropped_image_9_16 = ( | |
crop_and_draw_divisions( | |
input_image=input_image, | |
left_division=response_json["left_row"], | |
right_division=response_json["right_row"], | |
) | |
) | |
except Exception as e: | |
print(e) | |
return input_image, input_image, input_image | |
return cropped_image_16_9, image_with_lines, cropped_image_9_16 | |
def get_sprite_firebase(cid, rsid, uid): | |
config = { | |
"apiKey": f"{os.getenv('FIREBASE_API_KEY')}", | |
"authDomain": f"{os.getenv('FIREBASE_AUTH_DOMAIN')}", | |
"databaseURL": f"{os.getenv('FIREBASE_DATABASE_URL')}", | |
"projectId": f"{os.getenv('FIREBASE_PROJECT_ID')}", | |
"storageBucket": f"{os.getenv('FIREBASE_STORAGE_BUCKET')}", | |
"messagingSenderId": f"{os.getenv('FIREBASE_MESSAGING_SENDER_ID')}", | |
"appId": f"{os.getenv('FIREBASE_APP_ID')}", | |
"measurementId": f"{os.getenv('FIREBASE_MEASUREMENT_ID')}", | |
} | |
firebase = pyrebase.initialize_app(config) | |
db = firebase.database() | |
account_id = os.getenv("ROLL_ACCOUNT") | |
COLLAB_EDIT_LINK = "collab_sprite_link_handler" | |
path = f"{account_id}/{COLLAB_EDIT_LINK}/{uid}/{cid}/{rsid}" | |
data = db.child(path).get() | |
print(config, data.val()) | |
return data.val() | |
def get_image_crop(cid=None, rsid=None, uid=None): | |
"""Function that returns both 16:9 and 9:16 crops""" | |
image_paths = get_sprite_firebase(cid, rsid, uid) | |
input_images = [] | |
mid_images = [] | |
cropped_image_16_9s = [] | |
images_with_lines = [] | |
cropped_image_9_16s = [] | |
for image_path in image_paths: | |
response = requests.get(image_path) | |
input_image = Image.open(BytesIO(response.content)) | |
input_images.append(input_image) | |
# Get the middle thumbnail | |
mid_image = get_middle_thumbnail(input_image) | |
mid_images.append(mid_image) | |
numbered_mid_image = add_top_numbers( | |
input_image=mid_image, | |
num_divisions=20, | |
margin=50, | |
font_size=30, | |
dot_spacing=20, | |
) | |
cropped_image_16_9, image_with_lines, cropped_image_9_16 = analyze_image( | |
numbered_mid_image, remove_unwanted_prompt(2), mid_image | |
) | |
cropped_image_16_9s.append(cropped_image_16_9) | |
images_with_lines.append(image_with_lines) | |
cropped_image_9_16s.append(cropped_image_9_16) | |
return gr.Gallery( | |
[ | |
*input_images, | |
*mid_images, | |
*cropped_image_16_9s, | |
*images_with_lines, | |
*cropped_image_9_16s, | |
] | |
) | |