File size: 22,169 Bytes
414b431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
import numpy as np
import os
import torch
import torchvision
import torchvision.transforms.functional as torchvision_F
import matplotlib.pyplot as plt
import PIL
import PIL.ImageDraw
from PIL import Image, ImageFont
import trimesh
import pyrender
import cv2
import copy
import base64
import io
import imageio

os.environ['PYOPENGL_PLATFORM'] = 'egl'
@torch.no_grad()
def tb_image(opt, tb, step, group, name, images, masks=None, num_vis=None, from_range=(0, 1), poses=None, cmap="gray", depth=False):
    if not depth:
        images = preprocess_vis_image(opt, images, masks=masks, from_range=from_range, cmap=cmap) # [B, 3, H, W]
    else:
        masks = (masks > 0.5).float()
        images = images * masks + (1 - masks) * ((images * masks).max())
        images = (1 - images).detach().cpu()
    num_H, num_W = num_vis or opt.tb.num_images
    images = images[:num_H*num_W]
    if poses is not None:
        # poses: [B, 3, 4]
        # rots: [max(B, num_images), 3, 3]
        rots = poses[:num_H*num_W, ..., :3]
        images = torch.stack([draw_pose(opt, image, rot, size=20, width=2) for image, rot in zip(images, rots)], dim=0)
    image_grid = torchvision.utils.make_grid(images[:, :3], nrow=num_W, pad_value=1.)
    if images.shape[1]==4:
        mask_grid = torchvision.utils.make_grid(images[:, 3:], nrow=num_W, pad_value=1.)[:1]
        image_grid = torch.cat([image_grid, mask_grid], dim=0)
    tag = "{0}/{1}".format(group, name)
    tb.add_image(tag, image_grid, step)

def preprocess_vis_image(opt, images, masks=None, from_range=(0, 1), cmap="gray"):
    min, max = from_range
    images = (images-min)/(max-min)
    if masks is not None:
        # then the mask is directly the transparency channel of png
        images = torch.cat([images, masks], dim=1)
    images = images.clamp(min=0, max=1).cpu()
    if images.shape[1]==1:
        images = get_heatmap(opt, images[:, 0].cpu(), cmap=cmap)
    return images

def preprocess_depth_image(opt, depth, mask=None, max_depth=1000):
    if mask is not None: depth = depth * mask + (1 - mask) * max_depth  # min of this will leads to minimum of masked regions
    depth = depth - depth.min()
    
    if mask is not None: depth = depth * mask   # max of this will leads to maximum of masked regions
    depth = depth / depth.max()
    return depth

def dump_images(opt, idx, name, images, masks=None, from_range=(0, 1), poses=None, metrics=None, cmap="gray", folder='dump'):
    images = preprocess_vis_image(opt, images, masks=masks, from_range=from_range, cmap=cmap) # [B, 3, H, W]
    if poses is not None:
        rots = poses[..., :3]
        images = torch.stack([draw_pose(opt, image, rot, size=20, width=2) for image, rot in zip(images, rots)], dim=0)
    if metrics is not None:
        images = torch.stack([draw_metric(opt, image, metric.item()) for image, metric in zip(images, metrics)], dim=0)
    images = images.cpu().permute(0, 2, 3, 1).contiguous().numpy() # [B, H, W, 3]
    for i, img in zip(idx, images):
        fname = "{}/{}/{}_{}.png".format(opt.output_path, folder, i, name)
        img = Image.fromarray((img*255).astype(np.uint8))
        img.save(fname)

def dump_depths(opt, idx, name, depths, masks=None, rescale=False, folder='dump'):
    if rescale:
        masks = (masks > 0.5).float()
        depths = depths * masks + (1 - masks) * ((depths * masks).max())
    depths = (1 - depths).detach().cpu()
    for i, depth in zip(idx, depths):
        fname = "{}/{}/{}_{}.png".format(opt.output_path, folder, i, name)
        plt.imsave(fname, depth.squeeze(), cmap='viridis')

