File size: 14,060 Bytes
2ba4412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
import os
import os.path as osp
import sys
import cv2
import glob
import math
import torch
import gzip
import copy
import time
import json
import pickle
import base64
import imageio
import hashlib
import requests
import binascii
import zipfile
# import skvideo.io
import numpy as np
from io import BytesIO
import urllib.request
import torch.nn.functional as F
import torchvision.utils as tvutils
from multiprocessing.pool import ThreadPool as Pool
from einops import rearrange
from PIL import Image, ImageDraw, ImageFont


def gen_text_image(captions, text_size):
    num_char = int(38 * (text_size / text_size))
    font_size = int(text_size / 20)
    font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=font_size)
    text_image_list = []
    for text in captions:
        txt_img = Image.new("RGB", (text_size, text_size), color="white") 
        draw = ImageDraw.Draw(txt_img)
        lines = "\n".join(text[start:start + num_char] for start in range(0, len(text), num_char))
        draw.text((0, 0), lines, fill="black", font=font)
        txt_img = np.array(txt_img)
        text_image_list.append(txt_img)
    text_images = np.stack(text_image_list, axis=0)
    text_images = torch.from_numpy(text_images)
    return text_images

@torch.no_grad()
def save_video_refimg_and_text(
    local_path,
    ref_frame,
    gen_video, 
    captions, 
    mean=[0.5, 0.5, 0.5], 
    std=[0.5, 0.5, 0.5], 
    text_size=256, 
    nrow=4, 
    save_fps=8,
    retry=5):
    ''' 
    gen_video: BxCxFxHxW
    '''
    nrow = max(int(gen_video.size(0) / 2), 1)
    vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
    vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw

    text_images = gen_text_image(captions, text_size) # Tensor 8x256x256x3
    text_images = text_images.unsqueeze(1) # Tensor 8x1x256x256x3
    text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 8x16x256x256x3

    ref_frame = ref_frame.unsqueeze(2)
    ref_frame = ref_frame.mul_(vid_std).add_(vid_mean)
    ref_frame = ref_frame.repeat_interleave(repeats=gen_video.size(2), dim=2) # 8x16x256x256x3
    ref_frame.clamp_(0, 1)
    ref_frame = ref_frame * 255.0
    ref_frame = rearrange(ref_frame, 'b c f h w -> b f h w c')
    
    gen_video = gen_video.mul_(vid_std).add_(vid_mean)  # 8x3x16x256x384
    gen_video.clamp_(0, 1)
    gen_video = gen_video * 255.0

    images = rearrange(gen_video, 'b c f h w -> b f h w c')
    images = torch.cat([ref_frame, images, text_images], dim=3)

    images = rearrange(images, '(r j) f h w c -> f (r h) (j w) c', r=nrow)
    images = [(img.numpy()).astype('uint8') for img in images]

    for _ in [None] * retry:
        try:
            if len(images) == 1:
                local_path = local_path + '.png'
                cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
            else:
                local_path = local_path + '.mp4'
                frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path)))
                os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True)
                for fid, frame in enumerate(images):
                    tpth = os.path.join(frame_dir, '%04d.png' % (fid+1))
                    cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
                cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17  -pix_fmt yuv420p {local_path}'
                os.system(cmd); os.system(f'rm -rf {frame_dir}')
                # os.system(f'rm -rf {local_path}')
            exception = None
            break
        except Exception as e:
            exception = e
            continue


@torch.no_grad()
def save_i2vgen_video(
    local_path,
    image_id,
    gen_video, 
    captions, 
    mean=[0.5, 0.5, 0.5], 
    std=[0.5, 0.5, 0.5], 
    text_size=256, 
    retry=5,
    save_fps = 8
):
    ''' 
    Save both the generated video and the input conditions.
    '''
    vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
    vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw

    text_images = gen_text_image(captions, text_size) # Tensor 1x256x256x3
    text_images = text_images.unsqueeze(1) # Tensor 1x1x256x256x3
    text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 1x16x256x256x3

    image_id = image_id.unsqueeze(2) # B, C, F, H, W
    image_id = image_id.repeat_interleave(repeats=gen_video.size(2), dim=2) # 1x3x32x256x448
    image_id = image_id.mul_(vid_std).add_(vid_mean)  # 32x3x256x448
    image_id.clamp_(0, 1)
    image_id = image_id * 255.0
    image_id = rearrange(image_id, 'b c f h w -> b f h w c')

    gen_video = gen_video.mul_(vid_std).add_(vid_mean)  # 8x3x16x256x384
    gen_video.clamp_(0, 1)
    gen_video = gen_video * 255.0

    images = rearrange(gen_video, 'b c f h w -> b f h w c')
    images = torch.cat([image_id, images, text_images], dim=3)
    images = images[0]
    images = [(img.numpy()).astype('uint8') for img in images]

    exception = None
    for _ in [None] * retry:
        try:
            frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path)))
            os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True)
            for fid, frame in enumerate(images):
                tpth = os.path.join(frame_dir, '%04d.png' % (fid+1))
                cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
            cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17  -pix_fmt yuv420p {local_path}'
            os.system(cmd); os.system(f'rm -rf {frame_dir}')
            break
        except Exception as e:
            exception = e
            continue
    
    if exception is not None:
        raise exception


