Spaces:
Paused
Paused
| import spaces | |
| import os | |
| import json | |
| import torch | |
| import random | |
| import requests | |
| from PIL import Image | |
| import numpy as np | |
| import gradio as gr | |
| from datetime import datetime | |
| import torchvision.transforms as T | |
| from diffusers import DDIMScheduler | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from consisti2v.pipelines.pipeline_conditional_animation import ConditionalAnimationPipeline | |
| from consisti2v.utils.util import save_videos_grid | |
| from omegaconf import OmegaConf | |
| from transformers import pipeline as translation_pipeline | |
| # Translation pipeline for Korean to English | |
| translator = translation_pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") | |
| sample_idx = 0 | |
| scheduler_dict = { | |
| "DDIM": DDIMScheduler, | |
| } | |
| css = """ | |
| .toolbutton { | |
| margin-bottom: 0em 0em 0em 0em; | |
| max-width: 2.5em; | |
| min-width: 2.5em !important; | |
| height: 2.5em; | |
| } | |
| """ | |
| basedir = os.getcwd() | |
| savedir = os.path.join(basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) | |
| savedir_sample = os.path.join(savedir, "sample") | |
| os.makedirs(savedir, exist_ok=True) | |
| EXAMPLES = [ # prompt, first frame, width, height, center crop, seed | |
| ["์ค๋ก๋ผ๊ฐ ํ๋์ ์๋ ๋ ๋ฎ์ธ ๋ ์ ํ์๋ฉ์ค.", "example/example_01.png"], | |
| ["๋ถ๊ฝ๋์ด.", "example/example_02.png"], | |
| ["์ฐํธ์ด๋ฅผ ํค์์น๋ ํฐ๋๊ฐ๋ฆฌ.", "example/example_03.png"], | |
| ["์ฝ์์ ๋ น์๋ด๋ฆฌ๋ ์์ด์คํฌ๋ฆผ.", "example/example_04.png"], | |
| ] | |
| EXAMPLES_HIDDEN = { | |
| "์ค๋ก๋ผ๊ฐ ํ๋์ ์๋ ๋ ๋ฎ์ธ ๋ ์ ํ์๋ฉ์ค.": ["example/example_01.png", 512, 512, True, 21800], | |
| "๋ถ๊ฝ๋์ด.": ["example/example_02.png", 512, 512, True, 21800], | |
| "์ฐํธ์ด๋ฅผ ํค์์น๋ ํฐ๋๊ฐ๋ฆฌ.": ["example/example_03.png", 512, 512, True, 75692375], | |
| "์ฝ์์ ๋ น์๋ด๋ฆฌ๋ ์์ด์คํฌ๋ฆผ.": ["example/example_04.png", 512, 512, True, 21800] | |
| } | |
| def update_and_resize_image(input_image_path, height_slider, width_slider, center_crop): | |
| if input_image_path.startswith("http://") or input_image_path.startswith("https://"): | |
| pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB') | |
| else: | |
| pil_image = Image.open(input_image_path).convert('RGB') | |
| original_width, original_height = pil_image.size | |
| if center_crop: | |
| crop_aspect_ratio = width_slider / height_slider | |
| aspect_ratio = original_width / original_height | |
| if aspect_ratio > crop_aspect_ratio: | |
| new_width = int(crop_aspect_ratio * original_height) | |
| left = (original_width - new_width) / 2 | |
| top = 0 | |
| right = left + new_width | |
| bottom = original_height | |
| pil_image = pil_image.crop((left, top, right, bottom)) | |
| elif aspect_ratio < crop_aspect_ratio: | |
| new_height = int(original_width / crop_aspect_ratio) | |
| top = (original_height - new_height) / 2 | |
| left = 0 | |
| right = original_width | |
| bottom = top + new_height | |
| pil_image = pil_image.crop((left, top, right, bottom)) | |
| pil_image = pil_image.resize((width_slider, height_slider)) | |
| return gr.Image(value=np.array(pil_image)) | |
| def get_examples(prompt_textbox, input_image): | |
| input_image_path = EXAMPLES_HIDDEN[prompt_textbox][0] | |
| width_slider = EXAMPLES_HIDDEN[prompt_textbox][1] | |
| height_slider = EXAMPLES_HIDDEN[prompt_textbox][2] | |
| center_crop = EXAMPLES_HIDDEN[prompt_textbox][3] | |
| seed_textbox = EXAMPLES_HIDDEN[prompt_textbox][4] | |
| input_image = update_and_resize_image(input_image_path, height_slider, width_slider, center_crop) | |
| return prompt_textbox, input_image, input_image_path, width_slider, height_slider, center_crop, seed_textbox | |
| # config models | |
| pipeline = ConditionalAnimationPipeline.from_pretrained("TIGER-Lab/ConsistI2V", torch_dtype=torch.float16) | |
| pipeline.to("cuda") | |
| def update_textbox_and_save_image(input_image, height_slider, width_slider, center_crop): | |
| pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB") | |
| img_path = os.path.join(savedir, "input_image.png") | |
| pil_image.save(img_path) | |
| original_width, original_height = pil_image.size | |
| if center_crop: | |
| crop_aspect_ratio = width_slider / height_slider | |
| aspect_ratio = original_width / original_height | |
| if aspect_ratio > crop_aspect_ratio: | |
| new_width = int(crop_aspect_ratio * original_height) | |
| left = (original_width - new_width) / 2 | |
| top = 0 | |
| right = left + new_width | |
| bottom = original_height | |
| pil_image = pil_image.crop((left, top, right, bottom)) | |
| elif aspect_ratio < crop_aspect_ratio: | |
| new_height = int(original_width / crop_aspect_ratio) | |
| top = (original_height - new_height) / 2 | |
| left = 0 | |
| right = original_width | |
| bottom = top + new_height | |
| pil_image = pil_image.crop((left, top, right, bottom)) | |
| pil_image = pil_image.resize((width_slider, height_slider)) | |
| return gr.Textbox(value=img_path), gr.Image(value=np.array(pil_image)) | |
| def animate( | |
| prompt_textbox, | |
| negative_prompt_textbox, | |
| input_image_path, | |
| sampler_dropdown, | |
| sample_step_slider, | |
| width_slider, | |
| height_slider, | |
| txt_cfg_scale_slider, | |
| img_cfg_scale_slider, | |
| center_crop, | |
| frame_stride, | |
| use_frameinit, | |
| frame_init_noise_level, | |
| seed_textbox | |
| ): | |
| width_slider = int(width_slider) | |
| height_slider = int(height_slider) | |
| frame_stride = int(frame_stride) | |
| sample_step_slider = int(sample_step_slider) | |
| txt_cfg_scale_slider = float(txt_cfg_scale_slider) | |
| img_cfg_scale_slider = float(img_cfg_scale_slider) | |
| frame_init_noise_level = int(frame_init_noise_level) | |
| if pipeline is None: | |
| raise gr.Error(f"Please select a pretrained pipeline path.") | |
| if input_image_path == "": | |
| raise gr.Error(f"Please upload an input image.") | |
| if (not center_crop) and (width_slider % 8 != 0 or height_slider % 8 != 0): | |
| raise gr.Error(f"`height` and `width` have to be divisible by 8 but are {height_slider} and {width_slider}.") | |
| if center_crop and (width_slider % 8 != 0 or height_slider % 8 != 0): | |
| raise gr.Error(f"`height` and `width` (after cropping) have to be divisible by 8 but are {height_slider} and {width_slider}.") | |
| if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2: pipeline.unet.enable_xformers_memory_efficient_attention() | |
| if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) | |
| else: torch.seed() | |
| seed = torch.initial_seed() | |
| if input_image_path.startswith("http://") or input_image_path.startswith("https://"): | |
| first_frame = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB') | |
| else: | |
| first_frame = Image.open(input_image_path).convert('RGB') | |
| original_width, original_height = first_frame.size | |
| if not center_crop: | |
| img_transform = T.Compose([ | |
| T.ToTensor(), | |
| T.Resize((height_slider, width_slider), antialias=None), | |
| T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ]) | |
| else: | |
| aspect_ratio = original_width / original_height | |
| crop_aspect_ratio = width_slider / height_slider | |
| if aspect_ratio > crop_aspect_ratio: | |
| center_crop_width = int(crop_aspect_ratio * original_height) | |
| center_crop_height = original_height | |
| elif aspect_ratio < crop_aspect_ratio: | |
| center_crop_width = original_width | |
| center_crop_height = int(original_width / crop_aspect_ratio) | |
| else: | |
| center_crop_width = original_width | |
| center_crop_height = original_height | |
| img_transform = T.Compose([ | |
| T.ToTensor(), | |
| T.CenterCrop((center_crop_height, center_crop_width)), | |
| T.Resize((height_slider, width_slider), antialias=None), | |
| T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ]) | |
| first_frame = img_transform(first_frame).unsqueeze(0) | |
| first_frame = first_frame.to("cuda") | |
| if use_frameinit: | |
| pipeline.init_filter( | |
| width = width_slider, | |
| height = height_slider, | |
| video_length = 16, | |
| filter_params = OmegaConf.create({'method': 'gaussian', 'd_s': 0.25, 'd_t': 0.25,}) | |
| ) | |
| # Translate Korean prompt to English | |
| translated_prompt = translator(prompt_textbox)[0]['translation_text'] | |
| sample = pipeline( | |
| translated_prompt, | |
| negative_prompt = negative_prompt_textbox, | |
| first_frames = first_frame, | |
| num_inference_steps = sample_step_slider, | |
| guidance_scale_txt = txt_cfg_scale_slider, | |
| guidance_scale_img = img_cfg_scale_slider, | |
| width = width_slider, | |
| height = height_slider, | |
| video_length = 16, | |
| noise_sampling_method = "pyoco_mixed", | |
| noise_alpha = 1.0, | |
| frame_stride = frame_stride, | |
| use_frameinit = use_frameinit, | |
| frameinit_noise_level = frame_init_noise_level, | |
| camera_motion = None, | |
| ).videos | |
| global sample_idx | |
| sample_idx += 1 | |
| save_sample_path = os.path.join(savedir_sample, f"{sample_idx}.mp4") | |
| save_videos_grid(sample, save_sample_path, format="mp4") | |
| sample_config = { | |
| "prompt": prompt_textbox, | |
| "n_prompt": negative_prompt_textbox, | |
| "first_frame_path": input_image_path, | |
| "sampler": sampler_dropdown, | |
| "num_inference_steps": sample_step_slider, | |
| "guidance_scale_text": txt_cfg_scale_slider, | |
| "guidance_scale_image": img_cfg_scale_slider, | |
| "width": width_slider, | |
| "height": height_slider, | |
| "video_length": 8, | |
| "seed": seed | |
| } | |
| json_str = json.dumps(sample_config, indent=4) | |
| with open(os.path.join(savedir, "logs.json"), "a") as f: | |
| f.write(json_str) | |
| f.write("\n\n") | |
| return gr.Video(value=save_sample_path) | |
| def ui(): | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown( | |
| """ | |
| # Text+Image to Video Generation | |
| """ | |
| ) | |
| with gr.Row(): | |
| prompt_textbox = gr.Textbox(label="ํ๋กฌํํธ (ํ๊ธ)", lines=2) | |
| negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(): | |
| with gr.Row(): | |
| sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) | |
| sample_step_slider = gr.Slider(label="Sampling steps", value=250, minimum=10, maximum=250, step=1) | |
| with gr.Row(): | |
| center_crop = gr.Checkbox(label="Center Crop the Image", value=True) | |
| width_slider = gr.Slider(label="Width", value=512, minimum=0, maximum=512, step=64) | |
| height_slider = gr.Slider(label="Height", value=512, minimum=0, maximum=512, step=64) | |
| with gr.Row(): | |
| txt_cfg_scale_slider = gr.Slider(label="Text CFG Scale", value=7.5, minimum=1.0, maximum=20.0, step=0.5) | |
| img_cfg_scale_slider = gr.Slider(label="Image CFG Scale", value=1.0, minimum=1.0, maximum=20.0, step=0.5) | |
| frame_stride = gr.Slider(label="Frame Stride", value=3, minimum=1, maximum=5, step=1) | |
| with gr.Row(): | |
| use_frameinit = gr.Checkbox(label="Enable FrameInit", value=True) | |
| frameinit_noise_level = gr.Slider(label="FrameInit Noise Level", value=850, minimum=1, maximum=999, step=1) | |
| seed_textbox = gr.Textbox(label="Seed", value=-1) | |
| seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") | |
| seed_button.click(fn=lambda: gr.Textbox(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox]) | |
| generate_button = gr.Button(value="Generate", variant='primary') | |
| with gr.Column(): | |
| with gr.Row(): | |
| input_image_path = gr.Textbox(label="Input Image URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.") | |
| preview_button = gr.Button(value="Preview") | |
| with gr.Row(): | |
| input_image = gr.Image(label="Input Image", interactive=True) | |
| input_image.upload(fn=update_textbox_and_save_image, inputs=[input_image, height_slider, width_slider, center_crop], outputs=[input_image_path, input_image]) | |
| result_video = gr.Video(label="Generated Animation", interactive=False, autoplay=True) | |
| with gr.Row(): | |
| batch_examples = gr.Examples( | |
| examples=EXAMPLES, | |
| fn=get_examples, | |
| cache_examples=True, | |
| examples_per_page=4, | |
| inputs=[prompt_textbox, input_image], | |
| outputs=[prompt_textbox, input_image, input_image_path, width_slider, height_slider, center_crop, seed_textbox], | |
| ) | |
| preview_button.click(fn=update_and_resize_image, inputs=[input_image_path, height_slider, width_slider, center_crop], outputs=[input_image]) | |
| input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height_slider, width_slider, center_crop], outputs=[input_image]) | |
| generate_button.click( | |
| fn=animate, | |
| inputs=[ | |
| prompt_textbox, | |
| negative_prompt_textbox, | |
| input_image_path, | |
| sampler_dropdown, | |
| sample_step_slider, | |
| width_slider, | |
| height_slider, | |
| txt_cfg_scale_slider, | |
| img_cfg_scale_slider, | |
| center_crop, | |
| frame_stride, | |
| use_frameinit, | |
| frameinit_noise_level, | |
| seed_textbox, | |
| ], | |
| outputs=[result_video] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = ui() | |
| demo.launch(share=True) | |