Pie31415's picture
update
e2f5469
raw
history blame contribute delete
No virus
7.11 kB
import torch
from enum import Enum
import gc
import numpy as np
import jax.numpy as jnp
import jax
from PIL import Image
from typing import List
from flax.training.common_utils import shard
from flax.jax_utils import replicate
from flax import jax_utils
import einops
from transformers import CLIPTokenizer, CLIPFeatureExtractor, FlaxCLIPTextModel
from diffusers import (
FlaxDDIMScheduler,
FlaxAutoencoderKL,
FlaxStableDiffusionControlNetPipeline,
StableDiffusionPipeline,
FlaxUNet2DConditionModel as VanillaFlaxUNet2DConditionModel,
)
from text_to_animation.models.unet_2d_condition_flax import (
FlaxUNet2DConditionModel
)
from diffusers import FlaxControlNetModel
from text_to_animation.pipelines.text_to_video_pipeline_flax import (
FlaxTextToVideoPipeline,
)
import utils.utils as utils
import utils.gradio_utils as gradio_utils
import os
on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
unshard = lambda x: einops.rearrange(x, "d b ... -> (d b) ...")
class ModelType(Enum):
Text2Video = 1
ControlNetPose = 2
StableDiffusion = 3
def replicate_devices(array):
return jnp.expand_dims(array, 0).repeat(jax.device_count(), 0)
class ControlAnimationModel:
def __init__(self, dtype, **kwargs):
self.dtype = dtype
self.rng = jax.random.PRNGKey(0)
self.pipe = None
self.model_type = None
self.states = {}
self.model_name = ""
def set_model(
self,
model_id: str,
**kwargs,
):
if hasattr(self, "pipe") and self.pipe is not None:
del self.pipe
self.pipe = None
gc.collect()
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"fusing/stable-diffusion-v1-5-controlnet-openpose",
from_pt=True,
dtype=jnp.float16,
)
scheduler, scheduler_state = FlaxDDIMScheduler.from_pretrained(
model_id, subfolder="scheduler", from_pt=True
)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
feature_extractor = CLIPFeatureExtractor.from_pretrained(
model_id, subfolder="feature_extractor"
)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
model_id, subfolder="unet", from_pt=True, dtype=self.dtype
)
unet_vanilla = VanillaFlaxUNet2DConditionModel.from_config(
model_id, subfolder="unet", from_pt=True, dtype=self.dtype
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
model_id, subfolder="vae", from_pt=True, dtype=self.dtype
)
text_encoder = FlaxCLIPTextModel.from_pretrained(
model_id, subfolder="text_encoder", from_pt=True, dtype=self.dtype
)
self.pipe = FlaxTextToVideoPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
unet_vanilla=unet_vanilla,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
self.params = {
"unet": unet_params,
"vae": vae_params,
"scheduler": scheduler_state,
"controlnet": controlnet_params,
"text_encoder": text_encoder.params,
}
self.p_params = jax_utils.replicate(self.params)
self.model_name = model_id
def generate_initial_frames(
self,
prompt: str,
video_path: str,
n_prompt: str = "",
seed: int = 0,
num_imgs: int = 4,
resolution: int = 512,
model_id: str = "runwayml/stable-diffusion-v1-5",
) -> List[Image.Image]:
self.set_model(model_id=model_id)
video_path = gradio_utils.motion_to_video_path(video_path)
added_prompt = "high quality, best quality, HD, clay stop-motion, claymation, HQ, masterpiece, art, smooth"
prompts = added_prompt + ", " + prompt
added_n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly"
negative_prompts = added_n_prompt + ", " + n_prompt
video, fps = utils.prepare_video(
video_path, resolution, None, self.dtype, False, output_fps=4
)
control = utils.pre_process_pose(video, apply_pose_detect=False)
# seeds = [seed for seed in jax.random.randint(self.rng, [num_imgs], 0, 65536)]
prngs = [jax.random.PRNGKey(seed)] * num_imgs
images = self.pipe.generate_starting_frames(
params=self.p_params,
prngs=prngs,
controlnet_image=control,
prompt=prompts,
neg_prompt=negative_prompts,
)
images = [np.array(images[i]) for i in range(images.shape[0])]
return video, images
def generate_video_from_frame(self, controlnet_video, prompt, n_prompt, seed):
# generate a video using the seed provided
prng_seed = jax.random.PRNGKey(seed)
len_vid = controlnet_video.shape[0]
# print(f"Generating video from prompt {'<aardman> style '+ prompt}, with {controlnet_video.shape[0]} frames and prng seed {seed}")
added_prompt = "high quality, best quality, HD, clay stop-motion, claymation, HQ, masterpiece, art, smooth"
prompts = added_prompt + ", " + prompt
added_n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly"
negative_prompts = added_n_prompt + ", " + n_prompt
# prompt_ids = self.pipe.prepare_text_inputs(["aardman style "+ prompt]*len_vid)
# n_prompt_ids = self.pipe.prepare_text_inputs([neg_prompt]*len_vid)
prompt_ids = self.pipe.prepare_text_inputs([prompts]*len_vid)
n_prompt_ids = self.pipe.prepare_text_inputs([negative_prompts]*len_vid)
prng = replicate_devices(prng_seed) #jax.random.split(prng, jax.device_count())
image = replicate_devices(controlnet_video)
prompt_ids = replicate_devices(prompt_ids)
n_prompt_ids = replicate_devices(n_prompt_ids)
motion_field_strength_x = replicate_devices(jnp.array(3))
motion_field_strength_y = replicate_devices(jnp.array(4))
smooth_bg_strength = replicate_devices(jnp.array(0.8))
vid = (self.pipe(image=image,
prompt_ids=prompt_ids,
neg_prompt_ids=n_prompt_ids,
params=self.p_params,
prng_seed=prng,
jit = True,
smooth_bg_strength=smooth_bg_strength,
motion_field_strength_x=motion_field_strength_x,
motion_field_strength_y=motion_field_strength_y,
).images)[0]
return utils.create_gif(np.array(vid), 4, path=None, watermark=None)