Spaces:
Build error
Build error
import os | |
import PIL.Image | |
import numpy as np | |
import torchvision | |
from torchvision.transforms import Resize, InterpolationMode | |
import imageio | |
from einops import rearrange | |
import cv2 | |
from PIL import Image | |
from annotator.util import resize_image, HWC3 | |
from annotator.openpose import OpenposeDetector | |
import decord | |
import jax | |
import torch | |
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( | |
download_from_original_stable_diffusion_ckpt, | |
) | |
from huggingface_hub import hf_hub_download | |
import flax.linen as nn | |
apply_openpose = OpenposeDetector() | |
def add_watermark(image, watermark_path, wm_rel_size=1 / 16, boundary=5): | |
""" | |
Creates a watermark on the saved inference image. | |
We request that you do not remove this to properly assign credit to | |
Shi-Lab's work. | |
""" | |
watermark = Image.open(watermark_path) | |
w_0, h_0 = watermark.size | |
H, W, _ = image.shape | |
wmsize = int(max(H, W) * wm_rel_size) | |
aspect = h_0 / w_0 | |
if aspect > 1.0: | |
watermark = watermark.resize((wmsize, int(aspect * wmsize)), Image.LANCZOS) | |
else: | |
watermark = watermark.resize((int(wmsize / aspect), wmsize), Image.LANCZOS) | |
w, h = watermark.size | |
loc_h = H - h - boundary | |
loc_w = W - w - boundary | |
image = Image.fromarray(image) | |
mask = watermark if watermark.mode in ("RGBA", "LA") else None | |
image.paste(watermark, (loc_w, loc_h), mask) | |
return image | |
def load_safetensors_model(model_link): | |
ckpt_path = hf_hub_download( | |
repo_id=model_link, filename="ligne_claire_anime_diffusion_v1.safetensors" | |
) | |
print(f"Checkpoint path: {ckpt_path}") | |
# !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml | |
pipe = download_from_original_stable_diffusion_ckpt( | |
checkpoint_path=ckpt_path, | |
original_config_file="v1-inference.yaml", | |
from_safetensors=True, | |
) | |
pipe.save_pretrained("./models/ligne_claire", safe_serialization=True) | |
return pipe | |
def pre_process_pose(input_video, apply_pose_detect: bool = True): | |
detected_maps = [] | |
for frame in input_video: | |
img = rearrange(frame, "c h w -> h w c").astype(np.uint8) | |
img = HWC3(img) | |
if apply_pose_detect: | |
detected_map, _ = apply_openpose(img) | |
else: | |
detected_map = img | |
detected_map = HWC3(detected_map) | |
H, W, C = img.shape | |
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) | |
detected_maps.append(detected_map[None]) | |
detected_maps = np.concatenate(detected_maps) | |
control = (detected_maps.copy()) / 255.0 | |
return rearrange(control, "f h w c -> f c h w") | |
def create_video(frames, fps, rescale=False, path=None, watermark=None): | |
if path is None: | |
dir = "temporal" | |
os.makedirs(dir, exist_ok=True) | |
path = os.path.join(dir, "movie.mp4") | |
outputs = [] | |
for i, x in enumerate(frames): | |
x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = (x * 255).numpy().astype(np.uint8) | |
if watermark is not None: | |
x = add_watermark(x, watermark) | |
outputs.append(x) | |
# imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x) | |
imageio.mimsave(path, outputs, fps=fps) | |
return path | |
def create_gif(frames, fps, rescale=False, path=None, watermark=None): | |
if path is None: | |
dir = "temporal" | |
os.makedirs(dir, exist_ok=True) | |
path = os.path.join(dir, "canny_db.gif") | |
outputs = [] | |
for i, x in enumerate(frames): | |
x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = (x * 255).numpy().astype(np.uint8) | |
if watermark is not None: | |
x = add_watermark(x, watermark) | |
outputs.append(x) | |
# imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x) | |
imageio.mimsave(path, outputs, fps=fps) | |
return path | |
def prepare_video( | |
video_path: str, | |
resolution: int, | |
device, | |
dtype, | |
normalize=True, | |
start_t: float = 0, | |
end_t: float = -1, | |
output_fps: int = -1, | |
): | |
vr = decord.VideoReader(video_path) | |
initial_fps = vr.get_avg_fps() | |
if output_fps == -1: | |
output_fps = int(initial_fps) | |
if end_t == -1: | |
end_t = len(vr) / initial_fps | |
else: | |
end_t = min(len(vr) / initial_fps, end_t) | |
assert 0 <= start_t < end_t | |
assert output_fps > 0 | |
start_f_ind = int(start_t * initial_fps) | |
end_f_ind = int(end_t * initial_fps) | |
num_f = int((end_t - start_t) * output_fps) | |
sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int) | |
video = vr.get_batch(sample_idx) | |
video = video.asnumpy() | |
_, h, w, _ = video.shape | |
video = rearrange(video, "f h w c -> f c h w") | |
video = torch.Tensor(video) # .to(device).to(dtype) | |
# Use max if you want the larger side to be equal to resolution (e.g. 512) | |
# k = float(resolution) / min(h, w) | |
k = float(resolution) / max(h, w) | |
h *= k | |
w *= k | |
h = int(np.round(h / 64.0)) * 64 | |
w = int(np.round(w / 64.0)) * 64 | |
video = Resize((h, w), interpolation=InterpolationMode.BILINEAR, antialias=True)( | |
video | |
) | |
if normalize: | |
video = video / 127.5 - 1.0 | |
# video = rearrange(video, "f c h w -> f h w c").numpy() #channel first to channel last | |
video = video.numpy() | |
return video, output_fps | |
def post_process_gif(list_of_results, image_resolution): | |
output_file = "/tmp/ddxk.gif" | |
imageio.mimsave(output_file, list_of_results, fps=4) | |
return output_file | |