@torch.no_grad()
def save_i2vgen_video_safe(
    local_path,
    gen_video, 
    captions, 
    mean=[0.5, 0.5, 0.5], 
    std=[0.5, 0.5, 0.5], 
    text_size=256, 
    retry=5,
    save_fps = 8
):
    '''
    Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame.
    '''
    vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
    vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw

    gen_video = gen_video.mul_(vid_std).add_(vid_mean)  # 8x3x16x256x384
    gen_video.clamp_(0, 1)
    gen_video = gen_video * 255.0

    images = rearrange(gen_video, 'b c f h w -> b f h w c')
    images = images[0]
    images = [(img.numpy()).astype('uint8') for img in images]
    num_image = len(images)
    exception = None
    for _ in [None] * retry:
        try:
            if num_image == 1:
                local_path = local_path + '.png'
                cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
            else:
                writer = imageio.get_writer(local_path, fps=save_fps, codec='libx264', quality=8)
                for fid, frame in enumerate(images):
                    if fid == num_image-1: # Fix known bugs.
                        ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size)
                        if ratio > 0.4: continue
                    writer.append_data(frame)
                writer.close()
            break
        except Exception as e:
            exception = e
            continue
    
    if exception is not None:
        raise exception


@torch.no_grad()
def save_t2vhigen_video_safe(
    local_path,
    gen_video, 
    captions, 
    mean=[0.5, 0.5, 0.5], 
    std=[0.5, 0.5, 0.5], 
    text_size=256, 
    retry=5,
    save_fps = 8
):
    '''
    Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame.
    '''
    vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw
    vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw

    gen_video = gen_video.mul_(vid_std).add_(vid_mean)  # 8x3x16x256x384
    gen_video.clamp_(0, 1)
    gen_video = gen_video * 255.0

    images = rearrange(gen_video, 'b c f h w -> b f h w c')
    images = images[0]
    images = [(img.numpy()).astype('uint8') for img in images]
    num_image = len(images)
    exception = None
    for _ in [None] * retry:
        try:
            if num_image == 1:
                local_path = local_path + '.png'
                cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
            else:
                frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path)))
                os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True)
                for fid, frame in enumerate(images):
                    if fid == num_image-1: # Fix known bugs.
                        ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size)
                        if ratio > 0.4: continue
                    tpth = os.path.join(frame_dir, '%04d.png' % (fid+1))
                    cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
                cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17  -pix_fmt yuv420p {local_path}'
                os.system(cmd) 
                os.system(f'rm -rf {frame_dir}')
            break
        except Exception as e:
            exception = e
            continue
    
    if exception is not None:
        raise exception




@torch.no_grad()
def save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_tensor, model_kwargs, source_imgs, 
                                   mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], nrow=8, retry=5, save_fps=8):
    mean=torch.tensor(mean,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw
    std=torch.tensor(std,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw
    video_tensor = video_tensor.mul_(std).add_(mean)  #### unnormalize back to [0,1]
    video_tensor.clamp_(0, 1)

    b, c, n, h, w = video_tensor.shape
    source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w))
    source_imgs = source_imgs.cpu()

    model_kwargs_channel3 = {}
    for key, conditions in model_kwargs[0].items():

        
        if conditions.size(1) == 1:
            conditions = torch.cat([conditions, conditions, conditions], dim=1)
            conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
        if conditions.size(1) == 2:
            conditions = torch.cat([conditions, conditions[:,:1,]], dim=1)
            conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
        elif conditions.size(1) == 3:
            conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
        elif conditions.size(1) == 4: # means it is a mask.
            color = ((conditions[:, 0:3] + 1.)/2.) # .astype(np.float32)
            alpha = conditions[:, 3:4] # .astype(np.float32)
            conditions = color * alpha + 1.0 * (1.0 - alpha)
            conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
        model_kwargs_channel3[key] = conditions.cpu() if conditions.is_cuda else conditions
    
    # filename = rand_name(suffix='.gif')
    for _ in [None] * retry:
        try:
            vid_gif = rearrange(video_tensor, '(i j) c f h w -> c f (i h) (j w)', i = nrow)
            
            cons_list = [rearrange(con, '(i j) c f h w -> c f (i h) (j w)', i = nrow) for _, con in model_kwargs_channel3.items()]
            vid_gif = torch.cat(cons_list + [vid_gif,], dim=3)
            
            vid_gif = vid_gif.permute(1,2,3,0)
            
            images = vid_gif * 255.0
            images = [(img.numpy()).astype('uint8') for img in images]
            if len(images) == 1:
                
                local_path = local_path.replace('.mp4', '.png')
                cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
                # bucket.put_object_from_file(oss_key, local_path)
            else:

                outputs = []
                for image_name in images:
                    x = Image.fromarray(image_name)
                    outputs.append(x)
                from pathlib import Path
                save_fmt = Path(local_path).suffix

                if save_fmt == ".mp4":
                    with imageio.get_writer(local_path, fps=save_fps) as writer:
                        for img in outputs:
                            img_array = np.array(img)  # Convert PIL Image to numpy array
                            writer.append_data(img_array)

                elif save_fmt == ".gif":
                    outputs[0].save(
                        fp=local_path,
                        format="GIF",
                        append_images=outputs[1:],
                        save_all=True,
                        duration=(1 / save_fps * 1000),
                        loop=0,
                    )
                else:
                    raise ValueError("Unsupported file type. Use .mp4 or .gif.")

                # fourcc = cv2.VideoWriter_fourcc(*'mp4v')
                # fps = save_fps
                # image = images[0] 
                # media_writer = cv2.VideoWriter(local_path, fourcc, fps, (image.shape[1],image.shape[0]))
                # for image_name in images:
                #     im = image_name[:,:,::-1] 
                #     media_writer.write(im)
                # media_writer.release()
                
            
            exception = None
            break
        except Exception as e:
            exception = e
            continue
    if exception is not None:
        print('save video to {} failed, error: {}'.format(local_path, exception), flush=True)