# img_list is a list of length n_views, where each view is a image tensor of [B, 3, H, W] 
def dump_gifs(opt, idx, name, imgs_list, from_range=(0, 1), folder='dump', cmap="gray"):
    for i in range(len(imgs_list)):
        imgs_list[i] = preprocess_vis_image(opt, imgs_list[i], from_range=from_range, cmap=cmap)
    for i in range(len(idx)):
        img_list_np = [imgs[i].cpu().permute(1, 2, 0).contiguous().numpy() for imgs in imgs_list]  # list of [H, W, 3], each item is a view of ith sample
        img_list_pil = [Image.fromarray((img*255).astype(np.uint8)).convert('RGB') for img in img_list_np]
        fname = "{}/{}/{}_{}.gif".format(opt.output_path, folder, idx[i], name)
        img_list_pil[0].save(fname, format='GIF', append_images=img_list_pil[1:], save_all=True, duration=100, loop=0)

# img_list is a list of length n_views, where each view is a image tensor of [B, 3, H, W] 
def dump_attentions(opt, idx, name, attn_vis, folder='dump'):
    for i in range(len(idx)):
        img_list_pil = [Image.fromarray((img*255).astype(np.uint8)).convert('RGB') for img in attn_vis[i]]
        fname = "{}/{}/{}_{}.gif".format(opt.output_path, folder, idx[i], name)
        img_list_pil[0].save(fname, format='GIF', append_images=img_list_pil[1:], save_all=True, duration=50, loop=0)

def get_heatmap(opt, gray, cmap): # [N, H, W]
    color = plt.get_cmap(cmap)(gray.numpy())
    color = torch.from_numpy(color[..., :3]).permute(0, 3, 1, 2).contiguous().float() # [N, 3, H, W]
    return color

def dump_meshes(opt, idx, name, meshes, folder='dump'):
    for i, mesh in zip(idx, meshes):
        fname = "{}/{}/{}_{}.ply".format(opt.output_path, folder, i, name)
        try:
            mesh.export(fname)
        except:
            print('Mesh is empty!')

def dump_meshes_viz(opt, idx, name, meshes, save_frames=True, folder='dump'):
    for i, mesh in zip(idx, meshes):
        mesh = copy.deepcopy(mesh)
        R = trimesh.transformations.rotation_matrix(np.radians(180), [0,0,1])
        mesh.apply_transform(R)
        R = trimesh.transformations.rotation_matrix(np.radians(180), [0,1,0])
        mesh.apply_transform(R)
        # our  marching cubes outputs inverted normals for some reason so this is necessary
        trimesh.repair.fix_inversion(mesh) 

        fname = "{}/{}/{}_{}".format(opt.output_path, folder, i, name)
        try:
            mesh = scale_to_unit_cube(mesh)
            visualize_mesh(mesh, fname, write_frames=save_frames)
        except:
            pass

def dump_seen_surface(opt, idx, obj_name, img_name, seen_projs, folder='dump'):
    # seen_proj: [B, H, W, 3]
    for i, seen_proj in zip(idx, seen_projs):
        out_folder = "{}/{}".format(opt.output_path, folder)
        img_fname = "{}_{}.png".format(i, img_name)
        create_seen_surface(i, img_fname, seen_proj, out_folder, obj_name)

