| | """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py |
| | """ |
| | import base64 |
| | import gc |
| | import json |
| | import os |
| | import hashlib |
| | import random |
| | from datetime import datetime |
| | from glob import glob |
| |
|
| | import cv2 |
| | import gradio as gr |
| | import numpy as np |
| | import pkg_resources |
| | import requests |
| | import torch |
| | from diffusers import (CogVideoXDDIMScheduler, DDIMScheduler, |
| | DPMSolverMultistepScheduler, |
| | EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, |
| | FlowMatchEulerDiscreteScheduler, PNDMScheduler) |
| | from omegaconf import OmegaConf |
| | from PIL import Image |
| | from safetensors import safe_open |
| |
|
| | from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio |
| | from ..utils.utils import save_videos_grid |
| | from ..utils.fm_solvers import FlowDPMSolverMultistepScheduler |
| | from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
| | from ..dist import set_multi_gpus_devices |
| |
|
| | gradio_version = pkg_resources.get_distribution("gradio").version |
| | gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False |
| |
|
| | css = """ |
| | .toolbutton { |
| | margin-buttom: 0em 0em 0em 0em; |
| | max-width: 2.5em; |
| | min-width: 2.5em !important; |
| | height: 2.5em; |
| | } |
| | """ |
| |
|
| | ddpm_scheduler_dict = { |
| | "Euler": EulerDiscreteScheduler, |
| | "Euler A": EulerAncestralDiscreteScheduler, |
| | "DPM++": DPMSolverMultistepScheduler, |
| | "PNDM": PNDMScheduler, |
| | "DDIM": DDIMScheduler, |
| | "DDIM_Origin": DDIMScheduler, |
| | "DDIM_Cog": CogVideoXDDIMScheduler, |
| | } |
| | flow_scheduler_dict = { |
| | "Flow": FlowMatchEulerDiscreteScheduler, |
| | "Flow_Unipc": FlowUniPCMultistepScheduler, |
| | "Flow_DPM++": FlowDPMSolverMultistepScheduler, |
| | } |
| | all_cheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict} |
| |
|
| | class Fun_Controller: |
| | def __init__( |
| | self, GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", |
| | config_path=None, ulysses_degree=1, ring_degree=1, |
| | fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False, |
| | weight_dtype=None, savedir_sample=None, |
| | ): |
| | |
| | self.basedir = os.getcwd() |
| | self.config_dir = os.path.join(self.basedir, "config") |
| | self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer") |
| | self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") |
| | self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model") |
| | if savedir_sample is None: |
| | self.savedir_sample = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) |
| | else: |
| | self.savedir_sample = savedir_sample |
| | os.makedirs(self.savedir_sample, exist_ok=True) |
| |
|
| | self.GPU_memory_mode = GPU_memory_mode |
| | self.model_name = model_name |
| | self.diffusion_transformer_dropdown = model_name |
| | self.scheduler_dict = scheduler_dict |
| | self.model_type = model_type |
| | if config_path is not None: |
| | self.config_path = os.path.realpath(config_path) |
| | self.config = OmegaConf.load(config_path) |
| | else: |
| | self.config_path = None |
| | self.ulysses_degree = ulysses_degree |
| | self.ring_degree = ring_degree |
| | self.fsdp_dit = fsdp_dit |
| | self.fsdp_text_encoder = fsdp_text_encoder |
| | self.compile_dit = compile_dit |
| | self.weight_dtype = weight_dtype |
| | self.device = set_multi_gpus_devices(self.ulysses_degree, self.ring_degree) |
| |
|
| | self.diffusion_transformer_list = [] |
| | self.motion_module_list = [] |
| | self.personalized_model_list = [] |
| | self.config_list = [] |
| |
|
| | |
| | self.tokenizer = None |
| | self.text_encoder = None |
| | self.vae = None |
| | self.transformer = None |
| | self.transformer_2 = None |
| | self.pipeline = None |
| | self.base_model_path = "none" |
| | self.base_model_2_path = "none" |
| | self.lora_model_path = "none" |
| | self.lora_model_2_path = "none" |
| | |
| | self.refresh_config() |
| | self.refresh_diffusion_transformer() |
| | self.refresh_personalized_model() |
| | if model_name != None: |
| | self.update_diffusion_transformer(model_name) |
| |
|
| | def refresh_config(self): |
| | config_list = [] |
| | for root, dirs, files in os.walk(self.config_dir): |
| | for file in files: |
| | if file.endswith(('.yaml', '.yml')): |
| | full_path = os.path.join(root, file) |
| | config_list.append(full_path) |
| | self.config_list = config_list |
| |
|
| | def refresh_diffusion_transformer(self): |
| | self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/"))) |
| |
|
| | def refresh_personalized_model(self): |
| | personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors"))) |
| | self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list] |
| |
|
| | def update_model_type(self, model_type): |
| | self.model_type = model_type |
| |
|
| | def update_config(self, config_dropdown): |
| | self.config_path = config_dropdown |
| | self.config = OmegaConf.load(config_dropdown) |
| | print(f"Update config: {config_dropdown}") |
| |
|
| | def update_diffusion_transformer(self, diffusion_transformer_dropdown): |
| | pass |
| |
|
| | def update_base_model(self, base_model_dropdown, is_checkpoint_2=False): |
| | if not is_checkpoint_2: |
| | self.base_model_path = base_model_dropdown |
| | else: |
| | self.base_model_2_path = base_model_dropdown |
| | print(f"Update base model: {base_model_dropdown}") |
| | if base_model_dropdown == "none": |
| | return gr.update() |
| | if self.transformer is None and not is_checkpoint_2: |
| | gr.Info(f"Please select a pretrained model path.") |
| | print(f"Please select a pretrained model path.") |
| | return gr.update(value=None) |
| | elif self.transformer_2 is None and is_checkpoint_2: |
| | gr.Info(f"Please select a pretrained model path.") |
| | print(f"Please select a pretrained model path.") |
| | return gr.update(value=None) |
| | else: |
| | base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown) |
| | base_model_state_dict = {} |
| | with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: |
| | for key in f.keys(): |
| | base_model_state_dict[key] = f.get_tensor(key) |
| | if not is_checkpoint_2: |
| | self.transformer.load_state_dict(base_model_state_dict, strict=False) |
| | else: |
| | self.transformer_2.load_state_dict(base_model_state_dict, strict=False) |
| | print("Update base model done") |
| | return gr.update() |
| |
|
| | def update_lora_model(self, lora_model_dropdown, is_checkpoint_2=False): |
| | print(f"Update lora model: {lora_model_dropdown}") |
| | if lora_model_dropdown == "none": |
| | self.lora_model_path = "none" |
| | return gr.update() |
| | lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown) |
| | if not is_checkpoint_2: |
| | self.lora_model_path = lora_model_dropdown |
| | else: |
| | self.lora_model_2_path = lora_model_dropdown |
| | return gr.update() |
| |
|
| | def clear_cache(self,): |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | torch.cuda.ipc_collect() |
| |
|
| | def auto_model_clear_cache(self, model): |
| | origin_device = model.device |
| | model = model.to("cpu") |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | torch.cuda.ipc_collect() |
| | model = model.to(origin_device) |
| | |
| | def input_check(self, |
| | resize_method, |
| | generation_method, |
| | start_image, |
| | end_image, |
| | validation_video, |
| | control_video, |
| | is_api = False, |
| | ): |
| | if self.transformer is None: |
| | if is_api: |
| | return "", f"Please select a pretrained model path." |
| | else: |
| | raise gr.Error(f"Please select a pretrained model path.") |
| | |
| | if control_video is not None and self.model_type == "Inpaint": |
| | if is_api: |
| | return "", f"If specifying the control video, please set the model_type == \"Control\". " |
| | else: |
| | raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ") |
| |
|
| | if control_video is None and self.model_type == "Control": |
| | if is_api: |
| | return "", f"If set the model_type == \"Control\", please specifying the control video. " |
| | else: |
| | raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ") |
| |
|
| | if resize_method == "Resize according to Reference": |
| | if start_image is None and validation_video is None and control_video is None: |
| | if is_api: |
| | return "", f"Please upload an image when using \"Resize according to Reference\"." |
| | else: |
| | raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".") |
| |
|
| | if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None: |
| | if is_api: |
| | return "", f"Please select an image to video pretrained model while using image to video." |
| | else: |
| | raise gr.Error(f"Please select an image to video pretrained model while using image to video.") |
| |
|
| | if self.transformer.config.in_channels == self.vae.config.latent_channels and generation_method == "Long Video Generation": |
| | if is_api: |
| | return "", f"Please select an image to video pretrained model while using long video generation." |
| | else: |
| | raise gr.Error(f"Please select an image to video pretrained model while using long video generation.") |
| | |
| | if start_image is None and end_image is not None: |
| | if is_api: |
| | return "", f"If specifying the ending image of the video, please specify a starting image of the video." |
| | else: |
| | raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.") |
| | return "", "OK" |
| |
|
| | def get_height_width_from_reference( |
| | self, |
| | base_resolution, |
| | start_image, |
| | validation_video, |
| | control_video, |
| | ): |
| | spatial_compression_ratio = self.vae.config.spatial_compression_ratio if hasattr(self.vae.config, "spatial_compression_ratio") else 8 |
| | aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} |
| | if self.model_type == "Inpaint": |
| | if validation_video is not None: |
| | original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size |
| | else: |
| | original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size |
| | else: |
| | original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size |
| | closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) |
| | height_slider, width_slider = [int(x / spatial_compression_ratio / 2) * spatial_compression_ratio * 2 for x in closest_size] |
| | return height_slider, width_slider |
| |
|
| | def save_outputs(self, is_image, length_slider, sample, fps): |
| | def save_results(): |
| | if not os.path.exists(self.savedir_sample): |
| | os.makedirs(self.savedir_sample, exist_ok=True) |
| | index = len([path for path in os.listdir(self.savedir_sample)]) + 1 |
| | prefix = str(index).zfill(8) |
| |
|
| | md5_hash = hashlib.md5(sample.cpu().numpy().tobytes()).hexdigest() |
| |
|
| | if is_image or length_slider == 1: |
| | save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.png") |
| | print(f"Saving to {save_sample_path}") |
| | image = sample[0, :, 0] |
| | image = image.transpose(0, 1).transpose(1, 2) |
| | image = (image * 255).numpy().astype(np.uint8) |
| | image = Image.fromarray(image) |
| | image.save(save_sample_path) |
| |
|
| | else: |
| | save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.mp4") |
| | print(f"Saving to {save_sample_path}") |
| | save_videos_grid(sample, save_sample_path, fps=fps) |
| | return save_sample_path |
| |
|
| | if self.ulysses_degree * self.ring_degree > 1: |
| | import torch.distributed as dist |
| | if dist.get_rank() == 0: |
| | save_sample_path = save_results() |
| | else: |
| | save_sample_path = None |
| | else: |
| | save_sample_path = save_results() |
| | return save_sample_path |
| |
|
| | def generate( |
| | self, |
| | diffusion_transformer_dropdown, |
| | base_model_dropdown, |
| | lora_model_dropdown, |
| | lora_alpha_slider, |
| | prompt_textbox, |
| | negative_prompt_textbox, |
| | sampler_dropdown, |
| | sample_step_slider, |
| | resize_method, |
| | width_slider, |
| | height_slider, |
| | base_resolution, |
| | generation_method, |
| | length_slider, |
| | overlap_video_length, |
| | partial_video_length, |
| | cfg_scale_slider, |
| | start_image, |
| | end_image, |
| | validation_video, |
| | validation_video_mask, |
| | control_video, |
| | denoise_strength, |
| | seed_textbox, |
| | enable_teacache = None, |
| | teacache_threshold = None, |
| | num_skip_start_steps = None, |
| | teacache_offload = None, |
| | cfg_skip_ratio = None, |
| | enable_riflex = None, |
| | riflex_k = None, |
| | is_api = False, |
| | ): |
| | pass |
| |
|
| | def post_to_host( |
| | diffusion_transformer_dropdown, |
| | base_model_dropdown, lora_model_dropdown, lora_alpha_slider, |
| | prompt_textbox, negative_prompt_textbox, |
| | sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, |
| | base_resolution, generation_method, length_slider, cfg_scale_slider, |
| | start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox, |
| | ref_image = None, enable_teacache = None, teacache_threshold = None, num_skip_start_steps = None, |
| | teacache_offload = None, cfg_skip_ratio = None,enable_riflex = None, riflex_k = None, |
| | ): |
| | if start_image is not None: |
| | with open(start_image, 'rb') as file: |
| | file_content = file.read() |
| | start_image_encoded_content = base64.b64encode(file_content) |
| | start_image = start_image_encoded_content.decode('utf-8') |
| |
|
| | if end_image is not None: |
| | with open(end_image, 'rb') as file: |
| | file_content = file.read() |
| | end_image_encoded_content = base64.b64encode(file_content) |
| | end_image = end_image_encoded_content.decode('utf-8') |
| |
|
| | if validation_video is not None: |
| | with open(validation_video, 'rb') as file: |
| | file_content = file.read() |
| | validation_video_encoded_content = base64.b64encode(file_content) |
| | validation_video = validation_video_encoded_content.decode('utf-8') |
| |
|
| | if validation_video_mask is not None: |
| | with open(validation_video_mask, 'rb') as file: |
| | file_content = file.read() |
| | validation_video_mask_encoded_content = base64.b64encode(file_content) |
| | validation_video_mask = validation_video_mask_encoded_content.decode('utf-8') |
| |
|
| | if ref_image is not None: |
| | with open(ref_image, 'rb') as file: |
| | file_content = file.read() |
| | ref_image_encoded_content = base64.b64encode(file_content) |
| | ref_image = ref_image_encoded_content.decode('utf-8') |
| |
|
| | datas = { |
| | "base_model_path": base_model_dropdown, |
| | "lora_model_path": lora_model_dropdown, |
| | "lora_alpha_slider": lora_alpha_slider, |
| | "prompt_textbox": prompt_textbox, |
| | "negative_prompt_textbox": negative_prompt_textbox, |
| | "sampler_dropdown": sampler_dropdown, |
| | "sample_step_slider": sample_step_slider, |
| | "resize_method": resize_method, |
| | "width_slider": width_slider, |
| | "height_slider": height_slider, |
| | "base_resolution": base_resolution, |
| | "generation_method": generation_method, |
| | "length_slider": length_slider, |
| | "cfg_scale_slider": cfg_scale_slider, |
| | "start_image": start_image, |
| | "end_image": end_image, |
| | "validation_video": validation_video, |
| | "validation_video_mask": validation_video_mask, |
| | "denoise_strength": denoise_strength, |
| | "seed_textbox": seed_textbox, |
| |
|
| | "ref_image": ref_image, |
| | "enable_teacache": enable_teacache, |
| | "teacache_threshold": teacache_threshold, |
| | "num_skip_start_steps": num_skip_start_steps, |
| | "teacache_offload": teacache_offload, |
| | "cfg_skip_ratio": cfg_skip_ratio, |
| | "enable_riflex": enable_riflex, |
| | "riflex_k": riflex_k, |
| | } |
| |
|
| | session = requests.session() |
| | session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")}) |
| |
|
| | response = session.post(url=f'{os.environ.get("EAS_URL")}/videox_fun/infer_forward', json=datas, timeout=300) |
| |
|
| | outputs = response.json() |
| | return outputs |
| |
|
| |
|
| | class Fun_Controller_Client: |
| | def __init__(self, scheduler_dict, savedir_sample): |
| | self.basedir = os.getcwd() |
| | if savedir_sample is None: |
| | self.savedir_sample = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) |
| | else: |
| | self.savedir_sample = savedir_sample |
| | os.makedirs(self.savedir_sample, exist_ok=True) |
| | |
| | self.scheduler_dict = scheduler_dict |
| |
|
| | def generate( |
| | self, |
| | diffusion_transformer_dropdown, |
| | base_model_dropdown, |
| | lora_model_dropdown, |
| | lora_alpha_slider, |
| | prompt_textbox, |
| | negative_prompt_textbox, |
| | sampler_dropdown, |
| | sample_step_slider, |
| | resize_method, |
| | width_slider, |
| | height_slider, |
| | base_resolution, |
| | generation_method, |
| | length_slider, |
| | cfg_scale_slider, |
| | start_image, |
| | end_image, |
| | validation_video, |
| | validation_video_mask, |
| | denoise_strength, |
| | seed_textbox, |
| | ref_image = None, |
| | enable_teacache = None, |
| | teacache_threshold = None, |
| | num_skip_start_steps = None, |
| | teacache_offload = None, |
| | cfg_skip_ratio = None, |
| | enable_riflex = None, |
| | riflex_k = None, |
| | ): |
| | is_image = True if generation_method == "Image Generation" else False |
| |
|
| | outputs = post_to_host( |
| | diffusion_transformer_dropdown, |
| | base_model_dropdown, lora_model_dropdown, lora_alpha_slider, |
| | prompt_textbox, negative_prompt_textbox, |
| | sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, |
| | base_resolution, generation_method, length_slider, cfg_scale_slider, |
| | start_image, end_image, validation_video, validation_video_mask, denoise_strength, |
| | seed_textbox, ref_image = ref_image, enable_teacache = enable_teacache, teacache_threshold = teacache_threshold, |
| | num_skip_start_steps = num_skip_start_steps, teacache_offload = teacache_offload, |
| | cfg_skip_ratio = cfg_skip_ratio, enable_riflex = enable_riflex, riflex_k = riflex_k, |
| | ) |
| |
|
| | try: |
| | base64_encoding = outputs["base64_encoding"] |
| | except: |
| | return gr.Image(visible=False, value=None), gr.Video(None, visible=True), outputs["message"] |
| | |
| | decoded_data = base64.b64decode(base64_encoding) |
| |
|
| | if not os.path.exists(self.savedir_sample): |
| | os.makedirs(self.savedir_sample, exist_ok=True) |
| | md5_hash = hashlib.md5(decoded_data).hexdigest() |
| |
|
| | index = len([path for path in os.listdir(self.savedir_sample)]) + 1 |
| | prefix = str(index).zfill(8) |
| | |
| | if is_image or length_slider == 1: |
| | save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.png") |
| | print(f"Saving to {save_sample_path}") |
| | with open(save_sample_path, "wb") as file: |
| | file.write(decoded_data) |
| | if gradio_version_is_above_4: |
| | return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success" |
| | else: |
| | return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" |
| | else: |
| | save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.mp4") |
| | print(f"Saving to {save_sample_path}") |
| | with open(save_sample_path, "wb") as file: |
| | file.write(decoded_data) |
| | if gradio_version_is_above_4: |
| | return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" |
| | else: |
| | return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" |
| |
|