import gradio as gr from transformers import pipeline import io, base64 from PIL import Image import numpy as np import tensorflow as tf import mediapy import os import sys from huggingface_hub import snapshot_download # 1. GPT-J: Story Generation Pipeline story_gen = pipeline("text-generation", "pranavpsv/gpt2-genre-story-generator") # 2. LatentDiffusion: Latent Diffusion Interface image_gen = gr.Interface.load("spaces/multimodalart/latentdiffusion") # 3. FILM: Frame Interpolation Model (code re-use from spaces/akhaliq/frame-interpolation/tree/main) os.system("git clone https://github.com/google-research/frame-interpolation") sys.path.append("frame-interpolation") from eval import interpolator, util ffmpeg_path = util.get_ffmpeg_path() mediapy.set_ffmpeg(ffmpeg_path) model = snapshot_download(repo_id="akhaliq/frame-interpolation-film-style") interpolator = interpolator.Interpolator(model, None) def generate_story(choice, input_text): query = " <{0}> {1}".format(choice, input_text) print(query) generated_text = story_gen(query) generated_text = generated_text[0]['generated_text'] generated_text = generated_text.split('> ')[2] return generated_text def generate_images(generated_text): steps=50 width=256 height=256 num_images=4 diversity=6 image_bytes = image_gen(generated_text, steps, width, height, num_images, diversity) # Algo from spaces/Gradio-Blocks/latent_gpt2_story/blob/main/app.py generated_images = [] for image in image_bytes[1]: image_str = image[0] image_str = image_str.replace("data:image/png;base64,","") decoded_bytes = base64.decodebytes(bytes(image_str, "utf-8")) img = Image.open(io.BytesIO(decoded_bytes)) generated_images.append(img) return generated_images def generate_interpolation(gallery): times_to_interpolate = 4 generated_images = [] for image_str in gallery: image_str = image_str.replace("data:image/png;base64,","") decoded_bytes = base64.decodebytes(bytes(image_str, "utf-8")) img = Image.open(io.BytesIO(decoded_bytes)) generated_images.append(img) generated_images[0].save('frame_0.png') generated_images[1].save('frame_1.png') generated_images[2].save('frame_2.png') generated_images[3].save('frame_3.png') input_frames = ["frame_0.png", "frame_1.png", "frame_2.png", "frame_3.png"] frames = list(util.interpolate_recursively_from_files(input_frames, times_to_interpolate, interpolator)) mediapy.write_video("out.mp4", frames, fps=15) return "out.mp4" demo = gr.Blocks() with demo: with gr.Row(): # Left column (inputs) with gr.Column(): input_story_type = gr.Radio(choices=['superhero', 'action', 'drama', 'horror', 'thriller', 'sci_fi'], value='sci_fi', label="Genre") input_start_text = gr.Textbox(placeholder='A teddy bear outer space', label="Starting Text") gr.Markdown("Be sure to run each of the buttons one at a time, they depend on each others' outputs!") # Rows of instructions & buttons with gr.Row(): gr.Markdown("1. Select a type of story, then write some starting text! Then hit the 'Generate Story' button to generate a story! Feel free to edit the generated story afterwards!") button_gen_story = gr.Button("Generate Story") with gr.Row(): gr.Markdown("2. After generating a story, hit the 'Generate Images' button to create some visuals for your story! (Can re-run multiple times!)") button_gen_images = gr.Button("Generate Images") with gr.Row(): gr.Markdown("3. After generating some images, hit the 'Generate Video' button to create a short video by interpolating the previously generated visuals!") button_gen_video = gr.Button("Generate Video") # Rows of references with gr.Row(): gr.Markdown("--Models Used--") with gr.Row(): gr.Markdown("Story Generation: [GPT-J](https://huggingface.co/pranavpsv/gpt2-genre-story-generator)") with gr.Row(): gr.Markdown("Image Generation Conditioned on Text: [Latent Diffusion](https://huggingface.co/spaces/multimodalart/latentdiffusion) | [Github Repo](https://github.com/CompVis/latent-diffusion)") with gr.Row(): gr.Markdown("Interpolations: [FILM](https://huggingface.co/spaces/akhaliq/frame-interpolation) | [Github Repo](https://github.com/google-research/frame-interpolation)") with gr.Row(): gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=gradio-blocks_story_and_video_generation)") # Right column (outputs) with gr.Column(): output_generated_story = gr.Textbox(label="Generated Story") output_gallery = gr.Gallery(label="Generated Story Images") output_interpolation = gr.Video(label="Generated Video") # Bind functions to buttons button_gen_story.click(fn=generate_story, inputs=[input_story_type , input_start_text], outputs=output_generated_story) button_gen_images.click(fn=generate_images, inputs=output_generated_story, outputs=output_gallery) button_gen_video.click(fn=generate_interpolation, inputs=output_gallery, outputs=output_interpolation) demo.launch(debug=True, enable_queue=True)