import os import numpy as np import torch import torchvision from torchvision.transforms import Resize import imageio from einops import rearrange import cv2 from PIL import Image from annotator.util import resize_image, HWC3 from annotator.canny import CannyDetector from annotator.openpose import OpenposeDetector import decord decord.bridge.set_bridge('torch') apply_canny = CannyDetector() apply_openpose = OpenposeDetector() def add_watermark(image, im_size_h, im_size_w, watermark_path="__assets__/picsart_watermark.jpg", wmsize=16, bbuf=5, opacity=0.9): ''' Creates a watermark on the saved inference image. We request that you do not remove this to properly assign credit to Shi-Lab's work. ''' watermark = Image.open(watermark_path).resize((wmsize, wmsize)) loc_h = im_size_h - wmsize - bbuf loc_w = im_size_w - wmsize - bbuf image[loc_h:-bbuf, loc_w:-bbuf, :] = watermark return image def pre_process_canny(input_video, low_threshold=100, high_threshold=200): detected_maps = [] for frame in input_video: img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8) detected_map = apply_canny(img, low_threshold, high_threshold) detected_map = HWC3(detected_map) detected_maps.append(detected_map[None]) detected_maps = np.concatenate(detected_maps) control = torch.from_numpy(detected_maps.copy()).float() / 255.0 return rearrange(control, 'f h w c -> f c h w') def pre_process_pose(input_video, apply_pose_detect: bool = True): detected_maps = [] for frame in input_video: img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8) img = HWC3(img) if apply_pose_detect: detected_map, _ = apply_openpose(img) else: detected_map = img detected_map = HWC3(detected_map) H, W, C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) detected_maps.append(detected_map[None]) detected_maps = np.concatenate(detected_maps) control = torch.from_numpy(detected_maps.copy()).float() / 255.0 return rearrange(control, 'f h w c -> f c h w') def create_video(frames, fps, rescale=False, path=None): if path is None: dir = "temporal" os.makedirs(dir, exist_ok=True) path = os.path.join(dir, 'movie.mp4') outputs = [] for i, x in enumerate(frames): x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4) if rescale: x = (x + 1.0) / 2.0 # -1,1 -> 0,1 x = (x * 255).numpy().astype(np.uint8) h_, w_, _ = x.shape x = add_watermark(x, im_size_h=h_, im_size_w=w_) outputs.append(x) # imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x) imageio.mimsave(path, outputs, fps=fps) return path def create_gif(frames, fps, rescale=False): dir = "temporal" os.makedirs(dir, exist_ok=True) path = os.path.join(dir, 'canny_db.gif') outputs = [] for i, x in enumerate(frames): x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4) if rescale: x = (x + 1.0) / 2.0 # -1,1 -> 0,1 x = (x * 255).numpy().astype(np.uint8) h_, w_, _ = x.shape x = add_watermark(x, im_size_h=h_, im_size_w=w_) outputs.append(x) # imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x) imageio.mimsave(path, outputs, fps=fps) return path def prepare_video(video_path:str, resolution:int, device, dtype, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1): vr = decord.VideoReader(video_path) video = vr.get_batch(range(0, len(vr))).asnumpy() initial_fps = vr.get_avg_fps() if output_fps == -1: output_fps = int(initial_fps) if end_t == -1: end_t = len(vr) / initial_fps else: end_t = min(len(vr) / initial_fps, end_t) assert 0 <= start_t < end_t assert output_fps > 0 f, h, w, c = video.shape start_f_ind = int(start_t * initial_fps) end_f_ind = int(end_t * initial_fps) num_f = int((end_t - start_t) * output_fps) sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int) video = video[sample_idx] video = rearrange(video, "f h w c -> f c h w") video = torch.Tensor(video).to(device).to(dtype) if h > w: w = int(w * resolution / h) w = w - w % 8 h = resolution - resolution % 8 video = Resize((h, w))(video) else: h = int(h * resolution / w) h = h - h % 8 w = resolution - resolution % 8 video = Resize((h, w))(video) if normalize: video = video / 127.5 - 1.0 return video, output_fps def post_process_gif(list_of_results, image_resolution): output_file = "/tmp/ddxk.gif" imageio.mimsave(output_file, list_of_results, fps=4) return output_file class CrossFrameAttnProcessor: def __init__(self, unet_chunk_size=2): self.unet_chunk_size = unet_chunk_size def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) is_cross_attention = encoder_hidden_states is not None if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.cross_attention_norm: encoder_hidden_states = attn.norm_cross(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # Sparse Attention if not is_cross_attention: video_length = key.size()[0] // self.unet_chunk_size # former_frame_index = torch.arange(video_length) - 1 # former_frame_index[0] = 0 former_frame_index = [0] * video_length key = rearrange(key, "(b f) d c -> b f d c", f=video_length) key = key[:, former_frame_index] key = rearrange(key, "b f d c -> (b f) d c") value = rearrange(value, "(b f) d c -> b f d c", f=video_length) value = value[:, former_frame_index] value = rearrange(value, "b f d c -> (b f) d c") query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states