# https://github.com/princeton-vl/oasis/blob/master/utils/vis_mesh.py
def create_seen_surface(sample_ID, img_path, XYZ, output_folder, obj_name, connect_thres=0.005):
    height, width = XYZ.shape[:2]
    XYZ_to_idx = {}
    idx = 1
    with open("{}/{}_{}.mtl".format(output_folder, sample_ID, obj_name), "w") as f:
        f.write("newmtl material_0\n")
        f.write("Ka 0.200000 0.200000 0.200000\n")
        f.write("Kd 0.752941 0.752941 0.752941\n")
        f.write("Ks 1.000000 1.000000 1.000000\n")
        f.write("Tr 1.000000\n")
        f.write("illum 2\n")
        f.write("Ns 0.000000\n")
        f.write("map_Ka %s\n" % img_path)
        f.write("map_Kd %s\n" % img_path)

    with open("{}/{}_{}.obj".format(output_folder, sample_ID, obj_name), "w") as f:
        f.write("mtllib {}_{}.mtl\n".format(sample_ID, obj_name))
        for y in range(height):
            for x in range(width):
                if XYZ[y][x][2] > 0:
                    XYZ_to_idx[(y, x)] = idx
                    idx += 1
                    f.write("v %.4f %.4f %.4f\n" % (XYZ[y][x][0], XYZ[y][x][1], XYZ[y][x][2]))
                    f.write("vt %.8f %.8f\n" % ( float(x) / float(width), 1.0 - float(y) / float(height)))
        f.write("usemtl material_0\n")
        for y in range(height-1):
            for x in range(width-1):
                if XYZ[y][x][2] > 0 and XYZ[y][x+1][2] > 0 and XYZ[y+1][x][2] > 0:
                    # if close enough, connect vertices to form a face
                    if torch.norm(XYZ[y][x] - XYZ[y][x+1]).item() < connect_thres and torch.norm(XYZ[y][x] - XYZ[y+1][x]).item() < connect_thres:
                        f.write("f %d/%d %d/%d %d/%d\n" % (XYZ_to_idx[(y, x)], XYZ_to_idx[(y, x)], XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y+1, x)], XYZ_to_idx[(y+1, x)]))
                if XYZ[y][x+1][2] > 0 and XYZ[y+1][x+1][2] > 0 and XYZ[y+1][x][2] > 0:
                    if torch.norm(XYZ[y][x+1] - XYZ[y+1][x+1]).item() < connect_thres and torch.norm(XYZ[y][x+1] - XYZ[y+1][x]).item() < connect_thres:
                        f.write("f %d/%d %d/%d %d/%d\n" % (XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y+1, x+1)], XYZ_to_idx[(y+1, x+1)], XYZ_to_idx[(y+1, x)], XYZ_to_idx[(y+1, x)]))

def dump_pointclouds_compare(opt, idx, name, preds, gts, folder='dump'):
    for i in range(len(idx)):
        pred = preds[i].cpu().numpy()   # [N1, 3]
        gt = gts[i].cpu().numpy()   # [N2, 3]
        color_pred = np.zeros(pred.shape).astype(np.uint8)
        color_pred[:, 0] = 255
        color_gt = np.zeros(gt.shape).astype(np.uint8)
        color_gt[:, 1] = 255
        pc_vertices = np.vstack([pred, gt])
        colors = np.vstack([color_pred, color_gt])
        pc_color = trimesh.points.PointCloud(vertices=pc_vertices, colors=colors)
        fname = "{}/{}/{}_{}.ply".format(opt.output_path, folder, idx[i], name)
        pc_color.export(fname)

def dump_pointclouds(opt, idx, name, pcs, colors, folder='dump', colormap='jet'):
    for i, pc, color in zip(idx, pcs, colors):
        pc = pc.cpu().numpy()   # [N, 3]
        color = color.cpu().numpy()   # [N, 3] or [N, 1]
        # convert scalar color to rgb with colormap
        if color.shape[1] == 1:
            # single channel color in numpy between [0, 1] to rgb
            color = plt.get_cmap(colormap)(color[:, 0])
            color = (color * 255).astype(np.uint8)
        pc_color = trimesh.points.PointCloud(vertices=pc, colors=color)
        fname = "{}/{}/{}_{}.ply".format(opt.output_path, folder, i, name)
        pc_color.export(fname)

@torch.no_grad()
def vis_pointcloud(opt, vis, step, split, pred, GT=None):
    win_name = "{0}/{1}".format(opt.group, opt.name)
    pred, GT = pred.cpu().numpy(), GT.cpu().numpy()
    for i in range(opt.visdom.num_samples):
        # prediction
        data = [dict(
            type="scatter3d",
            x=[float(n) for n in points[i, :opt.visdom.num_points, 0]],
            y=[float(n) for n in points[i, :opt.visdom.num_points, 1]],
            z=[float(n) for n in points[i, :opt.visdom.num_points, 2]],
            mode="markers",
            marker=dict(
                color=color,
                size=1,
            ),
        ) for points, color in zip([pred, GT], ["blue", "magenta"])]
        vis._send(dict(
            data=data,
            win="{0} #{1}".format(split, i),
            eid="{0}/{1}".format(opt.group, opt.name),
            layout=dict(
                title="{0} #{1} ({2})".format(split, i, step),
                autosize=True,
                margin=dict(l=30, r=30, b=30, t=30, ),
                showlegend=False,
                yaxis=dict(
                    scaleanchor="x",
                    scaleratio=1,
                )
            ),
            opts=dict(title="{0} #{1} ({2})".format(win_name, i, step), ),
        ))

