File size: 8,798 Bytes
0a63786
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import os
import matplotlib.pyplot as plt
import torch
import numpy as np
import cv2
import imageio
from PIL import Image
import textwrap

def find_nearest_Nx(size, N=32):
    return int(np.ceil(size / N) * N)

def load_image_as_tensor(image_path, image_size):
    if isinstance(image_size, int):
        image_size = (image_size, image_size)
    image = cv2.imread(image_path)[..., ::-1]
    try:
        image = cv2.resize(image, image_size)
    except Exception as e:
        print(e)
        print(image_path)

    image = torch.from_numpy(np.array(image).transpose(2, 0, 1)) / 255.
    return image

def show_image(image):
    if len(image.shape) == 4:
        image = image[0]
    plt.imshow(image.permute(1, 2, 0).detach().cpu().numpy())
    plt.show()

def extract_video(video_path, save_dir, sampling_fps, skip_frames=0):
    os.makedirs(save_dir, exist_ok=True)
    cap = cv2.VideoCapture(video_path)
    frame_skip = int(cap.get(cv2.CAP_PROP_FPS) / sampling_fps)
    frame_count = 0
    save_count = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if frame_count < skip_frames:  # skip the first N frames
            frame_count += 1
            continue
        if (frame_count - skip_frames) % frame_skip == 0:
            # Save the frame as an image file if it doesn't already exist
            save_path = os.path.join(save_dir, f"frame{save_count:04d}.jpg")
            save_count += 1
            if not os.path.exists(save_path):
                cv2.imwrite(save_path, frame)
        frame_count += 1
    cap.release()
    cv2.destroyAllWindows()

def concatenate_frames_to_video(frame_dir, video_path, fps):
    os.makedirs(os.path.dirname(video_path), exist_ok=True)
    # Get the list of frame file names in the directory
    frame_files = [f for f in os.listdir(frame_dir) if f.startswith("frame")]
    # Sort the frame file names in ascending order
    frame_files.sort()
    # Load the first frame to get the frame size
    frame = cv2.imread(os.path.join(frame_dir, frame_files[0]))
    height, width, _ = frame.shape
    # Initialize the video writer
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
    # Loop through the frame files and add them to the video
    for frame_file in frame_files:
        frame_path = os.path.join(frame_dir, frame_file)
        frame = cv2.imread(frame_path)
        out.write(frame)
    # Release the video writer
    out.release()

def cumulative_histogram(hist):
    cum_hist = hist.copy()
    for i in range(1, len(hist)):
        cum_hist[i] = cum_hist[i - 1] + hist[i]
    return cum_hist

def histogram_matching(src_img, ref_img):
    src_img = (src_img * 255).astype(np.uint8)
    ref_img = (ref_img * 255).astype(np.uint8)
    src_img_yuv = cv2.cvtColor(src_img, cv2.COLOR_RGB2YUV)
    ref_img_yuv = cv2.cvtColor(ref_img, cv2.COLOR_RGB2YUV)

    matched_img = np.zeros_like(src_img_yuv)
    for channel in range(src_img_yuv.shape[2]):
        src_hist, _ = np.histogram(src_img_yuv[:, :, channel].ravel(), 256, (0, 256))
        ref_hist, _ = np.histogram(ref_img_yuv[:, :, channel].ravel(), 256, (0, 256))

        src_cum_hist = cumulative_histogram(src_hist)
        ref_cum_hist = cumulative_histogram(ref_hist)

        lut = np.zeros(256, dtype=np.uint8)
        j = 0
        for i in range(256):
            while ref_cum_hist[j] < src_cum_hist[i] and j < 255:
                j += 1
            lut[i] = j

        matched_img[:, :, channel] = cv2.LUT(src_img_yuv[:, :, channel], lut)

    matched_img = cv2.cvtColor(matched_img, cv2.COLOR_YUV2RGB)
    matched_img = matched_img.astype(np.float32) / 255
    return matched_img

def canny_image_batch(image_batch, low_threshold=100, high_threshold=200):
    if isinstance(image_batch, torch.Tensor):
        # [-1, 1] tensor -> [0, 255] numpy array
        is_torch = True
        device = image_batch.device
        image_batch = (image_batch + 1) * 127.5
        image_batch = image_batch.permute(0, 2, 3, 1).detach().cpu().numpy()
        image_batch = image_batch.astype(np.uint8)
    image_batch = np.array([cv2.Canny(image, low_threshold, high_threshold) for image in image_batch])
    image_batch = image_batch[:, :, :, None]
    image_batch = np.concatenate([image_batch, image_batch, image_batch], axis=3)

    if is_torch:
        # [0, 255] numpy array -> [-1, 1] tensor
        image_batch = torch.from_numpy(image_batch).permute(0, 3, 1, 2).float() / 255.
        image_batch = image_batch.to(device)
    return image_batch


