CoherentControl / model.py
foz
Fix requirements
aada7c5
from enum import Enum
import gc
import numpy as np
import torch
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import utils
import gradio_utils
import os
from einops import rearrange
import matplotlib.pyplot as plt
def create_key(seed=0):
return jax.random.PRNGKey(seed)
class Model:
def __init__(self, **kwargs):
self.base_controlnet, self.base_controlnet_params = FlaxControlNetModel.from_pretrained(
#"JFoz/dog-cat-pose", dtype=jnp.bfloat16
"lllyasviel/control_v11p_sd15_openpose", dtype=jnp.bfloat16, from_pt=True
)
self.pipe, self.params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=self.base_controlnet, revision="flax", dtype=jnp.bfloat16,# from_pt=True,
)
def infer_frame(self, frame_id, prompt, negative_prompt, rng, **kwargs):
print(prompt, frame_id)
num_samples = 1
prompt_ids = self.pipe.prepare_text_inputs([prompt[frame_id]]*num_samples)
negative_prompt_ids = self.pipe.prepare_text_inputs([negative_prompt[frame_id]] * num_samples)
processed_image = self.pipe.prepare_image_inputs([kwargs['image'][frame_id]]*num_samples)
self.params["controlnet"] = self.base_controlnet_params
p_params = replicate(self.params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
output = self.pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
output_images = np.asarray(output.reshape((num_samples,) + output.shape[-3:]))
return output_images
def inference(self, **kwargs):
seed = kwargs.pop('seed', 0)
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
f = len(kwargs['image'])
print('frames', f)
assert 'prompt' in kwargs
prompt = [kwargs.pop('prompt')] * f
negative_prompt = [kwargs.pop('negative_prompt', '')] * f
frames_counter = 0
result = []
for i in range(0, f):
print(f'Processing frame {i + 1} / {f}')
result.append(self.infer_frame(frame_id=i,
prompt=prompt,
negative_prompt=negative_prompt,
rng = rng,
**kwargs))
frames_counter += 1
result = np.stack(result, axis=0)
return result
def process_controlnet_pose(self,
video_path,
prompt,
num_inference_steps=20,
controlnet_conditioning_scale=1.0,
guidance_scale=9.0,
seed=42,
eta=0.0,
resolution=512,
save_path=None):
print("Module Pose")
video_path = gradio_utils.motion_to_video_path(video_path)
added_prompt = 'best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth'
negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic'
video, fps = utils.prepare_video(
video_path, resolution, False, output_fps=4)
control = utils.pre_process_pose(
video, apply_pose_detect=False)
print('N frames', len(control))
f, _, h, w = video.shape
result = self.inference(image=control,
prompt=prompt + ', ' + added_prompt,
height=h,
width=w,
negative_prompt=negative_prompts,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale,
eta=eta,
seed=seed,
output_type='numpy',
)
return utils.create_gif(result.astype(jnp.float16), fps, path=save_path)