| | from diffusers import DiffusionPipeline |
| | import gradio as gr |
| | import torch |
| | import cv2 |
| | import os |
| |
|
| | MY_SECRET_TOKEN=os.environ.get('HF_TOKEN_SD') |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | pipe = DiffusionPipeline.from_pretrained( |
| | "CompVis/stable-diffusion-v1-4", |
| | use_auth_token=MY_SECRET_TOKEN, |
| | |
| | |
| | safety_checker=None, |
| | custom_pipeline="interpolate_stable_diffusion", |
| | ).to(device) |
| | pipe.enable_attention_slicing() |
| |
|
| | def run(prompt1, seed1, prompt2, seed2, prompt3, seed3): |
| | |
| | frame_filepaths = pipe.walk( |
| | prompts=[prompt1, prompt2, prompt3], |
| | seeds=[seed1, seed2, seed3], |
| | num_interpolation_steps=16, |
| | output_dir='./dreams', |
| | batch_size=4, |
| | height=512, |
| | width=512, |
| | guidance_scale=8.5, |
| | num_inference_steps=50, |
| | ) |
| | print(frame_filepaths) |
| | |
| | frame = cv2.imread(frame_filepaths[0]) |
| | height, width, layers = frame.shape |
| | fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') |
| | video = cv2.VideoWriter("out.mp4", fourcc, 24, (width,height)) |
| | for image in frame_filepaths: |
| | |
| | video.write(cv2.imread(image)) |
| | |
| | video.release() |
| | cv2.destroyAllWindows() |
| | |
| | |
| | return "out.mp4", frame_filepaths |
| |
|
| | with gr.Blocks() as demo: |
| | with gr.Column(): |
| | gr.HTML(''' |
| | <h1 style='font-size: 2em;text-align:center;font-weigh:900;'> |
| | Stable Diffusion Interpolation • Community pipeline |
| | </h1> |
| | <p style='text-align: center;'><br /> |
| | This community pipeline returns a list of images saved under the folder as defined in output_dir. <br /> |
| | You can use these images to create videos of stable diffusion. |
| | </p> |
| | |
| | <p style='text-align: center;'> |
| | This demo can be run on a GPU of at least 8GB VRAM and should take approximately 5 minutes.<br /> |
| | — |
| | </p> |
| | |
| | ''') |
| | with gr.Row(): |
| | with gr.Column(): |
| | with gr.Column(): |
| | with gr.Row(): |
| | intpol_prompt_1 = gr.Textbox(lines=1, label="prompt 1") |
| | seed1 = gr.Slider(label = "Seed 1", minimum = 0, maximum = 2147483647, step = 1, randomize = True) |
| | with gr.Row(): |
| | intpol_prompt_2 = gr.Textbox(lines=1, label="prompt 2") |
| | seed2 = gr.Slider(label = "Seed 2", minimum = 0, maximum = 2147483647, step = 1, randomize = True) |
| | with gr.Row(): |
| | intpol_prompt_3 = gr.Textbox(lines=1, label="prompt 3") |
| | seed3 = gr.Slider(label = "Seed 3", minimum = 0, maximum = 2147483647, step = 1, randomize = True) |
| | intpol_run = gr.Button("Run Interpolation") |
| | |
| | with gr.Column(): |
| | video_output = gr.Video(label="Generated video", show_label=True) |
| | gallery_output = gr.Gallery(label="Generated images", show_label=False).style(grid=2, height="auto") |
| | |
| | intpol_run.click(run, inputs=[intpol_prompt_1, seed1, intpol_prompt_2, seed2, intpol_prompt_3, seed3], outputs=[video_output, gallery_output]) |
| | |
| | demo.launch() |