Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import torch | |
import torchvision | |
from torchvision.transforms import Resize | |
import imageio | |
from einops import rearrange | |
import cv2 | |
from annotator.util import resize_image, HWC3 | |
from annotator.canny import CannyDetector | |
from annotator.openpose import OpenposeDetector | |
import decord | |
decord.bridge.set_bridge('torch') | |
apply_canny = CannyDetector() | |
apply_openpose = OpenposeDetector() | |
def add_watermark(image, im_size, watermark_path="__assets__/pair_watermark.png", | |
wmsize=16, bbuf=5, opacity=0.9): | |
''' | |
Creates a watermark on the saved inference image. | |
We request that you do not remove this to properly assign credit to | |
Shi-Lab's work. | |
''' | |
watermark = Image.open(watermark_path).resize((wmsize, wmsize)) | |
loc = im_size - wmsize - bbuf | |
image[:,:,loc:-bbuf, loc:-bbuf] = watermark | |
return image | |
def pre_process_canny(input_video, low_threshold=100, high_threshold=200): | |
detected_maps = [] | |
for frame in input_video: | |
img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8) | |
detected_map = apply_canny(img, low_threshold, high_threshold) | |
detected_map = HWC3(detected_map) | |
detected_maps.append(detected_map[None]) | |
detected_maps = np.concatenate(detected_maps) | |
control = torch.from_numpy(detected_maps.copy()).float() / 255.0 | |
return rearrange(control, 'f h w c -> f c h w') | |
def pre_process_pose(input_video, apply_pose_detect: bool = True): | |
detected_maps = [] | |
for frame in input_video: | |
img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8) | |
img = HWC3(img) | |
if apply_pose_detect: | |
detected_map, _ = apply_openpose(img) | |
else: | |
detected_map = img | |
detected_map = HWC3(detected_map) | |
H, W, C = img.shape | |
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) | |
detected_maps.append(detected_map[None]) | |
detected_maps = np.concatenate(detected_maps) | |
control = torch.from_numpy(detected_maps.copy()).float() / 255.0 | |
return rearrange(control, 'f h w c -> f c h w') | |
def create_video(frames, fps, rescale=False, path=None): | |
if path is None: | |
dir = "temporal" | |
os.makedirs(dir, exist_ok=True) | |
path = os.path.join(dir, 'movie.mp4') | |
outputs = [] | |
for i, x in enumerate(frames): | |
x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = (x * 255).numpy().astype(np.uint8) | |
x = add_watermark(x, im_size=512) | |
outputs.append(x) | |
# imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x) | |
imageio.mimsave(path, outputs, fps=fps) | |
return path | |
def create_gif(frames, fps, rescale=False): | |
dir = "temporal" | |
os.makedirs(dir, exist_ok=True) | |
path = os.path.join(dir, 'canny_db.gif') | |
outputs = [] | |
for i, x in enumerate(frames): | |
x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = (x * 255).numpy().astype(np.uint8) | |
x = add_watermark(x, im_size=512) | |
outputs.append(x) | |
# imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x) | |
imageio.mimsave(path, outputs, fps=fps) | |
return path | |
def prepare_video(video_path:str, resolution:int, device, dtype, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1): | |
vr = decord.VideoReader(video_path) | |
video = vr.get_batch(range(0, len(vr))).asnumpy() | |
initial_fps = vr.get_avg_fps() | |
if output_fps == -1: | |
output_fps = int(initial_fps) | |
if end_t == -1: | |
end_t = len(vr) / initial_fps | |
else: | |
end_t = min(len(vr) / initial_fps, end_t) | |
assert 0 <= start_t < end_t | |
assert output_fps > 0 | |
f, h, w, c = video.shape | |
start_f_ind = int(start_t * initial_fps) | |
end_f_ind = int(end_t * initial_fps) | |
num_f = int((end_t - start_t) * output_fps) | |
sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int) | |
video = video[sample_idx] | |
video = rearrange(video, "f h w c -> f c h w") | |
video = torch.Tensor(video).to(device).to(dtype) | |
if h > w: | |
w = int(w * resolution / h) | |
w = w - w % 8 | |
h = resolution - resolution % 8 | |
video = Resize((h, w))(video) | |
else: | |
h = int(h * resolution / w) | |
h = h - h % 8 | |
w = resolution - resolution % 8 | |
video = Resize((h, w))(video) | |
if normalize: | |
video = video / 127.5 - 1.0 | |
return video, output_fps | |
def post_process_gif(list_of_results, image_resolution): | |
output_file = "/tmp/ddxk.gif" | |
imageio.mimsave(output_file, list_of_results, fps=4) | |
return output_file | |
class CrossFrameAttnProcessor: | |
def __init__(self, unet_chunk_size=2): | |
self.unet_chunk_size = unet_chunk_size | |
def __call__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None): | |
batch_size, sequence_length, _ = hidden_states.shape | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
query = attn.to_q(hidden_states) | |
is_cross_attention = encoder_hidden_states is not None | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.cross_attention_norm: | |
encoder_hidden_states = attn.norm_cross(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
# Sparse Attention | |
if not is_cross_attention: | |
video_length = key.size()[0] // self.unet_chunk_size | |
# former_frame_index = torch.arange(video_length) - 1 | |
# former_frame_index[0] = 0 | |
former_frame_index = [0] * video_length | |
key = rearrange(key, "(b f) d c -> b f d c", f=video_length) | |
key = key[:, former_frame_index] | |
key = rearrange(key, "b f d c -> (b f) d c") | |
value = rearrange(value, "(b f) d c -> b f d c", f=video_length) | |
value = value[:, former_frame_index] | |
value = rearrange(value, "b f d c -> (b f) d c") | |
query = attn.head_to_batch_dim(query) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
return hidden_states | |