@torch.no_grad()
def draw_pose(opt, image, rot_mtrx, size=15, width=1):
    # rot_mtrx: [3, 4]
    mode = "RGBA" if image.shape[0]==4 else "RGB"
    image_pil = torchvision_F.to_pil_image(image.cpu()).convert("RGBA")
    draw_pil = PIL.Image.new("RGBA", image_pil.size, (0, 0, 0, 0))
    draw = PIL.ImageDraw.Draw(draw_pil)
    center = (size, size)
    # first column of rotation matrix is the rotated vector of [1, 0, 0]'
    # second column of rotation matrix is the rotated vector of [0, 1, 0]'
    # third column of rotation matrix is the rotated vector of [0, 0, 1]'
    # then always take the first two element of each column is a projection to the 2D plane for visualization
    endpoint = [(size+size*p[0], size+size*p[1]) for p in rot_mtrx.t()]
    draw.line([center, endpoint[0]], fill=(255, 0, 0), width=width)
    draw.line([center, endpoint[1]], fill=(0, 255, 0), width=width)
    draw.line([center, endpoint[2]], fill=(0, 0, 255), width=width)
    image_pil.alpha_composite(draw_pil)
    image_drawn = torchvision_F.to_tensor(image_pil.convert(mode))
    return image_drawn

@torch.no_grad()
def draw_metric(opt, image, metric):
    mode = "RGBA" if image.shape[0]==4 else "RGB"
    image_pil = torchvision_F.to_pil_image(image.cpu()).convert("RGBA")
    draw_pil = PIL.Image.new("RGBA", image_pil.size, (0, 0, 0, 0))
    draw = PIL.ImageDraw.Draw(draw_pil)
    font = ImageFont.truetype("DejaVuSans.ttf", 24)
    position = (image_pil.size[0] - 80, image_pil.size[1] - 35)
    draw.text(position, '{:.3f}'.format(metric), fill="red", font=font) 
    image_pil.alpha_composite(draw_pil)
    image_drawn = torchvision_F.to_tensor(image_pil.convert(mode))
    return image_drawn

