vid2grid / app.py
ZachNagengast's picture
Fix broken gifs with default
7e8a433
raw
history blame
12.5 kB
import gradio as gr
from PIL import Image, ImageDraw, ImageFont, ImageSequence
import numpy as np
import cv2
import os
import tempfile
global stored_frames
def load_and_store_frames(image_file, grid_x, grid_y):
global stored_frames
try:
# Make sure file exists
if image_file is None:
return "File not found", ""
print(f"Loading frames for {image_file.name}")
if image_file.name.endswith('.mp4'):
frames = extract_frames_from_video(image_file.name)
video_path = image_file.name
else: # it's a gif
try:
img = Image.open(image_file.name)
except Exception as e:
print(f"Could not open GIF file: {e}")
return "Could not open GIF file", ""
frames = []
for i in range(0, img.n_frames):
try:
img.seek(i)
frames.append(img.copy())
except Exception as e:
print(f"Could not seek to frame {i}: {e}")
# Convert GIF to MP4 for preview
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
video_path = tmp_file.name
try:
duration = img.info.get('duration', 100)
# default to reasonable framerate if duration is 0
framerate = 1 / (duration / 1000.0) if duration > 0 else 10
print(f"frame count: {len(frames)} framerate: {duration} {img.info}")
convert_gif_to_video(image_file.name, tmp_file.name, framerate)
except Exception as e:
print(f"Could not convert GIF to MP4: {e}")
stored_frames = frames # Store the frames for later use
total_frames = len(frames)
selected_frames_count = grid_x * grid_y
details = f"**Total Frames:** {len(frames)}\n\n"
output_info = f"Grid size: {grid_x} x {grid_y}\n\nSelected Frames: {selected_frames_count} / {total_frames} ({selected_frames_count / total_frames * 100:.2f}%)"
return f"Frames loaded successfully\n\n{details}\n\n{output_info}", video_path
except Exception as e:
print(f"An error occurred while loading and storing frames: {e}")
return f"An error occurred: {e}", ""
def generate_grid(grid_x, grid_y, font_size, font_color, position, border_size, border_color):
global stored_frames
# print(f"Processing grid with {grid_x} x {grid_y} grid size, font size {font_size}, font color {font_color}, position {position}, border size {border_size}, border color {border_color}")
if stored_frames is None:
load_and_store_frames()
grid_img, output_info = create_grid(stored_frames, grid_x, grid_y, font_size, font_color, position, border_size, border_color)
details = f"Total Frames: {len(stored_frames)}\n\n{output_info}"
return grid_img, details
def create_grid(frames, grid_x, grid_y, font_size, font_color, position, border_size, border_color):
total_frames = len(frames)
selected_frames_count = grid_x * grid_y
# Select evenly spaced frames
selected_frames_indices = np.linspace(0, total_frames - 1, selected_frames_count).astype(int)
selected_frames = [frames[i] for i in selected_frames_indices]
# Modify frames by adding border and number
modified_frames = []
try:
font = ImageFont.truetype("Lato-Regular.ttf", font_size)
except IOError:
print("Font not found, using default font.")
font = ImageFont.load_default()
positions = {
"Top Left": (20, 20),
"Top Right": (frames[0].width - 20 - font_size, 20),
"Bottom Left": (20, frames[0].height - 20 - font_size),
"Bottom Right": (frames[0].width - 20 - font_size, frames[0].height - 20 - font_size)
}
for i, frame in enumerate(selected_frames):
# Add border
border_width = border_size
frame_with_border = Image.new('RGB', (frame.width + 2*border_width, frame.height + 2*border_width), border_color.lower())
frame_with_border.paste(frame, (border_width, border_width))
# Add number
draw = ImageDraw.Draw(frame_with_border)
text = str(i + 1)
text_position = (border_width + positions[position][0], border_width + positions[position][1])
draw.text(text_position, text, font=font, fill=font_color)
modified_frames.append(frame_with_border)
# Combine modified frames into a grid
grid_width = modified_frames[0].width * grid_x
grid_height = modified_frames[0].height * grid_y
grid_img = Image.new('RGB', (grid_width, grid_height), border_color.lower())
for i, frame in enumerate(modified_frames):
x_offset = (i % grid_x) * frame.width
y_offset = (i // grid_x) * frame.height
grid_img.paste(frame, (x_offset, y_offset))
output_info = f"Grid size: {grid_x} x {grid_y}\n\nSelected Frames: {selected_frames_count} / {total_frames} ({selected_frames_count / total_frames * 100:.2f}%)"
return grid_img, output_info
def extract_frames_from_video(video_file):
"""Extract frames from an MP4 video."""
frames = []
cap = cv2.VideoCapture(video_file)
while True:
ret, frame = cap.read()
if not ret:
break
# Convert BGR format (used by OpenCV) to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame_rgb))
cap.release()
return frames
def convert_gif_to_video(gif_path, output_video_path, frame_rate):
try:
# Load the gif
gif = Image.open(gif_path)
except Exception as e:
print(f"Could not open GIF file: {e}")
return
try:
# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (gif.width, gif.height))
except Exception as e:
print(f"Could not create VideoWriter object: {e}")
return
try:
# Iterate over the frames of the gif
for frame_index in range(gif.n_frames):
gif.seek(frame_index)
# Convert the PIL Image to an array
frame_arr = np.array(gif.convert("RGB"))
# Convert RGB to BGR format
frame_bgr = cv2.cvtColor(frame_arr, cv2.COLOR_RGB2BGR)
# Write the frame to the video
out.write(frame_bgr)
except Exception as e:
print(f"Could not write frame to video: {e}")
out.release()
def gif_or_video_info(image_file, grid_x, grid_y, font_size, font_color, position, border_size, border_color):
image_file.file.seek(0)
video_path = ""
if image_file.name.endswith('.mp4'):
video_path = image_file.name
cap = cv2.VideoCapture(image_file.name)
frame_rate = cap.get(cv2.CAP_PROP_FPS) # Get the actual frame rate of the video
frames = extract_frames_from_video(image_file.name)
total_frames = len(frames)
cap.release()
else: # it's a gif
img = Image.open(image_file.name)
frames = []
for i in range(0, img.n_frames):
img.seek(i)
frames.append(img.copy())
total_frames = img.n_frames
frame_rate = 1 / (img.info.get('duration', 100) / 1000.0) # Convert to seconds
# Convert GIF to MP4 and save it to a temp path
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
video_path = tmp_file.name
convert_gif_to_video(image_file.name, tmp_file.name, frame_rate)
grid_img, output_info = create_grid(frames, grid_x, grid_y, font_size, font_color, position, border_size, border_color)
details = f"**Total Frames:** {total_frames}\n\n**Frame Rate:** {frame_rate} frames/sec\n\n{output_info}"
return grid_img, details, video_path
def gif_info(image_file, grid_x, grid_y, font_size, font_color, position, border_size, border_color):
return gif_or_video_info(image_file, grid_x, grid_y, font_size, font_color, position, border_size, border_color)
def mirror(x):
return x
with gr.Blocks() as app:
gr.Markdown('## vid2grid Generator')
gr.Markdown('Upload a GIF or MP4 to generate a grid from its frames. Use the sliders to adjust the grid size and text settings.\n\nThis is particularly useful for use with multi modal models such as GPT-4V to retrieve descriptions of short videos or gifs, [example here.](https://twitter.com/zachnagengast/status/1712896232170180651)\n\n **Note:** The grid will be generated only after clicking the "Generate Grid" button.')
with gr.Row():
with gr.Column():
control_image = gr.File(label="Upload a short MP4 or GIF", type="file", elem_id="file_upload", file_types=[".gif", ".mp4"])
video_preview = gr.Video(interactive=False, label="Preview", format="mp4")
gif_details = gr.Markdown("No file found.")
# gr.Examples(
# examples=[os.path.join(os.path.dirname(__file__), "demo.mp4")],
# inputs=[control_image],
# outputs=[gif_details, video_preview],
# fn=load_and_store_frames,
# cache_examples=True,
# )
process_button = gr.Button("Generate Grid") # New button to trigger the heavy computation
grid_x_slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Grid X Size")
grid_y_slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Grid Y Size")
font_color_dropdown = gr.Dropdown(choices=["Black", "White", "Red", "Green", "Blue"], value="White", label="Numbering Color")
position_radio = gr.Radio(choices=["Top Left", "Top Right", "Bottom Left", "Bottom Right"], value="Top Left", label="Numbering Position")
font_size_slider = gr.Slider(minimum=10, maximum=100, step=5, value=40, label="Font Size")
border_color_dropdown = gr.Dropdown(choices=["Black", "White", "Red", "Green", "Blue"], value="White", label="Border Color")
border_size_slider = gr.Slider(minimum=0, maximum=100, step=5, value=10, label="Border Size")
with gr.Column():
result_image = gr.Image(label="Generated Grid", value="https://i.imgur.com/fYrBwbd.png")
# Use .change() method to listen for changes in any of the controls
control_image.upload(load_and_store_frames, inputs=[control_image, grid_x_slider, grid_y_slider], outputs=[gif_details, video_preview])
# grid_x_slider.change(generate_grid, inputs=[grid_x_slider, grid_y_slider, font_size_slider, font_color_dropdown, position_radio, border_size_slider, border_color_dropdown], outputs=[result_image, gif_details, video_preview])
# grid_y_slider.change(generate_grid, inputs=[grid_x_slider, grid_y_slider, font_size_slider, font_color_dropdown, position_radio, border_size_slider, border_color_dropdown], outputs=[result_image, gif_details])
# font_size_slider.change(generate_grid, inputs=[grid_x_slider, grid_y_slider, font_size_slider, font_color_dropdown, position_radio, border_size_slider, border_color_dropdown], outputs=[result_image, gif_details])
# font_color_dropdown.change(generate_grid, inputs=[grid_x_slider, grid_y_slider, font_size_slider, font_color_dropdown, position_radio, border_size_slider, border_color_dropdown], outputs=[result_image, gif_details])
# position_radio.change(generate_grid, inputs=[grid_x_slider, grid_y_slider, font_size_slider, font_color_dropdown, position_radio, border_size_slider, border_color_dropdown], outputs=[result_image, gif_details])
# border_size_slider.change(generate_grid, inputs=[grid_x_slider, grid_y_slider, font_size_slider, font_color_dropdown, position_radio, border_size_slider, border_color_dropdown], outputs=[result_image, gif_details])
# border_color_dropdown.change(generate_grid, inputs=[grid_x_slider, grid_y_slider, font_size_slider, font_color_dropdown, position_radio, border_size_slider, border_color_dropdown], outputs=[result_image, gif_details])
process_button.click(generate_grid, inputs=[grid_x_slider, grid_y_slider, font_size_slider, font_color_dropdown, position_radio, border_size_slider, border_color_dropdown], outputs=[result_image, gif_details])
if __name__ == "__main__":
stored_frames = None
app.launch()