def images_to_gif(images, filename, fps):
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    # Normalize to 0-255 and convert to uint8
    images = [(img * 255).astype(np.uint8) if img.dtype == np.float32 else img for img in images]
    images = [Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) for img in images]
    imageio.mimsave(filename, images, duration=1 / fps)

def load_gif(image_path):
    import imageio
    gif = imageio.get_reader(image_path)
    np_images = np.array([frame[..., :3] for frame in gif])
    return np_images

def add_text_to_frame(frame, text, font_scale=1, thickness=2, color=(0, 0, 0), bg_color=(255, 255, 255), max_width=30):
    """
    Add text to a frame.
    """
    # Make a copy of the frame
    frame_with_text = np.copy(frame)
    # Choose font
    font = cv2.FONT_HERSHEY_SIMPLEX
    # Split text into lines if it's too long
    lines = textwrap.wrap(text, width=max_width)
    # Get total text height
    total_text_height = len(lines) * (thickness * font_scale + 10) + 60 * font_scale
    # Create an image filled with the background color, having enough space for the text
    text_bg_img = np.full((int(total_text_height), frame.shape[1], 3), bg_color, dtype=np.uint8)
    # Put each line on the text background image
    y = 0
    for line in lines:
        text_size, _ = cv2.getTextSize(line, font, font_scale, thickness)
        text_x = (text_bg_img.shape[1] - text_size[0]) // 2
        y += text_size[1] + 10
        cv2.putText(text_bg_img, line, (text_x, y), font, font_scale, color, thickness)
    # Append the text background image to the frame
    frame_with_text = np.vstack((frame_with_text, text_bg_img))
    
    return frame_with_text

def add_text_to_gif(numpy_images, text, **kwargs):
    """
    Add text to each frame of a gif.
    """
    # Iterate over frames and add text to each frame
    frames_with_text = []
    for frame in numpy_images:
        frame_with_text = add_text_to_frame(frame, text, **kwargs)
        frames_with_text.append(frame_with_text)

    # Convert the list of frames to a numpy array
    numpy_images_with_text = np.array(frames_with_text)
    
    return numpy_images_with_text

def pad_images_to_same_height(images):
    """
    Pad images to the same height.
    """
    # Find the maximum height
    max_height = max(img.shape[0] for img in images)
    
    # Pad each image to the maximum height
    padded_images = []
    for img in images:
        pad_height = max_height - img.shape[0]
        padded_img = cv2.copyMakeBorder(img, 0, pad_height, 0, 0, cv2.BORDER_CONSTANT, value=[255, 255, 255])
        padded_images.append(padded_img)
    
    return padded_images

def concatenate_gifs(gifs):
    """
    Concatenate gifs.
    """
    # Ensure that all gifs have the same number of frames
    min_num_frames = min(gif.shape[0] for gif in gifs)
    gifs = [gif[:min_num_frames] for gif in gifs]
    
    # Concatenate each frame
    concatenated_gifs = []
    for i in range(min_num_frames):
        # Get the i-th frame from each gif
        frames = [gif[i] for gif in gifs]
        
        # Pad the frames to the same height
        padded_frames = pad_images_to_same_height(frames)
        
        # Concatenate the padded frames
        concatenated_frame = np.concatenate(padded_frames, axis=1)
        
        concatenated_gifs.append(concatenated_frame)

    return np.array(concatenated_gifs)

def stack_gifs(gifs):
    '''vertically stack gifs'''
    min_num_frames = min(gif.shape[0] for gif in gifs)
    stacked_gifs = []

    for i in range(min_num_frames):
        frames = [gif[i] for gif in gifs]
        stacked_frame = np.concatenate(frames, axis=0)
        stacked_gifs.append(stacked_frame)

    return np.array(stacked_gifs)

def save_tensor_to_gif(images, filename, fps):
    images = images.squeeze(0).detach().cpu().numpy().transpose(0, 2, 3, 1) / 2 + 0.5
    images_to_gif(images, filename, fps)

def save_tensor_to_images(images, output_dir):
    images = images.squeeze(0).detach().cpu().numpy().transpose(0, 2, 3, 1) / 2 + 0.5
    os.makedirs(output_dir, exist_ok=True)
    for i in range(images.shape[0]):
        plt.imsave(f'{output_dir}/{i:03d}.jpg', images[i])