@torch.no_grad()
def show_att_on_image(img, mask):
    """
    Convert the grayscale attention into heatmap on the image.
    Parameters
    ----------
    img: np.array, [H, W, 3]
        Original colored image in [0, 1].
    mask: np.array, [H, W]
        Attention map in [0, 1].
    Returns
    ----------
    np image with attention applied.
    """
    # check the validity
    assert np.max(img) <= 1
    assert np.max(mask) <= 1
    
    # generate heatmap and normalize into [0, 1]
    heatmap = cv2.cvtColor(cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
    heatmap = np.float32(heatmap) / 255

    # add heatmap onto the image
    merged = heatmap + np.float32(img)

    # re-scale the image
    merged = merged / np.max(merged)
    return merged

def look_at(camera_position, camera_target, up_vector):
	vector = camera_position - camera_target
	vector = vector / np.linalg.norm(vector)

	vector2 = np.cross(up_vector, vector)
	vector2 = vector2 / np.linalg.norm(vector2)

	vector3 = np.cross(vector, vector2)
	return np.array([
		[vector2[0], vector3[0], vector[0], 0.0],
		[vector2[1], vector3[1], vector[1], 0.0],
		[vector2[2], vector3[2], vector[2], 0.0],
		[-np.dot(vector2, camera_position), -np.dot(vector3, camera_position), np.dot(vector, camera_position), 1.0]
	])

def scale_to_unit_cube(mesh):
    if isinstance(mesh, trimesh.Scene):
        mesh = mesh.dump().sum()

    vertices = mesh.vertices - mesh.bounding_box.centroid
    vertices *= 2 / np.max(mesh.bounding_box.extents)
    vertices *= 0.5
	
    return trimesh.Trimesh(vertices=vertices, faces=mesh.faces)

def get_positions_and_rotations(n_frames=180, r=1.5): 
    '''
        n_frames: how many frames
        r: how far should the camera be
    '''
    # test case 1
    n_frame_full_circ = n_frames // 3 # frames for a full circle
    n_frame_half_circ = n_frames // 6 # frames for a half circle
    
    # full circle in horizontal axes going from 1 to -1 height axis
    pos1 = [np.array([r*np.cos(theta), elev, r*np.sin(theta)]) 
        for theta, elev in zip(np.linspace(0.5*np.pi,2.5*np.pi, n_frame_full_circ), np.linspace(1,-1,n_frame_full_circ))] 
    # half circle in horizontal axes at fixed -1 height
    pos2 = [np.array([r*np.cos(theta), -1, r*np.sin(theta)]) 
        for theta in np.linspace(2.5*np.pi,3.5*np.pi, n_frame_half_circ)]  
    # full circle in horizontal axes going from -1 to 1 height axis
    pos3 = [np.array([r*np.cos(theta), elev, r*np.sin(theta)]) 
        for theta, elev in zip(np.linspace(3.5*np.pi,5.5*np.pi, n_frame_full_circ), np.linspace(-1,1,n_frame_full_circ))] 
    # half circle in horizontal axes at fixed 1 height 
    pos4 = [np.array([r*np.cos(theta), 1, r*np.sin(theta)]) 
        for theta in np.linspace(3.5*np.pi,4.5*np.pi, n_frame_half_circ)] 

    pos = pos1 + pos2 + pos3 + pos4
    target = np.array([0.0, 0.0, 0.0])
    up = np.array([0.0, 1.0, 0.0])
    rot = [look_at(x, target, up) for x in pos]
    return pos, rot

def visualize_mesh(mesh, output_path, resolution=(200,200), write_gif=True, write_frames=True, time_per_frame=80, n_frames=180):
    '''
        mesh: Trimesh mesh object
        output_path: absolute path, ".gif" will get added if write_gif, and this will be used as dirname if write_frames is true
        time_per_frame: how many milliseconds to wait for each frame
        n_frames: how many frames in total
    '''
    
    # set material
    mat = pyrender.MetallicRoughnessMaterial(
            metallicFactor=0.8,
            roughnessFactor=1.0,
            alphaMode='OPAQUE',
            baseColorFactor=(0.5, 0.5, 0.8, 1.0),
        )  
    # define and add scene elements
    mesh = pyrender.Mesh.from_trimesh(mesh, material=mat)
    camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.0)
    light = pyrender.SpotLight(color=np.ones(3), intensity=15.0,
                               innerConeAngle=np.pi/4.0,
                               outerConeAngle=np.pi/4.0)

    scene = pyrender.Scene()
    obj = scene.add(mesh)
    cam = scene.add(camera)
    light = scene.add(light)

    positions, rotations = get_positions_and_rotations(n_frames=n_frames)

    r = pyrender.OffscreenRenderer(*resolution)
    
    # move the camera and generate images
    count = 0
    image_list = []
    for pos, rot in zip(positions, rotations):

        pose = np.eye(4)
        pose[:3, 3] = pos
        pose[:3,:3] = rot[:3,:3]
        
        scene.set_pose(cam, pose)
        scene.set_pose(light, pose)

        color, depth = r.render(scene)
        
        img = Image.fromarray(color, mode="RGB")
        image_list.append(img)
    
    # save to file
    if write_gif:
        image_list[0].save(f"{output_path}.gif", format='GIF', append_images=image_list[1:], save_all=True, duration=80, loop=0)

    if write_frames:
        if not os.path.exists(output_path):
            os.makedirs(output_path)

        for i, img in enumerate(image_list):
            img.save(os.path.join(output_path, f"{i:04d}.jpg"))

