File size: 7,135 Bytes
a001281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import decord
import cv2

import os, io, csv, torch, math, random
from typing import Optional
from einops import rearrange
import numpy as np
from decord import VideoReader
from petrel_client.client import Client
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
from torch.utils.data.distributed import DistributedSampler

import animatediff.data.video_transformer as video_transforms
from animatediff.utils.util import zero_rank_print, detect_edges, prepare_mask_coef_by_score


def get_score(video_data,
              cond_frame_idx,
              weight=[1.0, 1.0, 1.0, 1.0],
              use_edge=True):
    """
        Similar to get_score under utils/util.py/detect_edges
    """
    """
        the shape of video_data is f c h w, np.ndarray
    """
    h, w = video_data.shape[1], video_data.shape[2]

    cond_frame = video_data[cond_frame_idx]
    cond_hsv_list = list(
        cv2.split(
            cv2.cvtColor(cond_frame.astype(np.float32), cv2.COLOR_RGB2HSV)))

    if use_edge:
        cond_frame_lum = cond_hsv_list[-1]
        cond_frame_edge = detect_edges(cond_frame_lum.astype(np.uint8))
        cond_hsv_list.append(cond_frame_edge)

    score_sum = []

    for frame_idx in range(video_data.shape[0]):
        frame = video_data[frame_idx]
        hsv_list = list(
            cv2.split(cv2.cvtColor(frame.astype(np.float32),
                                   cv2.COLOR_RGB2HSV)))

        if use_edge:
            frame_img_lum = hsv_list[-1]
            frame_img_edge = detect_edges(lum=frame_img_lum.astype(np.uint8))
            hsv_list.append(frame_img_edge)

        hsv_diff = [
            np.abs(hsv_list[c] - cond_hsv_list[c]) for c in range(len(weight))
        ]
        hsv_mse = [np.sum(hsv_diff[c]) * weight[c] for c in range(len(weight))]
        score_sum.append(sum(hsv_mse) / (h * w) / (sum(weight)))

    return score_sum


class WebVid10M(Dataset):
    def __init__(
            self,
            csv_path,
            sample_n_frames, sample_stride,
            sample_size=[320,512],
            conf_path="~/petreloss.conf",
            static_video=False,
            is_image=False,
        ):
        zero_rank_print(f"initializing ceph client ...")
        self._client          = Client(conf_path=conf_path, enable_mc=True)
        self.sample_n_frames  = sample_n_frames
        self.sample_stride    = sample_stride
        self.temporal_sampler = video_transforms.TemporalRandomCrop(sample_n_frames * sample_stride)
        self.static_video     = static_video
        self.is_image         = is_image

        zero_rank_print(f"(~1 mins) loading annotations from {csv_path} ...")
        with open(csv_path, 'r') as csvfile:
            self.dataset = list(csv.DictReader(csvfile))
        self.length = len(self.dataset)
        zero_rank_print(f"data scale: {self.length}")

        sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
        self.pixel_transforms = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize(sample_size[0]),
            transforms.CenterCrop(sample_size),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        ])

    def get_batch(self, idx):

        video_dict = self.dataset[idx]
        videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
        ceph_dir = f"webvideo:s3://WebVid10M/{page_dir}/{videoid}.mp4"

        video_bytes = self._client.Get(ceph_dir)
        video_bytes = io.BytesIO(video_bytes)

        # ensure not reading zero byte
        assert video_bytes.getbuffer().nbytes != 0

        video_reader = VideoReader(video_bytes)
        total_frames = len(video_reader)

        if not self.is_image:
            if self.static_video:
                frame_indice = random.randint(0, total_frames-1)
                frame_indice = np.linspace(frame_indice, frame_indice, self.sample_n_frames, dtype=int)

            else:
                start_frame_ind, end_frame_ind = self.temporal_sampler(total_frames)
                assert end_frame_ind - start_frame_ind >= self.sample_n_frames
                frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int)

        else:
            frame_indice = [random.randint(0, total_frames - 1)]

        pixel_values_np = video_reader.get_batch(frame_indice).asnumpy()

        cond_frames = random.randint(0, self.sample_n_frames - 1)

        # f h w c -> f c h w
        pixel_values = torch.from_numpy(pixel_values_np).permute(0, 3, 1, 2).contiguous()
        pixel_values = pixel_values / 255.
        del video_reader

        if self.is_image:
            pixel_values = pixel_values[0]

        return pixel_values, name, cond_frames, videoid

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        while True:
            try:
                video, name, cond_frames, videoid = self.get_batch(idx)
                break

            except Exception as e:
                # zero_rank_print(e)
                idx = random.randint(0, self.length-1)

        video  = self.pixel_transforms(video)
        video_ = video.clone().permute(0, 2, 3, 1).numpy() / 2 + 0.5
        video_ = video_ * 255
        #video_ = video_.astype(np.uint8)
        score  = get_score(video_, cond_frame_idx=cond_frames)
        del video_
        sample = dict(pixel_values=video, text=name, score=score, cond_frames=cond_frames, vid=videoid)
        return sample



if __name__ == "__main__":
    dataset = WebVid10M(
        csv_path="results_10M_train.csv",
        sample_size=(320,512),
        sample_n_frames=16,
        sample_stride=4,
        static_video=False,
        is_image=False,
    )

    distributed_sampler = DistributedSampler(
        dataset,
        num_replicas=1,
        rank=0,
        shuffle=True,
        seed=5,
    )
    batch_size = 1
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0, sampler=distributed_sampler)

    STATISTIC = [[0., 0.],
        [0.3535855, 24.23687346],
        [0.91609545, 30.65091947],
        [1.41165152, 34.40093286],
        [1.56943881, 36.99639585],
        [1.73182842, 39.42044163],
        [1.82733002, 40.94703526],
        [1.88060527, 42.66233244],
        [1.96208071, 43.73070788],
        [2.02723091, 44.25965378],
        [2.10820894, 45.66120213],
        [2.21115041, 46.29561324],
        [2.23412351, 47.08810863],
        [2.29430165, 47.9515062],
        [2.32986362, 48.69085638],
        [2.37310751, 49.19931439]]

    for idx, batch in enumerate(dataloader):
        pixel_values, texts, vid = batch['pixel_values'], batch['text'], batch['vid']
        pixel_values = (pixel_values.clone()) / 2. + 0.5
        pixel_values*= 255
        score        = get_score(pixel_values)
        cond_frames  = [0] * len(batch_size)
        score        = prepare_mask_coef_by_score(pixel_values, cond_frames, statistic=STATISTIC)
        print(f'num: {idx}, diff: {score}')