import numpy as np import os import torch import torchvision import torchvision.transforms.functional as torchvision_F import matplotlib.pyplot as plt import PIL import PIL.ImageDraw from PIL import Image, ImageFont import trimesh import pyrender import cv2 import copy import base64 import io import imageio os.environ['PYOPENGL_PLATFORM'] = 'egl' @torch.no_grad() def tb_image(opt, tb, step, group, name, images, masks=None, num_vis=None, from_range=(0, 1), poses=None, cmap="gray", depth=False): if not depth: images = preprocess_vis_image(opt, images, masks=masks, from_range=from_range, cmap=cmap) # [B, 3, H, W] else: masks = (masks > 0.5).float() images = images * masks + (1 - masks) * ((images * masks).max()) images = (1 - images).detach().cpu() num_H, num_W = num_vis or opt.tb.num_images images = images[:num_H*num_W] if poses is not None: # poses: [B, 3, 4] # rots: [max(B, num_images), 3, 3] rots = poses[:num_H*num_W, ..., :3] images = torch.stack([draw_pose(opt, image, rot, size=20, width=2) for image, rot in zip(images, rots)], dim=0) image_grid = torchvision.utils.make_grid(images[:, :3], nrow=num_W, pad_value=1.) if images.shape[1]==4: mask_grid = torchvision.utils.make_grid(images[:, 3:], nrow=num_W, pad_value=1.)[:1] image_grid =[image_grid, mask_grid], dim=0) tag = "{0}/{1}".format(group, name) tb.add_image(tag, image_grid, step) def preprocess_vis_image(opt, images, masks=None, from_range=(0, 1), cmap="gray"): min, max = from_range images = (images-min)/(max-min) if masks is not None: # then the mask is directly the transparency channel of png images =[images, masks], dim=1) images = images.clamp(min=0, max=1).cpu() if images.shape[1]==1: images = get_heatmap(opt, images[:, 0].cpu(), cmap=cmap) return images def preprocess_depth_image(opt, depth, mask=None, max_depth=1000): if mask is not None: depth = depth * mask + (1 - mask) * max_depth # min of this will leads to minimum of masked regions depth = depth - depth.min() if mask is not None: depth = depth * mask # max of this will leads to maximum of masked regions depth = depth / depth.max() return depth def dump_images(opt, idx, name, images, masks=None, from_range=(0, 1), poses=None, metrics=None, cmap="gray", folder='dump'): images = preprocess_vis_image(opt, images, masks=masks, from_range=from_range, cmap=cmap) # [B, 3, H, W] if poses is not None: rots = poses[..., :3] images = torch.stack([draw_pose(opt, image, rot, size=20, width=2) for image, rot in zip(images, rots)], dim=0) if metrics is not None: images = torch.stack([draw_metric(opt, image, metric.item()) for image, metric in zip(images, metrics)], dim=0) images = images.cpu().permute(0, 2, 3, 1).contiguous().numpy() # [B, H, W, 3] for i, img in zip(idx, images): fname = "{}/{}/{}_{}.png".format(opt.output_path, folder, i, name) img = Image.fromarray((img*255).astype(np.uint8)) def dump_depths(opt, idx, name, depths, masks=None, rescale=False, folder='dump'): if rescale: masks = (masks > 0.5).float() depths = depths * masks + (1 - masks) * ((depths * masks).max()) depths = (1 - depths).detach().cpu() for i, depth in zip(idx, depths): fname = "{}/{}/{}_{}.png".format(opt.output_path, folder, i, name) plt.imsave(fname, depth.squeeze(), cmap='viridis') # img_list is a list of length n_views, where each view is a image tensor of [B, 3, H, W] def dump_gifs(opt, idx, name, imgs_list, from_range=(0, 1), folder='dump', cmap="gray"): for i in range(len(imgs_list)): imgs_list[i] = preprocess_vis_image(opt, imgs_list[i], from_range=from_range, cmap=cmap) for i in range(len(idx)): img_list_np = [imgs[i].cpu().permute(1, 2, 0).contiguous().numpy() for imgs in imgs_list] # list of [H, W, 3], each item is a view of ith sample img_list_pil = [Image.fromarray((img*255).astype(np.uint8)).convert('RGB') for img in img_list_np] fname = "{}/{}/{}_{}.gif".format(opt.output_path, folder, idx[i], name) img_list_pil[0].save(fname, format='GIF', append_images=img_list_pil[1:], save_all=True, duration=100, loop=0) # img_list is a list of length n_views, where each view is a image tensor of [B, 3, H, W] def dump_attentions(opt, idx, name, attn_vis, folder='dump'): for i in range(len(idx)): img_list_pil = [Image.fromarray((img*255).astype(np.uint8)).convert('RGB') for img in attn_vis[i]] fname = "{}/{}/{}_{}.gif".format(opt.output_path, folder, idx[i], name) img_list_pil[0].save(fname, format='GIF', append_images=img_list_pil[1:], save_all=True, duration=50, loop=0) def get_heatmap(opt, gray, cmap): # [N, H, W] color = plt.get_cmap(cmap)(gray.numpy()) color = torch.from_numpy(color[..., :3]).permute(0, 3, 1, 2).contiguous().float() # [N, 3, H, W] return color def dump_meshes(opt, idx, name, meshes, folder='dump'): for i, mesh in zip(idx, meshes): fname = "{}/{}/{}_{}.ply".format(opt.output_path, folder, i, name) try: mesh.export(fname) except: print('Mesh is empty!') def dump_meshes_viz(opt, idx, name, meshes, save_frames=True, folder='dump'): for i, mesh in zip(idx, meshes): mesh = copy.deepcopy(mesh) R = trimesh.transformations.rotation_matrix(np.radians(180), [0,0,1]) mesh.apply_transform(R) R = trimesh.transformations.rotation_matrix(np.radians(180), [0,1,0]) mesh.apply_transform(R) # our marching cubes outputs inverted normals for some reason so this is necessary fname = "{}/{}/{}_{}".format(opt.output_path, folder, i, name) try: mesh = scale_to_unit_cube(mesh) visualize_mesh(mesh, fname, write_frames=save_frames) except: pass def dump_seen_surface(opt, idx, obj_name, img_name, seen_projs, folder='dump'): # seen_proj: [B, H, W, 3] for i, seen_proj in zip(idx, seen_projs): out_folder = "{}/{}".format(opt.output_path, folder) img_fname = "{}_{}.png".format(i, img_name) create_seen_surface(i, img_fname, seen_proj, out_folder, obj_name) # def create_seen_surface(sample_ID, img_path, XYZ, output_folder, obj_name, connect_thres=0.005): height, width = XYZ.shape[:2] XYZ_to_idx = {} idx = 1 with open("{}/{}_{}.mtl".format(output_folder, sample_ID, obj_name), "w") as f: f.write("newmtl material_0\n") f.write("Ka 0.200000 0.200000 0.200000\n") f.write("Kd 0.752941 0.752941 0.752941\n") f.write("Ks 1.000000 1.000000 1.000000\n") f.write("Tr 1.000000\n") f.write("illum 2\n") f.write("Ns 0.000000\n") f.write("map_Ka %s\n" % img_path) f.write("map_Kd %s\n" % img_path) with open("{}/{}_{}.obj".format(output_folder, sample_ID, obj_name), "w") as f: f.write("mtllib {}_{}.mtl\n".format(sample_ID, obj_name)) for y in range(height): for x in range(width): if XYZ[y][x][2] > 0: XYZ_to_idx[(y, x)] = idx idx += 1 f.write("v %.4f %.4f %.4f\n" % (XYZ[y][x][0], XYZ[y][x][1], XYZ[y][x][2])) f.write("vt %.8f %.8f\n" % ( float(x) / float(width), 1.0 - float(y) / float(height))) f.write("usemtl material_0\n") for y in range(height-1): for x in range(width-1): if XYZ[y][x][2] > 0 and XYZ[y][x+1][2] > 0 and XYZ[y+1][x][2] > 0: # if close enough, connect vertices to form a face if torch.norm(XYZ[y][x] - XYZ[y][x+1]).item() < connect_thres and torch.norm(XYZ[y][x] - XYZ[y+1][x]).item() < connect_thres: f.write("f %d/%d %d/%d %d/%d\n" % (XYZ_to_idx[(y, x)], XYZ_to_idx[(y, x)], XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y+1, x)], XYZ_to_idx[(y+1, x)])) if XYZ[y][x+1][2] > 0 and XYZ[y+1][x+1][2] > 0 and XYZ[y+1][x][2] > 0: if torch.norm(XYZ[y][x+1] - XYZ[y+1][x+1]).item() < connect_thres and torch.norm(XYZ[y][x+1] - XYZ[y+1][x]).item() < connect_thres: f.write("f %d/%d %d/%d %d/%d\n" % (XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y, x+1)], XYZ_to_idx[(y+1, x+1)], XYZ_to_idx[(y+1, x+1)], XYZ_to_idx[(y+1, x)], XYZ_to_idx[(y+1, x)])) def dump_pointclouds_compare(opt, idx, name, preds, gts, folder='dump'): for i in range(len(idx)): pred = preds[i].cpu().numpy() # [N1, 3] gt = gts[i].cpu().numpy() # [N2, 3] color_pred = np.zeros(pred.shape).astype(np.uint8) color_pred[:, 0] = 255 color_gt = np.zeros(gt.shape).astype(np.uint8) color_gt[:, 1] = 255 pc_vertices = np.vstack([pred, gt]) colors = np.vstack([color_pred, color_gt]) pc_color = trimesh.points.PointCloud(vertices=pc_vertices, colors=colors) fname = "{}/{}/{}_{}.ply".format(opt.output_path, folder, idx[i], name) pc_color.export(fname) def dump_pointclouds(opt, idx, name, pcs, colors, folder='dump', colormap='jet'): for i, pc, color in zip(idx, pcs, colors): pc = pc.cpu().numpy() # [N, 3] color = color.cpu().numpy() # [N, 3] or [N, 1] # convert scalar color to rgb with colormap if color.shape[1] == 1: # single channel color in numpy between [0, 1] to rgb color = plt.get_cmap(colormap)(color[:, 0]) color = (color * 255).astype(np.uint8) pc_color = trimesh.points.PointCloud(vertices=pc, colors=color) fname = "{}/{}/{}_{}.ply".format(opt.output_path, folder, i, name) pc_color.export(fname) @torch.no_grad() def vis_pointcloud(opt, vis, step, split, pred, GT=None): win_name = "{0}/{1}".format(, pred, GT = pred.cpu().numpy(), GT.cpu().numpy() for i in range(opt.visdom.num_samples): # prediction data = [dict( type="scatter3d", x=[float(n) for n in points[i, :opt.visdom.num_points, 0]], y=[float(n) for n in points[i, :opt.visdom.num_points, 1]], z=[float(n) for n in points[i, :opt.visdom.num_points, 2]], mode="markers", marker=dict( color=color, size=1, ), ) for points, color in zip([pred, GT], ["blue", "magenta"])] vis._send(dict( data=data, win="{0} #{1}".format(split, i), eid="{0}/{1}".format(,, layout=dict( title="{0} #{1} ({2})".format(split, i, step), autosize=True, margin=dict(l=30, r=30, b=30, t=30, ), showlegend=False, yaxis=dict( scaleanchor="x", scaleratio=1, ) ), opts=dict(title="{0} #{1} ({2})".format(win_name, i, step), ), )) @torch.no_grad() def draw_pose(opt, image, rot_mtrx, size=15, width=1): # rot_mtrx: [3, 4] mode = "RGBA" if image.shape[0]==4 else "RGB" image_pil = torchvision_F.to_pil_image(image.cpu()).convert("RGBA") draw_pil ="RGBA", image_pil.size, (0, 0, 0, 0)) draw = PIL.ImageDraw.Draw(draw_pil) center = (size, size) # first column of rotation matrix is the rotated vector of [1, 0, 0]' # second column of rotation matrix is the rotated vector of [0, 1, 0]' # third column of rotation matrix is the rotated vector of [0, 0, 1]' # then always take the first two element of each column is a projection to the 2D plane for visualization endpoint = [(size+size*p[0], size+size*p[1]) for p in rot_mtrx.t()] draw.line([center, endpoint[0]], fill=(255, 0, 0), width=width) draw.line([center, endpoint[1]], fill=(0, 255, 0), width=width) draw.line([center, endpoint[2]], fill=(0, 0, 255), width=width) image_pil.alpha_composite(draw_pil) image_drawn = torchvision_F.to_tensor(image_pil.convert(mode)) return image_drawn @torch.no_grad() def draw_metric(opt, image, metric): mode = "RGBA" if image.shape[0]==4 else "RGB" image_pil = torchvision_F.to_pil_image(image.cpu()).convert("RGBA") draw_pil ="RGBA", image_pil.size, (0, 0, 0, 0)) draw = PIL.ImageDraw.Draw(draw_pil) font = ImageFont.truetype("DejaVuSans.ttf", 24) position = (image_pil.size[0] - 80, image_pil.size[1] - 35) draw.text(position, '{:.3f}'.format(metric), fill="red", font=font) image_pil.alpha_composite(draw_pil) image_drawn = torchvision_F.to_tensor(image_pil.convert(mode)) return image_drawn @torch.no_grad() def show_att_on_image(img, mask): """ Convert the grayscale attention into heatmap on the image. Parameters ---------- img: np.array, [H, W, 3] Original colored image in [0, 1]. mask: np.array, [H, W] Attention map in [0, 1]. Returns ---------- np image with attention applied. """ # check the validity assert np.max(img) <= 1 assert np.max(mask) <= 1 # generate heatmap and normalize into [0, 1] heatmap = cv2.cvtColor(cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB) heatmap = np.float32(heatmap) / 255 # add heatmap onto the image merged = heatmap + np.float32(img) # re-scale the image merged = merged / np.max(merged) return merged def look_at(camera_position, camera_target, up_vector): vector = camera_position - camera_target vector = vector / np.linalg.norm(vector) vector2 = np.cross(up_vector, vector) vector2 = vector2 / np.linalg.norm(vector2) vector3 = np.cross(vector, vector2) return np.array([ [vector2[0], vector3[0], vector[0], 0.0], [vector2[1], vector3[1], vector[1], 0.0], [vector2[2], vector3[2], vector[2], 0.0], [, camera_position),, camera_position),, camera_position), 1.0] ]) def scale_to_unit_cube(mesh): if isinstance(mesh, trimesh.Scene): mesh = mesh.dump().sum() vertices = mesh.vertices - mesh.bounding_box.centroid vertices *= 2 / np.max(mesh.bounding_box.extents) vertices *= 0.5 return trimesh.Trimesh(vertices=vertices, faces=mesh.faces) def get_positions_and_rotations(n_frames=180, r=1.5): ''' n_frames: how many frames r: how far should the camera be ''' # test case 1 n_frame_full_circ = n_frames // 3 # frames for a full circle n_frame_half_circ = n_frames // 6 # frames for a half circle # full circle in horizontal axes going from 1 to -1 height axis pos1 = [np.array([r*np.cos(theta), elev, r*np.sin(theta)]) for theta, elev in zip(np.linspace(0.5*np.pi,2.5*np.pi, n_frame_full_circ), np.linspace(1,-1,n_frame_full_circ))] # half circle in horizontal axes at fixed -1 height pos2 = [np.array([r*np.cos(theta), -1, r*np.sin(theta)]) for theta in np.linspace(2.5*np.pi,3.5*np.pi, n_frame_half_circ)] # full circle in horizontal axes going from -1 to 1 height axis pos3 = [np.array([r*np.cos(theta), elev, r*np.sin(theta)]) for theta, elev in zip(np.linspace(3.5*np.pi,5.5*np.pi, n_frame_full_circ), np.linspace(-1,1,n_frame_full_circ))] # half circle in horizontal axes at fixed 1 height pos4 = [np.array([r*np.cos(theta), 1, r*np.sin(theta)]) for theta in np.linspace(3.5*np.pi,4.5*np.pi, n_frame_half_circ)] pos = pos1 + pos2 + pos3 + pos4 target = np.array([0.0, 0.0, 0.0]) up = np.array([0.0, 1.0, 0.0]) rot = [look_at(x, target, up) for x in pos] return pos, rot def visualize_mesh(mesh, output_path, resolution=(200,200), write_gif=True, write_frames=True, time_per_frame=80, n_frames=180): ''' mesh: Trimesh mesh object output_path: absolute path, ".gif" will get added if write_gif, and this will be used as dirname if write_frames is true time_per_frame: how many milliseconds to wait for each frame n_frames: how many frames in total ''' # set material mat = pyrender.MetallicRoughnessMaterial( metallicFactor=0.8, roughnessFactor=1.0, alphaMode='OPAQUE', baseColorFactor=(0.5, 0.5, 0.8, 1.0), ) # define and add scene elements mesh = pyrender.Mesh.from_trimesh(mesh, material=mat) camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.0) light = pyrender.SpotLight(color=np.ones(3), intensity=15.0, innerConeAngle=np.pi/4.0, outerConeAngle=np.pi/4.0) scene = pyrender.Scene() obj = scene.add(mesh) cam = scene.add(camera) light = scene.add(light) positions, rotations = get_positions_and_rotations(n_frames=n_frames) r = pyrender.OffscreenRenderer(*resolution) # move the camera and generate images count = 0 image_list = [] for pos, rot in zip(positions, rotations): pose = np.eye(4) pose[:3, 3] = pos pose[:3,:3] = rot[:3,:3] scene.set_pose(cam, pose) scene.set_pose(light, pose) color, depth = r.render(scene) img = Image.fromarray(color, mode="RGB") image_list.append(img) # save to file if write_gif: image_list[0].save(f"{output_path}.gif", format='GIF', append_images=image_list[1:], save_all=True, duration=80, loop=0) if write_frames: if not os.path.exists(output_path): os.makedirs(output_path) for i, img in enumerate(image_list):, f"{i:04d}.jpg")) def get_base64_encoded_image(image_path): """ Returns the base64-encoded image at the given path. Args: image_path (str): The path to the image file. Returns: str: The base64-encoded image. """ with open(image_path, "rb") as f: img = if img.mode == 'RGBA': img = img.convert('RGB') # Resize the image to reduce its file size img.thumbnail((200, 200)) buffer = io.BytesIO() # Convert the image to JPEG format to reduce its file size, format="JPEG", quality=80) return base64.b64encode(buffer.getvalue()).decode("utf-8") def get_base64_encoded_gif(gif_path): """ Returns the base64-encoded GIF at the given path. Args: gif_path (str): The path to the GIF file. Returns: str: The base64-encoded GIF. """ with open(gif_path, "rb") as f: frames = imageio.mimread(f) # Reduce the number of frames to reduce the file size frames = frames[::4] buffer = io.BytesIO() # compress each image frame to reduce the file size frames = [frame[::2, ::2] for frame in frames] # Convert the GIF to a subrectangle format to reduce the file size imageio.mimsave(buffer, frames, format="GIF", fps=10, subrectangles=True) return base64.b64encode(buffer.getvalue()).decode("utf-8") def create_gif_html(folder_path, html_file, skip_every=10): """ Creates an HTML file with a grid of sample visualizations. Args: folder_path (str): The path to the folder containing the sample visualizations. html_file (str): The name of the HTML file to create. """ # convert path to absolute path folder_path = os.path.abspath(folder_path) # Get a list of all the sample IDs ids = [] count = 0 all_files = sorted(os.listdir(folder_path), key=lambda x: int(x.split("_")[0])) for filename in all_files: if filename.endswith("_image_input.png"): if count % skip_every == 0: ids.append(filename.split("_")[0]) count += 1 # Write the HTML file with open(html_file, "w") as f: # Write the HTML header and CSS style f.write("\n") f.write("\n") f.write("\n") f.write("\n") f.write("\n") # Write the sample visualizations to the HTML file for sample_id in ids: try: f.write("
\n") f.write(f"


\n") f.write(f"
\n") f.write(f"
\n") f.write(f"
\n") f.write("
\n") except: pass # Write the HTML footer f.write("\n") f.write("\n")