| import os |
| import sys |
| import traceback |
|
|
| from math import ceil |
|
|
| import PIL.Image |
| import torch |
| import distinctipy |
| import matplotlib.pyplot as plt |
| from PIL import Image |
| import numpy as np |
| import facer |
| import tyro |
|
|
| from pixel3dmm import env_paths |
|
|
| colors = distinctipy.get_colors(22, rng=0) |
|
|
|
|
| def viz_results(img, seq_classes, n_classes, suppress_plot = False): |
|
|
| seg_img = np.zeros([img.shape[-2], img.shape[-1], 3]) |
| |
| bad_indices = [ |
| 0, |
| 1, |
| |
| 3, |
| 4, |
| 5, |
| |
| |
| |
| |
| |
| |
| |
| |
| 14, |
| |
| 16, |
| 17, |
| 18, |
| ] |
| bad_indices = [] |
|
|
| for i in range(n_classes): |
| if i not in bad_indices: |
| seg_img[seq_classes[0, :, :] == i] = np.array(colors[i])*255 |
|
|
| if not suppress_plot: |
| plt.imshow(seg_img.astype(np.uint(8))) |
| plt.show() |
| return Image.fromarray(seg_img.astype(np.uint8)) |
|
|
| def get_color_seg(img, seq_classes, n_classes): |
|
|
| seg_img = np.zeros([img.shape[-2], img.shape[-1], 3]) |
| colors = distinctipy.get_colors(n_classes+1, rng=0) |
| |
| bad_indices = [ |
| 0, |
| 1, |
| |
| 3, |
| 4, |
| 5, |
| |
| |
| |
| |
| |
| |
| |
| |
| 14, |
| |
| 16, |
| 17, |
| 18, |
| ] |
|
|
| for i in range(n_classes): |
| if i not in bad_indices: |
| seg_img[seq_classes[0, :, :] == i] = np.array(colors[i])*255 |
|
|
|
|
| return Image.fromarray(seg_img.astype(np.uint8)) |
|
|
|
|
| def crop_gt_img(img, seq_classes, n_classes): |
|
|
| seg_img = np.zeros([img.shape[-2], img.shape[-1], 3]) |
| colors = distinctipy.get_colors(n_classes+1, rng=0) |
| |
| bad_indices = [ |
| 0, |
| 1, |
| |
| 3, |
| 4, |
| 5, |
| |
| |
| |
| |
| |
| |
| |
| |
| 14, |
| |
| 16, |
| 17, |
| 18, |
| ] |
|
|
| for i in range(n_classes): |
| if i in bad_indices: |
| img[seq_classes[0, :, :] == i] = 0 |
|
|
|
|
| |
| |
| return img.astype(np.uint8) |
|
|
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
| face_detector = facer.face_detector('retinaface/mobilenet', device=device) |
| face_parser = facer.face_parser('farl/celebm/448', device=device) |
|
|
|
|
| def main(video_name : str): |
|
|
|
|
| out = f'{env_paths.PREPROCESSED_DATA}/{video_name}' |
| out_seg = f'{out}/seg_og/' |
| out_seg_annot = f'{out}/seg_non_crop_annotations/' |
| os.makedirs(out_seg, exist_ok=True) |
| os.makedirs(out_seg_annot, exist_ok=True) |
| folder = f'{out}/cropped/' |
|
|
|
|
|
|
|
|
|
|
| frames = [f for f in os.listdir(folder) if f.endswith('.png') or f.endswith('.jpg')] |
|
|
| frames.sort() |
|
|
| if len(os.listdir(out_seg)) == len(frames): |
| print(f''' |
| <<<<<<<< ALREADY COMPLETED SEGMENTATION FOR {video_name}, SKIPPING >>>>>>>> |
| ''') |
| return |
|
|
| |
| batch_size = 1 |
|
|
| for i in range(len(frames)//batch_size): |
| image_stack = [] |
| frame_stack = [] |
| original_shapes = [] |
| for j in range(batch_size): |
| file = frames[i * batch_size + j] |
|
|
| if os.path.exists(f'{out_seg_annot}/color_{file}.png'): |
| print('DONE') |
| continue |
| img = Image.open(f'{folder}/{file}') |
|
|
| og_size = img.size |
|
|
| image = facer.hwc2bchw(torch.from_numpy(np.array(img)[..., :3])).to(device=device) |
| image_stack.append(image) |
| frame_stack.append(file[:-4]) |
|
|
| for batch_idx in range(ceil(len(image_stack)/batch_size)): |
| image_batch = torch.cat(image_stack[batch_idx*batch_size:(batch_idx+1)*batch_size], dim=0) |
| frame_idx_batch = frame_stack[batch_idx*batch_size:(batch_idx+1)*batch_size] |
| og_shape_batch = original_shapes[batch_idx*batch_size:(batch_idx+1)*batch_size] |
|
|
| |
| try: |
| with torch.inference_mode(): |
| faces = face_detector(image_batch) |
| torch.cuda.empty_cache() |
| faces = face_parser(image_batch, faces, bbox_scale_factor=1.25) |
| torch.cuda.empty_cache() |
|
|
| seg_logits = faces['seg']['logits'] |
| back_ground = torch.all(seg_logits == 0, dim=1, keepdim=True).detach().squeeze(1).cpu().numpy() |
| seg_probs = seg_logits.softmax(dim=1) |
| seg_classes = seg_probs.argmax(dim=1).detach().cpu().numpy().astype(np.uint8) |
| seg_classes[back_ground] = seg_probs.shape[1] + 1 |
|
|
|
|
| for _iidx in range(seg_probs.shape[0]): |
| frame = frame_idx_batch[_iidx] |
| iidx = faces['image_ids'][_iidx].item() |
| try: |
| I_color = viz_results(image_batch[iidx:iidx+1], seq_classes=seg_classes[_iidx:_iidx+1], n_classes=seg_probs.shape[1] + 1, suppress_plot=True) |
| I_color.save(f'{out_seg_annot}/color_{frame}.png') |
| except Exception as ex: |
| pass |
| I = Image.fromarray(seg_classes[_iidx]) |
| I.save(f'{out_seg}/{frame}.png') |
| torch.cuda.empty_cache() |
| except Exception as exx: |
| traceback.print_exc() |
| continue |
|
|
|
|
| if __name__ == '__main__': |
|
|
| tyro.cli(main) |
|
|
|
|