Spaces:
Running
Running
import torch | |
import gradio as gr | |
from diffusers import ( | |
StableDiffusionPipeline, | |
StableDiffusionInstructPix2PixPipeline, | |
StableVideoDiffusionPipeline, | |
WanPipeline, | |
) | |
from diffusers.utils import export_to_video, load_image | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 if device == "cuda" else torch.float32 | |
def make_pipe(cls, model_id, **kwargs): | |
pipe = cls.from_pretrained(model_id, torch_dtype=dtype, **kwargs) | |
pipe.enable_model_cpu_offload() | |
return pipe | |
TXT2IMG_PIPE = None | |
IMG2IMG_PIPE = None | |
TXT2VID_PIPE = None | |
IMG2VID_PIPE = None | |
def generate_image_from_text(prompt): | |
global TXT2IMG_PIPE | |
if TXT2IMG_PIPE is None: | |
TXT2IMG_PIPE = make_pipe( | |
StableDiffusionPipeline, | |
"stabilityai/stable-diffusion-2-1-base" | |
).to(device) | |
return TXT2IMG_PIPE(prompt, num_inference_steps=20).images[0] | |
def generate_image_from_image_and_prompt(image, prompt): | |
global IMG2IMG_PIPE | |
if IMG2IMG_PIPE is None: | |
IMG2IMG_PIPE = make_pipe( | |
StableDiffusionInstructPix2PixPipeline, | |
"timbrooks/instruct-pix2pix" | |
).to(device) | |
out = IMG2IMG_PIPE(prompt=prompt, image=image, num_inference_steps=8) | |
return out.images[0] | |
def generate_video_from_text(prompt): | |
global TXT2VID_PIPE | |
if TXT2VID_PIPE is None: | |
TXT2VID_PIPE = make_pipe( | |
WanPipeline, | |
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers" | |
).to(device) | |
frames = TXT2VID_PIPE(prompt=prompt, num_frames=12).frames[0] | |
return export_to_video(frames, "/tmp/wan_video.mp4", fps=8) | |
def generate_video_from_image(image): | |
global IMG2VID_PIPE | |
if IMG2VID_PIPE is None: | |
IMG2VID_PIPE = make_pipe( | |
StableVideoDiffusionPipeline, | |
"stabilityai/stable-video-diffusion-img2vid-xt", | |
variant="fp16" if dtype == torch.float16 else None | |
).to(device) | |
image = load_image(image).resize((512, 288)) | |
frames = IMG2VID_PIPE(image, num_inference_steps=16).frames[0] | |
return export_to_video(frames, "/tmp/svd_video.mp4", fps=8) | |
with gr.Blocks() as demo: | |
gr.Markdown("## π§ Lightweight Any-to-Any AI Playground") | |
with gr.Tab("Text β Image"): | |
text_input = gr.Textbox(label="Prompt") | |
image_output = gr.Image(label="Generated Image") | |
generate_button = gr.Button("Generate") | |
generate_button.click(generate_image_from_text, inputs=text_input, outputs=image_output) | |
with gr.Tab("Image β Image"): | |
input_image = gr.Image(label="Input Image") | |
prompt_input = gr.Textbox(label="Edit Prompt") | |
output_image = gr.Image(label="Edited Image") | |
edit_button = gr.Button("Generate") | |
edit_button.click(generate_image_from_image_and_prompt, inputs=[input_image, prompt_input], outputs=output_image) | |
with gr.Tab("Text β Video"): | |
video_prompt = gr.Textbox(label="Prompt") | |
video_output = gr.Video(label="Generated Video") | |
video_button = gr.Button("Generate") | |
video_button.click(generate_video_from_text, inputs=video_prompt, outputs=video_output) | |
with gr.Tab("Image β Video"): | |
anim_image = gr.Image(label="Input Image") | |
anim_video_output = gr.Video(label="Animated Video") | |
anim_button = gr.Button("Animate") | |
anim_button.click(generate_video_from_image, inputs=anim_image, outputs=anim_video_output) | |
demo.launch() | |