import torch from enum import Enum import gc import numpy as np import jax.numpy as jnp import jax from PIL import Image from typing import List from flax.training.common_utils import shard from flax.jax_utils import replicate from flax import jax_utils import einops from transformers import CLIPTokenizer, CLIPFeatureExtractor, FlaxCLIPTextModel from diffusers import ( FlaxDDIMScheduler, FlaxAutoencoderKL, FlaxUNet2DConditionModel as VanillaFlaxUNet2DConditionModel, ) from text_to_animation.models.unet_2d_condition_flax import FlaxUNet2DConditionModel from diffusers import FlaxControlNetModel from text_to_animation.pipelines.text_to_video_pipeline_flax import ( FlaxTextToVideoPipeline, ) import utils.utils as utils import utils.gradio_utils as gradio_utils import os on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR" unshard = lambda x: einops.rearrange(x, "d b ... -> (d b) ...") class ModelType(Enum): Text2Video = 1 ControlNetPose = 2 StableDiffusion = 3 def replicate_devices(array): return jnp.expand_dims(array, 0).repeat(jax.device_count(), 0) class ControlAnimationModel: def __init__(self, dtype, **kwargs): self.dtype = dtype self.rng = jax.random.PRNGKey(0) self.pipe = None self.model_type = None self.states = {} self.model_name = "" def set_model( self, model_id: str, **kwargs, ): if hasattr(self, "pipe") and self.pipe is not None: del self.pipe self.pipe = None gc.collect() controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( "fusing/stable-diffusion-v1-5-controlnet-openpose", from_pt=True, dtype=jnp.float16, ) scheduler, scheduler_state = FlaxDDIMScheduler.from_pretrained( model_id, subfolder="scheduler", from_pt=True ) tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") feature_extractor = CLIPFeatureExtractor.from_pretrained( model_id, subfolder="feature_extractor" ) unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( model_id, subfolder="unet", from_pt=True, dtype=self.dtype ) unet_vanilla = VanillaFlaxUNet2DConditionModel.from_config( model_id, subfolder="unet", from_pt=True, dtype=self.dtype ) vae, vae_params = FlaxAutoencoderKL.from_pretrained( model_id, subfolder="vae", from_pt=True, dtype=self.dtype ) text_encoder = FlaxCLIPTextModel.from_pretrained( model_id, subfolder="text_encoder", from_pt=True, dtype=self.dtype ) self.pipe = FlaxTextToVideoPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, unet_vanilla=unet_vanilla, controlnet=controlnet, scheduler=scheduler, safety_checker=None, feature_extractor=feature_extractor, ) self.params = { "unet": unet_params, "vae": vae_params, "scheduler": scheduler_state, "controlnet": controlnet_params, "text_encoder": text_encoder.params, } self.p_params = jax_utils.replicate(self.params) self.model_name = model_id def generate_initial_frames( self, prompt: str, video_path: str, n_prompt: str = "", num_imgs: int = 4, resolution: int = 512, model_id: str = "runwayml/stable-diffusion-v1-5", ) -> List[Image.Image]: self.set_model(model_id=model_id) video_path = gradio_utils.motion_to_video_path(video_path) added_prompt = "high quality, best quality, HD, clay stop-motion, claymation, HQ, masterpiece, art, smooth" prompts = added_prompt + ", " + prompt added_n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly" negative_prompts = added_n_prompt + ", " + n_prompt video, fps = utils.prepare_video( video_path, resolution, None, self.dtype, False, output_fps=4 ) control = utils.pre_process_pose(video, apply_pose_detect=False) seeds = [seed for seed in jax.random.randint(self.rng, [num_imgs], 0, 65536)] prngs = [jax.random.PRNGKey(seed) for seed in seeds] print(seeds) images = self.pipe.generate_starting_frames( params=self.p_params, prngs=prngs, controlnet_image=control, prompt=prompts, neg_prompt=negative_prompts, ) images = [np.array(images[i]) for i in range(images.shape[0])] return images def generate_video_from_frame(self, controlnet_video, prompt, seed, neg_prompt=""): # generate a video using the seed provided prng_seed = jax.random.PRNGKey(seed) len_vid = controlnet_video.shape[0] # print(f"Generating video from prompt {' style '+ prompt}, with {controlnet_video.shape[0]} frames and prng seed {seed}") added_prompt = "high quality, best quality, HD, clay stop-motion, claymation, HQ, masterpiece, art, smooth" prompts = added_prompt + ", " + prompt added_n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly" negative_prompts = added_n_prompt + ", " + neg_prompt # prompt_ids = self.pipe.prepare_text_inputs(["aardman style "+ prompt]*len_vid) # n_prompt_ids = self.pipe.prepare_text_inputs([neg_prompt]*len_vid) prompt_ids = self.pipe.prepare_text_inputs([prompts] * len_vid) n_prompt_ids = self.pipe.prepare_text_inputs([negative_prompts] * len_vid) prng = replicate_devices( prng_seed ) # jax.random.split(prng, jax.device_count()) image = replicate_devices(controlnet_video) prompt_ids = replicate_devices(prompt_ids) n_prompt_ids = replicate_devices(n_prompt_ids) motion_field_strength_x = replicate_devices(jnp.array(3)) motion_field_strength_y = replicate_devices(jnp.array(4)) smooth_bg_strength = replicate_devices(jnp.array(0.8)) vid = ( self.pipe( image=image, prompt_ids=prompt_ids, neg_prompt_ids=n_prompt_ids, params=self.p_params, prng_seed=prng, jit=True, smooth_bg_strength=smooth_bg_strength, motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, ).images )[0] return utils.create_gif(np.array(vid), 4, path=None, watermark=None)