def get_base64_encoded_image(image_path):
    """
    Returns the base64-encoded image at the given path.
    
    Args:
    image_path (str): The path to the image file.
    
    Returns:
    str: The base64-encoded image.
    """
    with open(image_path, "rb") as f:
        img = Image.open(f)
        if img.mode == 'RGBA':
            img = img.convert('RGB')
        # Resize the image to reduce its file size
        img.thumbnail((200, 200))
        buffer = io.BytesIO()
        # Convert the image to JPEG format to reduce its file size
        img.save(buffer, format="JPEG", quality=80)
        return base64.b64encode(buffer.getvalue()).decode("utf-8")

def get_base64_encoded_gif(gif_path):
    """
    Returns the base64-encoded GIF at the given path.
    
    Args:
    gif_path (str): The path to the GIF file.
    
    Returns:
    str: The base64-encoded GIF.
    """
    with open(gif_path, "rb") as f:
        frames = imageio.mimread(f)
        # Reduce the number of frames to reduce the file size
        frames = frames[::4]
        buffer = io.BytesIO()
        # compress each image frame to reduce the file size
        frames = [frame[::2, ::2] for frame in frames]
        # Convert the GIF to a subrectangle format to reduce the file size
        imageio.mimsave(buffer, frames, format="GIF", fps=10, subrectangles=True)
        return base64.b64encode(buffer.getvalue()).decode("utf-8")

def create_gif_html(folder_path, html_file, skip_every=10):
    """
    Creates an HTML file with a grid of sample visualizations.
    
    Args:
    folder_path (str): The path to the folder containing the sample visualizations.
    html_file (str): The name of the HTML file to create.
    """
    # convert path to absolute path
    folder_path = os.path.abspath(folder_path)
    
    # Get a list of all the sample IDs
    ids = []
    count = 0
    all_files = sorted(os.listdir(folder_path), key=lambda x: int(x.split("_")[0]))
    for filename in all_files:
        if filename.endswith("_image_input.png"):
            if count % skip_every == 0:
                ids.append(filename.split("_")[0])
            count += 1

    # Write the HTML file
    with open(html_file, "w") as f:
        # Write the HTML header and CSS style
        f.write("<html>\n")
        f.write("<head>\n")
        f.write("<style>\n")
        f.write(".sample-container {\n")
        f.write("  display: inline-block;\n")
        f.write("  margin: 10px;\n")
        f.write("  width: 350px;\n")
        f.write("  height: 150px;\n")
        f.write("  text-align: center;\n")
        f.write("}\n")
        f.write(".sample-container:nth-child(6n+1) {\n")
        f.write("  clear: left;\n")
        f.write("}\n")
        f.write(".image-container, .gif-container {\n")
        f.write("  display: inline-block;\n")
        f.write("  margin: 10px;\n")
        f.write("  width: 90px;\n")
        f.write("  height: 90px;\n")
        f.write("  object-fit: cover;\n")
        f.write("}\n")
        f.write("</style>\n")
        f.write("</head>\n")
        f.write("<body>\n")
        
        # Write the sample visualizations to the HTML file
        for sample_id in ids:
            try:
                f.write("<div class=\"sample-container\">\n")
                f.write(f"<div class=\"sample-id\"><p>{sample_id}</p></div>\n")
                f.write(f"<div class=\"image-container\"><img src=\"data:image/png;base64,{get_base64_encoded_image(os.path.join(folder_path, sample_id + '_image_input.png'))}\" width=\"90\" height=\"90\"></div>\n")
                f.write(f"<div class=\"image-container\"><img src=\"data:image/png;base64,{get_base64_encoded_image(os.path.join(folder_path, sample_id + '_depth_est.png'))}\" width=\"90\" height=\"90\"></div>\n")
                f.write(f"<div class=\"gif-container\"><img src=\"data:image/gif;base64,{get_base64_encoded_gif(os.path.join(folder_path, sample_id + '_mesh_viz.gif'))}\" width=\"90\" height=\"90\"></div>\n")
                f.write("</div>\n")
            except:
                pass
        
        # Write the HTML footer
        f.write("</body>\n")
        f.write("</html>\n")