| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | from vhap.config.base import import_module, PhotometricStageConfig, BaseTrackingConfig |
| | from vhap.model.flame import FlameHead, FlameTexPCA, FlameTexPainted, FlameUvMask |
| | from vhap.model.lbs import batch_rodrigues |
| | from vhap.util.mesh import ( |
| | get_mtl_content, |
| | get_obj_content, |
| | normalize_image_points, |
| | ) |
| | from vhap.util.log import get_logger |
| | from vhap.util.visualization import plot_landmarks_2d |
| |
|
| | from torch.utils.tensorboard import SummaryWriter |
| | import torch |
| | import torchvision |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader |
| | import numpy as np |
| | from matplotlib import cm |
| | from typing import Literal |
| | from functools import partial |
| | import tyro |
| | import yaml |
| | from datetime import datetime |
| | import threading |
| | from typing import Optional |
| | from collections import defaultdict |
| | from copy import deepcopy |
| | import time |
| | import os |
| |
|
| |
|
| | class FlameTracker: |
| | def __init__(self, cfg: BaseTrackingConfig): |
| | self.cfg = cfg |
| | |
| | self.device = cfg.device |
| | self.tb_writer = None |
| |
|
| | |
| | self.flame = FlameHead( |
| | cfg.model.n_shape, |
| | cfg.model.n_expr, |
| | add_teeth=cfg.model.add_teeth, |
| | remove_lip_inside=cfg.model.remove_lip_inside, |
| | face_clusters=cfg.model.tex_clusters, |
| | ).to(self.device) |
| |
|
| | if cfg.model.tex_painted: |
| | self.flame_tex_painted = FlameTexPainted(tex_size=cfg.model.tex_resolution).to(self.device) |
| | else: |
| | self.flame_tex_pca = FlameTexPCA(cfg.model.n_tex, tex_size=cfg.model.tex_resolution).to(self.device) |
| |
|
| | self.flame_uvmask = FlameUvMask().to(self.device) |
| |
|
| | |
| | if self.cfg.render.backend == 'nvdiffrast': |
| | from vhap.util.render_nvdiffrast import NVDiffRenderer |
| |
|
| | self.render = NVDiffRenderer( |
| | use_opengl=self.cfg.render.use_opengl, |
| | lighting_type=self.cfg.render.lighting_type, |
| | lighting_space=self.cfg.render.lighting_space, |
| | disturb_rate_fg=self.cfg.render.disturb_rate_fg, |
| | disturb_rate_bg=self.cfg.render.disturb_rate_bg, |
| | fid2cid=self.flame.mask.fid2cid, |
| | ) |
| | elif self.cfg.render.backend == 'pytorch3d': |
| | from vhap.util.render_pytorch3d import PyTorch3DRenderer |
| |
|
| | self.render = PyTorch3DRenderer() |
| | else: |
| | raise NotImplementedError(f"Unknown renderer backend: {self.cfg.render.backend}") |
| | |
| | def load_from_tracked_flame_params(self, fp): |
| | """ |
| | loads checkpoint from tracked_flame_params file. Counterpart to save_result() |
| | :param fp: |
| | :return: |
| | """ |
| | report = np.load(fp) |
| |
|
| | |
| | def load_param(param, ckpt_array): |
| | param.data[:] = torch.from_numpy(ckpt_array).to(param.device) |
| |
|
| | def load_param_list(param_list, ckpt_array): |
| | for i in range(min(len(param_list), len(ckpt_array))): |
| | load_param(param_list[i], ckpt_array[i]) |
| |
|
| | load_param_list(self.rotation, report["rotation"]) |
| | load_param_list(self.translation, report["translation"]) |
| | load_param_list(self.neck_pose, report["neck_pose"]) |
| | load_param_list(self.jaw_pose, report["jaw_pose"]) |
| | load_param_list(self.eyes_pose, report["eyes_pose"]) |
| | load_param(self.shape, report["shape"]) |
| | load_param_list(self.expr, report["expr"]) |
| | load_param(self.lights, report["lights"]) |
| | |
| | if not self.calibrated: |
| | load_param(self.focal_length, report["focal_length"]) |
| | |
| | if not self.cfg.model.tex_painted: |
| | if "tex" in report: |
| | load_param(self.tex_pca, report["tex"]) |
| | else: |
| | self.logger.warn("No tex_extra found in flame_params!") |
| | |
| | if self.cfg.model.tex_extra: |
| | if "tex_extra" in report: |
| | load_param(self.tex_extra, report["tex_extra"]) |
| | else: |
| | self.logger.warn("No tex_extra found in flame_params!") |
| | |
| | if self.cfg.model.use_static_offset: |
| | if "static_offset" in report: |
| | load_param(self.static_offset, report["static_offset"]) |
| | else: |
| | self.logger.warn("No static_offset found in flame_params!") |
| |
|
| | if self.cfg.model.use_dynamic_offset: |
| | if "dynamic_offset" in report: |
| | load_param_list(self.dynamic_offset, report["dynamic_offset"]) |
| | else: |
| | self.logger.warn("No dynamic_offset found in flame_params!") |
| |
|
| | def trimmed_decays(self, is_init): |
| | decays = {} |
| | for k, v in self.decays.items(): |
| | if is_init and "init" in k or not is_init and "init" not in k: |
| | decays[k.replace("_init", "")] = v |
| | return decays |
| |
|
| | def clear_cache(self): |
| | self.render.clear_cache() |
| |
|
| | def get_current_frame(self, frame_idx, include_keyframes=False): |
| | """ |
| | Creates a single item batch from the frame data at index frame_idx in the dataset. |
| | If include_keyframes option is set, keyframe data will be appended to the batch. However, |
| | it is guaranteed that the frame data belonging to frame_idx is at position 0 |
| | :param frame_idx: |
| | :return: |
| | """ |
| | indices = [frame_idx] |
| | if include_keyframes: |
| | indices += self.cfg.exp.keyframes |
| |
|
| | samples = [] |
| | for idx in indices: |
| | sample = self.dataset.getitem_by_timestep(idx) |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | samples.append(sample) |
| |
|
| | |
| | sample = {} |
| | for k, v in samples[0].items(): |
| | values = [s[k] for s in samples] |
| | if isinstance(v, torch.Tensor): |
| | values = torch.cat(values, dim=0) |
| | sample[k] = values |
| |
|
| | if "lmk2d_iris" in sample: |
| | sample["lmk2d"] = torch.cat([sample["lmk2d"], sample["lmk2d_iris"]], dim=1) |
| | return sample |
| |
|
| | def fill_cam_params_into_sample(self, sample): |
| | """ |
| | Adds intrinsics and extrinics to sample, if data is not calibrated |
| | """ |
| | if self.calibrated: |
| | assert "intrinsic" in sample |
| | assert "extrinsic" in sample |
| | else: |
| | b, _, h, w = sample["rgb"].shape |
| | |
| |
|
| | |
| | f = self.focal_length * max(h, w) |
| | cx, cy = torch.tensor([[0.5*w], [0.5*h]]).to(f) |
| |
|
| | sample["intrinsic"] = torch.stack([f, f, cx, cy], dim=1) |
| | sample["extrinsic"] = self.RT[None, ...].expand(b, -1, -1) |
| |
|
| | def configure_optimizer(self, params, lr_scale=1.0): |
| | """ |
| | Creates optimizer for the given set of parameters |
| | :param params: |
| | :return: |
| | """ |
| | |
| | params = params.copy() |
| | param_groups = [] |
| | default_lr = self.cfg.lr.base |
| |
|
| | |
| | group_def = { |
| | "translation": ["translation"], |
| | "expr": ["expr"], |
| | "light": ["lights"], |
| | } |
| | if not self.calibrated: |
| | group_def ["cam"] = ["cam"] |
| | if self.cfg.model.use_static_offset: |
| | group_def ["static_offset"] = ["static_offset"] |
| | if self.cfg.model.use_dynamic_offset: |
| | group_def ["dynamic_offset"] = ["dynamic_offset"] |
| |
|
| | |
| | group_lr = { |
| | "translation": self.cfg.lr.translation, |
| | "expr": self.cfg.lr.expr, |
| | "light": self.cfg.lr.light, |
| | } |
| | if not self.calibrated: |
| | group_lr["cam"] = self.cfg.lr.camera |
| | if self.cfg.model.use_static_offset: |
| | group_lr["static_offset"] = self.cfg.lr.static_offset |
| | if self.cfg.model.use_dynamic_offset: |
| | group_lr["dynamic_offset"] = self.cfg.lr.dynamic_offset |
| |
|
| | for group_name, param_keys in group_def.items(): |
| | selected = [] |
| | for p in param_keys: |
| | if p in params: |
| | selected += params.pop(p) |
| | if len(selected) > 0: |
| | param_groups.append({"params": selected, "lr": group_lr[group_name] * lr_scale}) |
| |
|
| | |
| | selected = [] |
| | for _, v in params.items(): |
| | selected += v |
| | param_groups.append({"params": selected}) |
| |
|
| | optim = torch.optim.Adam(param_groups, lr=default_lr * lr_scale) |
| | return optim |
| |
|
| | def initialize_frame(self, frame_idx): |
| | """ |
| | Initializes parameters of frame frame_idx |
| | :param frame_idx: |
| | :return: |
| | """ |
| | if frame_idx > 0: |
| | self.initialize_from_previous(frame_idx) |
| |
|
| | def initialize_from_previous(self, frame_idx): |
| | """ |
| | Initializes the flame parameters with the optimized ones from the previous frame |
| | :param frame_idx: |
| | :return: |
| | """ |
| | if frame_idx == 0: |
| | return |
| |
|
| | param_list = [ |
| | self.expr, |
| | self.neck_pose, |
| | self.jaw_pose, |
| | self.translation, |
| | self.rotation, |
| | self.eyes_pose, |
| | ] |
| |
|
| | for param in param_list: |
| | param[frame_idx].data = param[frame_idx - 1].detach().clone().data |
| |
|
| | def select_frame_indices(self, frame_idx, include_keyframes): |
| | indices = [frame_idx] |
| | if include_keyframes: |
| | indices += self.cfg.exp.keyframes |
| | return indices |
| |
|
| | def forward_flame(self, frame_idx, include_keyframes): |
| | """ |
| | Evaluates the flame model using the given parameters |
| | :param flame_params: |
| | :return: |
| | """ |
| | indices = self.select_frame_indices(frame_idx, include_keyframes) |
| |
|
| | dynamic_offset = self.to_batch(self.dynamic_offset, indices) if self.cfg.model.use_dynamic_offset else None |
| |
|
| | ret = self.flame( |
| | self.shape[None, ...].expand(len(indices), -1), |
| | self.to_batch(self.expr, indices), |
| | self.to_batch(self.rotation, indices), |
| | self.to_batch(self.neck_pose, indices), |
| | self.to_batch(self.jaw_pose, indices), |
| | self.to_batch(self.eyes_pose, indices), |
| | self.to_batch(self.translation, indices), |
| | return_verts_cano=True, |
| | static_offset=self.static_offset, |
| | dynamic_offset=dynamic_offset, |
| | ) |
| | verts, verts_cano, lmks = ret[0], ret[1], ret[2] |
| | albedos = self.get_albedo().expand(len(indices), -1, -1, -1) |
| | return verts, verts_cano, lmks, albedos |
| | |
| | def get_base_texture(self): |
| | if self.cfg.model.tex_extra and not self.cfg.model.residual_tex: |
| | albedos_base = self.tex_extra[None, ...] |
| | else: |
| | if self.cfg.model.tex_painted: |
| | albedos_base = self.flame_tex_painted() |
| | else: |
| | albedos_base = self.flame_tex_pca(self.tex_pca[None, :]) |
| | return albedos_base |
| | |
| | def get_albedo(self): |
| | albedos_base = self.get_base_texture() |
| |
|
| | if self.cfg.model.tex_extra and self.cfg.model.residual_tex: |
| | albedos_res = self.tex_extra[None, :] |
| | if albedos_base.shape[-1] != albedos_res.shape[-1] or albedos_base.shape[-2] != albedos_res.shape[-2]: |
| | albedos_base = F.interpolate(albedos_base, albedos_res.shape[-2:], mode='bilinear') |
| | albedos = albedos_base + albedos_res |
| | else: |
| | albedos = albedos_base |
| |
|
| | return albedos |
| |
|
| | def rasterize_flame( |
| | self, sample, verts, faces, camera_index=None, train_mode=False |
| | ): |
| | """ |
| | Rasterizes the flame head mesh |
| | :param verts: |
| | :param albedos: |
| | :param K: |
| | :param RT: |
| | :param resolution: |
| | :param use_cache: |
| | :return: |
| | """ |
| | |
| | K = sample["intrinsic"].clone().to(self.device) |
| | RT = sample["extrinsic"].to(self.device) |
| | if camera_index is not None: |
| | K = K[[camera_index]] |
| | RT = RT[[camera_index]] |
| |
|
| | H, W = self.image_size |
| | image_size = H, W |
| | |
| | |
| | rast_dict = self.render.rasterize(verts, faces, RT, K, image_size, False, train_mode) |
| | return rast_dict |
| |
|
| | @torch.no_grad() |
| | def get_background_color(self, gt_rgb, gt_alpha, stage): |
| | if stage is None: |
| | background = self.cfg.render.background_eval |
| | else: |
| | background = self.cfg.render.background_train |
| |
|
| | if background == 'target': |
| | """use gt_rgb as background""" |
| | color = gt_rgb.permute(0, 2, 3, 1) |
| | elif background == 'white': |
| | color = [1, 1, 1] |
| | elif background == 'black': |
| | color = [0, 0, 0] |
| | else: |
| | raise NotImplementedError(f"Unknown background mode: {background}") |
| | return color |
| | |
| | def render_rgba( |
| | self, rast_dict, verts, faces, albedos, lights, background_color=[1, 1, 1], |
| | align_texture_except_fid=None, align_boundary_except_vid=None, enable_disturbance=False, |
| | ): |
| | """ |
| | Renders the rgba image from the rasterization result and |
| | the optimized texture + lights |
| | """ |
| | faces_uv = self.flame.textures_idx |
| | if self.cfg.render.backend == 'nvdiffrast': |
| | verts_uv = self.flame.verts_uvs.clone() |
| | verts_uv[:, 1] = 1 - verts_uv[:, 1] |
| | tex = albedos |
| |
|
| | render_out = self.render.render_rgba( |
| | rast_dict, verts, faces, verts_uv, faces_uv, tex, lights, background_color, |
| | align_texture_except_fid, align_boundary_except_vid, enable_disturbance |
| | ) |
| | render_out = {k: v.permute(0, 3, 1, 2) for k, v in render_out.items()} |
| | elif self.cfg.render.backend == 'pytorch3d': |
| | B = verts.shape[0] |
| | verts_uv = self.flame.face_uvcoords.repeat(B, 1, 1) |
| | tex = albedos.expand(B, -1, -1, -1) |
| |
|
| | rgba = self.render.render_rgba( |
| | rast_dict, verts, faces, verts_uv, faces_uv, tex, lights, background_color |
| | ) |
| | render_out = {'rgba': rgba.permute(0, 3, 1, 2)} |
| | else: |
| | raise NotImplementedError(f"Unknown renderer backend: {self.cfg.render.backend}") |
| | |
| | return render_out |
| |
|
| | def render_normal(self, rast_dict, verts, faces): |
| | """ |
| | Renders the rgba image from the rasterization result and |
| | the optimized texture + lights |
| | """ |
| | uv_coords = self.flame.face_uvcoords |
| | uv_coords = uv_coords.repeat(verts.shape[0], 1, 1) |
| | return self.render.render_normal(rast_dict, verts, faces, uv_coords) |
| |
|
| | def compute_lmk_energy(self, sample, pred_lmks, disable_jawline_landmarks=False): |
| | """ |
| | Computes the landmark energy loss term between groundtruth landmarks and flame landmarks |
| | :param sample: |
| | :param pred_lmks: |
| | :return: the lmk loss for all 68 facial landmarks, a separate 2 pupil landmark loss and |
| | a relative eye close term |
| | """ |
| | img_size = sample["rgb"].shape[-2:] |
| |
|
| | |
| | lmk2d = sample["lmk2d"].clone().to(pred_lmks) |
| | lmk2d, confidence = lmk2d[:, :, :2], lmk2d[:, :, 2] |
| | lmk2d[:, :, 0], lmk2d[:, :, 1] = normalize_image_points( |
| | lmk2d[:, :, 0], lmk2d[:, :, 1], img_size |
| | ) |
| |
|
| | |
| | K = sample["intrinsic"].to(self.device) |
| | RT = sample["extrinsic"].to(self.device) |
| | pred_lmk_ndc = self.render.world_to_ndc(pred_lmks, RT, K, img_size, flip_y=True) |
| | pred_lmk2d = pred_lmk_ndc[:, :, :2] |
| |
|
| | if (lmk2d.shape[1] == 70): |
| | diff = lmk2d - pred_lmk2d |
| | confidence = confidence[:, :70] |
| | |
| | confidence[:, 68:] = confidence[:, 68:] * 2 |
| | else: |
| | diff = lmk2d[:, :68] - pred_lmk2d[:, :68] |
| | confidence = confidence[:, :68] |
| |
|
| | |
| | lmk_loss = torch.norm(diff, dim=2, p=1) * confidence |
| |
|
| | result_dict = { |
| | "gt_lmk2d": lmk2d, |
| | "pred_lmk2d": pred_lmk2d, |
| | } |
| |
|
| | return lmk_loss.mean(), result_dict |
| |
|
| | def compute_photometric_energy( |
| | self, |
| | sample, |
| | verts, |
| | faces, |
| | albedos, |
| | rast_dict, |
| | step_i=None, |
| | stage=None, |
| | include_keyframes=False, |
| | ): |
| | """ |
| | Computes the dense photometric energy |
| | :param sample: |
| | :param vertices: |
| | :param albedos: |
| | :return: |
| | """ |
| | gt_rgb = sample["rgb"].to(verts) |
| | if "alpha" in sample: |
| | gt_alpha = sample["alpha_map"].to(verts) |
| | else: |
| | gt_alpha = None |
| |
|
| | lights = self.lights[None] if self.lights is not None else None |
| | bg_color = self.get_background_color(gt_rgb, gt_alpha, stage) |
| |
|
| | align_texture_except_fid = self.flame.mask.get_fid_by_region( |
| | self.cfg.pipeline[stage].align_texture_except |
| | ) if stage is not None else None |
| | align_boundary_except_vid = self.flame.mask.get_vid_by_region( |
| | self.cfg.pipeline[stage].align_boundary_except |
| | ) if stage is not None else None |
| |
|
| | render_out = self.render_rgba( |
| | rast_dict, verts, faces, albedos, lights, bg_color, |
| | align_texture_except_fid, align_boundary_except_vid, |
| | enable_disturbance=stage!=None, |
| | ) |
| |
|
| | pred_rgb = render_out['rgba'][:, :3] |
| | pred_alpha = render_out['rgba'][:, 3:] |
| | pred_mask = render_out['rgba'][:, [3]].detach() > 0 |
| | pred_mask = pred_mask.expand(-1, 3, -1, -1) |
| |
|
| | results_dict = render_out |
| |
|
| | |
| | error_rgb = gt_rgb - pred_rgb |
| | color_loss = error_rgb.abs().sum() / pred_mask.detach().sum() |
| |
|
| | results_dict.update( |
| | { |
| | "gt_rgb": gt_rgb, |
| | "pred_rgb": pred_rgb, |
| | "error_rgb": error_rgb, |
| | "pred_alpha": pred_alpha, |
| | } |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | photo_loss = color_loss |
| | |
| | return photo_loss, results_dict |
| | |
| | def compute_regularization_energy(self, result_dict, verts, verts_cano, lmks, albedos, frame_idx, include_keyframes, stage): |
| | """ |
| | Computes the energy term that penalizes strong deviations from the flame base model |
| | """ |
| | log_dict = {} |
| | |
| | std_tex = 1 |
| | std_expr = 1 |
| | std_shape = 1 |
| |
|
| | indices = self.select_frame_indices(frame_idx, include_keyframes) |
| |
|
| | |
| | if self.opt_dict['pose'] and 'tracking' in stage: |
| | E_pose_smooth = self.compute_pose_smooth_energy(frame_idx, stage=='global_tracking') |
| | log_dict["pose_smooth"] = E_pose_smooth |
| |
|
| | |
| | if self.opt_dict['joints']: |
| | if 'tracking' in stage: |
| | joint_smooth = self.compute_joint_smooth_energy(frame_idx, stage=='global_tracking') |
| | log_dict["joint_smooth"] = joint_smooth |
| |
|
| | joint_prior = self.compute_joint_prior_energy(frame_idx) |
| | log_dict["joint_prior"] = joint_prior |
| |
|
| | |
| | if self.opt_dict['expr']: |
| | expr = self.to_batch(self.expr, indices) |
| | reg_expr = (expr / std_expr) ** 2 |
| | log_dict["reg_expr"] = self.cfg.w.reg_expr * reg_expr.mean() |
| |
|
| | |
| | if self.opt_dict['shape']: |
| | reg_shape = (self.shape / std_shape) ** 2 |
| | log_dict["reg_shape"] = self.cfg.w.reg_shape * reg_shape.mean() |
| |
|
| | |
| | if self.opt_dict['texture']: |
| | |
| | if not self.cfg.model.tex_painted: |
| | reg_tex_pca = (self.tex_pca / std_tex) ** 2 |
| | log_dict["reg_tex_pca"] = self.cfg.w.reg_tex_pca * reg_tex_pca.mean() |
| |
|
| | |
| | if self.cfg.model.tex_extra: |
| | if self.cfg.model.residual_tex: |
| | if self.cfg.w.reg_tex_res is not None: |
| | reg_tex_res = self.tex_extra ** 2 |
| | |
| |
|
| | |
| | |
| | |
| | log_dict["reg_tex_res"] = self.cfg.w.reg_tex_res * reg_tex_res.mean() |
| | |
| | if self.cfg.w.reg_tex_tv is not None: |
| | tex = self.get_albedo()[0] |
| | tv_y = (tex[..., :-1, :] - tex[..., 1:, :]) ** 2 |
| | tv_x = (tex[..., :, :-1] - tex[..., :, 1:]) ** 2 |
| | tv = tv_y.reshape(tv_y.shape[0], -1) + tv_x.reshape(tv_x.shape[0], -1) |
| | w_reg_tex_tv = self.cfg.w.reg_tex_tv * self.cfg.data.scale_factor ** 2 |
| | if self.cfg.data.n_downsample_rgb is not None: |
| | w_reg_tex_tv /= (self.cfg.data.n_downsample_rgb ** 2) |
| | log_dict["reg_tex_tv"] = w_reg_tex_tv * tv.mean() |
| | |
| | if self.cfg.w.reg_tex_res_clusters is not None: |
| | mask_sclerae = self.flame_uvmask.get_uvmask_by_region(self.cfg.w.reg_tex_res_for)[None, :, :] |
| | reg_tex_res_clusters = self.tex_extra ** 2 * mask_sclerae |
| | log_dict["reg_tex_res_clusters"] = self.cfg.w.reg_tex_res_clusters * reg_tex_res_clusters.mean() |
| |
|
| | |
| | if self.opt_dict['lights']: |
| | if self.cfg.w.reg_light is not None and self.lights is not None: |
| | reg_light = (self.lights - self.lights_uniform) ** 2 |
| | log_dict["reg_light"] = self.cfg.w.reg_light * reg_light.mean() |
| | |
| | if self.cfg.w.reg_diffuse is not None and self.lights is not None: |
| | diffuse = result_dict['diffuse_detach_normal'] |
| | reg_diffuse = F.relu(diffuse.max() - 1) + diffuse.var(dim=1).mean() |
| | log_dict["reg_diffuse"] = self.cfg.w.reg_diffuse * reg_diffuse |
| | |
| | |
| | if self.opt_dict['static_offset'] or self.opt_dict['dynamic_offset']: |
| | if self.static_offset is not None or self.dynamic_offset is not None: |
| | offset = 0 |
| | if self.static_offset is not None: |
| | offset += self.static_offset |
| | if self.dynamic_offset is not None: |
| | offset += self.to_batch(self.dynamic_offset, indices) |
| |
|
| | if self.cfg.w.reg_offset_lap is not None: |
| | |
| | vert_wo_offset = (verts_cano - offset).detach() |
| | reg_offset_lap = self.compute_laplacian_smoothing_loss( |
| | vert_wo_offset, vert_wo_offset + offset |
| | ) |
| | if len(self.cfg.w.reg_offset_lap_relax_for) > 0: |
| | w = self.scale_vertex_weights_by_region( |
| | weights=torch.ones_like(verts[:, :, :1]), |
| | scale_factor=self.cfg.w.reg_offset_lap_relax_coef, |
| | region=self.cfg.w.reg_offset_lap_relax_for, |
| | ) |
| | reg_offset_lap *= w |
| | log_dict["reg_offset_lap"] = self.cfg.w.reg_offset_lap * reg_offset_lap.mean() |
| |
|
| | if self.cfg.w.reg_offset is not None: |
| | |
| | |
| | reg_offset = offset.abs() |
| | if len(self.cfg.w.reg_offset_relax_for) > 0: |
| | w = self.scale_vertex_weights_by_region( |
| | weights=torch.ones_like(verts[:, :, :1]), |
| | scale_factor=self.cfg.w.reg_offset_relax_coef, |
| | region=self.cfg.w.reg_offset_relax_for, |
| | ) |
| | reg_offset *= w |
| | log_dict["reg_offset"] = self.cfg.w.reg_offset * reg_offset.mean() |
| | |
| | if self.cfg.w.reg_offset_rigid is not None: |
| | reg_offset_rigid = 0 |
| | for region in self.cfg.w.reg_offset_rigid_for: |
| | vids = self.flame.mask.get_vid_by_region([region]) |
| | reg_offset_rigid += offset[:, vids, :].var(dim=-2).mean() |
| | log_dict["reg_offset_rigid"] = self.cfg.w.reg_offset_rigid * reg_offset_rigid |
| |
|
| | if self.cfg.w.reg_offset_dynamic is not None and self.dynamic_offset is not None and self.opt_dict['dynamic_offset']: |
| | |
| | if frame_idx == 0: |
| | reg_offset_d = torch.zeros_like(self.dynamic_offset[0]) |
| | offset_d = self.dynamic_offset[0] |
| | else: |
| | reg_offset_d = torch.stack([self.dynamic_offset[0], self.dynamic_offset[frame_idx - 1]]) |
| | offset_d = self.dynamic_offset[frame_idx] |
| |
|
| | reg_offset_dynamic = ((offset_d - reg_offset_d) ** 2).mean() |
| | log_dict["reg_offset_dynamic"] = self.cfg.w.reg_offset_dynamic * reg_offset_dynamic |
| |
|
| | return log_dict |
| |
|
| | def scale_vertex_weights_by_region(self, weights, scale_factor, region): |
| | indices = self.flame.mask.get_vid_by_region(region) |
| | weights[:, indices] *= scale_factor |
| |
|
| | for _ in range(self.cfg.w.blur_iter): |
| | M = self.flame.laplacian_matrix_negate_diag[None, ...] |
| | weights = M.bmm(weights) / 2 |
| | return weights |
| | |
| | def compute_pose_smooth_energy(self, frame_idx, use_next_frame=False): |
| | """ |
| | Regularizes the global pose of the flame head model to be temporally smooth |
| | """ |
| | idx = frame_idx |
| | idx_prev = np.clip(idx - 1, 0, self.n_timesteps - 1) |
| | if use_next_frame: |
| | idx_next = np.clip(idx + 1, 0, self.n_timesteps - 1) |
| | ref_indices = [idx_prev, idx_next] |
| | else: |
| | ref_indices = [idx_prev] |
| |
|
| | E_trans = ((self.translation[[idx]] - self.translation[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_trans |
| | E_rot = ((self.rotation[[idx]] - self.rotation[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_rot |
| | return E_trans + E_rot |
| | |
| | def compute_joint_smooth_energy(self, frame_idx, use_next_frame=False): |
| | """ |
| | Regularizes the joints of the flame head model to be temporally smooth |
| | """ |
| | idx = frame_idx |
| | idx_prev = np.clip(idx - 1, 0, self.n_timesteps - 1) |
| | if use_next_frame: |
| | idx_next = np.clip(idx + 1, 0, self.n_timesteps - 1) |
| | ref_indices = [idx_prev, idx_next] |
| | else: |
| | ref_indices = [idx_prev] |
| |
|
| | E_joint_smooth = 0 |
| | E_joint_smooth += ((self.neck_pose[[idx]] - self.neck_pose[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_neck |
| | E_joint_smooth += ((self.jaw_pose[[idx]] - self.jaw_pose[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_jaw |
| | E_joint_smooth += ((self.eyes_pose[[idx]] - self.eyes_pose[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_eyes |
| | return E_joint_smooth |
| | |
| | def compute_joint_prior_energy(self, frame_idx): |
| | """ |
| | Regularizes the joints of the flame head model towards neutral joint locations |
| | """ |
| | poses = [ |
| | ("neck", self.neck_pose[[frame_idx], :]), |
| | ("jaw", self.jaw_pose[[frame_idx], :]), |
| | ("eyes", self.eyes_pose[[frame_idx], :3]), |
| | ("eyes", self.eyes_pose[[frame_idx], 3:]), |
| | ] |
| | |
| | |
| | E_joint_prior = 0 |
| | for name, pose in poses: |
| | |
| | rotmats = batch_rodrigues(torch.cat([torch.zeros_like(pose), pose], dim=0)) |
| | diff = ((rotmats[[0]] - rotmats[1:]) ** 2).mean() |
| |
|
| | |
| | if name == 'jaw': |
| | |
| | diff += F.relu(-pose[:, 0]).mean() * 10 |
| |
|
| | |
| | diff += (pose[:, 1:] ** 2).mean() * 3 |
| | elif name == 'eyes': |
| | |
| | diff += ((self.eyes_pose[[frame_idx], :3] - self.eyes_pose[[frame_idx], 3:]) ** 2).mean() |
| |
|
| | E_joint_prior += diff * self.cfg.w[f"prior_{name}"] |
| | return E_joint_prior |
| |
|
| | def compute_laplacian_smoothing_loss(self, verts, offset_verts): |
| | L = self.flame.laplacian_matrix[None, ...].detach() |
| | basis_lap = L.bmm(verts).detach() |
| |
|
| | offset_lap = L.bmm(offset_verts) |
| | diff = (offset_lap - basis_lap) ** 2 |
| | diff = diff.sum(dim=-1, keepdim=True) |
| | return diff |
| |
|
| | def compute_energy( |
| | self, |
| | sample, |
| | frame_idx, |
| | include_keyframes=False, |
| | step_i=None, |
| | stage=None, |
| | ): |
| | """ |
| | Compute total energy for frame frame_idx |
| | :param sample: |
| | :param frame_idx: |
| | :param include_keyframes: if key frames shall be included when predicting the per |
| | frame energy |
| | :return: loss, log dict, predicted vertices and landmarks |
| | """ |
| | log_dict = {} |
| |
|
| | gt_rgb = sample["rgb"] |
| | result_dict = {"gt_rgb": gt_rgb} |
| |
|
| | verts, verts_cano, lmks, albedos = self.forward_flame(frame_idx, include_keyframes) |
| | faces = self.flame.faces |
| |
|
| | if isinstance(sample["num_cameras"], list): |
| | num_cameras = sample["num_cameras"][0] |
| | else: |
| | num_cameras = sample["num_cameras"] |
| | |
| |
|
| | if self.cfg.w.landmark is not None: |
| | lmks_n = self.repeat_n_times(lmks, num_cameras) |
| | if not self.cfg.w.always_enable_jawline_landmarks and stage is not None: |
| | disable_jawline_landmarks = self.cfg.pipeline[stage]['disable_jawline_landmarks'] |
| | else: |
| | disable_jawline_landmarks = False |
| | E_lmk, _result_dict = self.compute_lmk_energy(sample, lmks_n, disable_jawline_landmarks) |
| | log_dict["lmk"] = self.cfg.w.landmark * E_lmk |
| | result_dict.update(_result_dict) |
| | |
| | if stage is None or isinstance(self.cfg.pipeline[stage], PhotometricStageConfig): |
| | if self.cfg.w.photo is not None: |
| | verts_n = self.repeat_n_times(verts, num_cameras) |
| | rast_dict = self.rasterize_flame( |
| | sample, verts_n, self.flame.faces, train_mode=True |
| | ) |
| |
|
| | photo_energy_func = self.compute_photometric_energy |
| | E_photo, _result_dict = photo_energy_func( |
| | sample, |
| | verts, |
| | faces, |
| | albedos, |
| | rast_dict, |
| | step_i, |
| | stage, |
| | include_keyframes, |
| | ) |
| | result_dict.update(_result_dict) |
| | log_dict["photo"] = self.cfg.w.photo * E_photo |
| | |
| | if stage is not None: |
| | _log_dict = self.compute_regularization_energy( |
| | result_dict, verts, verts_cano, lmks, albedos, frame_idx, include_keyframes, stage |
| | ) |
| | log_dict.update(_log_dict) |
| |
|
| | E_total = torch.stack([v for k, v in log_dict.items()]).sum() |
| | log_dict["total"] = E_total |
| |
|
| | return E_total, log_dict, verts, faces, lmks, albedos, result_dict |
| |
|
| | @staticmethod |
| | def to_batch(x, indices): |
| | return torch.stack([x[i] for i in indices]) |
| |
|
| | @staticmethod |
| | def repeat_n_times(x: torch.Tensor, n: int): |
| | """Expand a tensor from shape [F, ...] to [F*n, ...]""" |
| | return x.unsqueeze(1).repeat_interleave(n, dim=1).reshape(-1, *x.shape[1:]) |
| |
|
| | @torch.no_grad() |
| | def log_scalars( |
| | self, |
| | log_dict, |
| | frame_idx, |
| | session: Literal["train", "eval"] = "train", |
| | stage=None, |
| | frame_step=None, |
| | |
| | ): |
| | """ |
| | Logs scalars in log_dict to tensorboard and self.logger |
| | :param log_dict: |
| | :param frame_idx: |
| | :param step_i: |
| | :return: |
| | """ |
| |
|
| | if not self.calibrated and stage is not None and 'cam' in self.cfg.pipeline[stage].optimizable_params: |
| | log_dict["focal_length"] = self.focal_length.squeeze(0) |
| |
|
| | log_msg = "" |
| |
|
| | if session == "train": |
| | global_step = self.global_step |
| | else: |
| | global_step = frame_idx |
| |
|
| | for k, v in log_dict.items(): |
| | if not k.startswith("decay"): |
| | log_msg += "{}: {:.4f} ".format(k, v) |
| | if self.tb_writer is not None: |
| | self.tb_writer.add_scalar(f"{session}/{k}", v, global_step) |
| |
|
| | if session == "train": |
| | assert stage is not None |
| | if frame_step is not None: |
| | msg_prefix = f"[{session}-{stage}] frame {frame_idx} step {frame_step}: " |
| | else: |
| | msg_prefix = f"[{session}-{stage}] frame {frame_idx} step {self.global_step}: " |
| | elif session == "eval": |
| | msg_prefix = f"[{session}] frame {frame_idx}: " |
| | self.logger.info(msg_prefix + log_msg) |
| |
|
| | def save_obj_with_texture(self, vertices, faces, uv_coordinates, uv_indices, albedos, obj_path, mtl_path, texture_path): |
| | |
| | torchvision.utils.save_image(albedos.squeeze(0), texture_path) |
| |
|
| | |
| | with open(mtl_path, 'w') as f: |
| | f.write(get_mtl_content(texture_path.name)) |
| | |
| | |
| | with open(obj_path, 'w') as f: |
| | f.write(get_obj_content(vertices, faces, uv_coordinates, uv_indices, mtl_path.name)) |
| | |
| | def async_func(func): |
| | """Decorator to run a function asynchronously""" |
| | def wrapper(*args, **kwargs): |
| | self = args[0] |
| | if self.cfg.async_func: |
| | thread = threading.Thread(target=func, args=args, kwargs=kwargs) |
| | thread.start() |
| | else: |
| | func(*args, **kwargs) |
| | return wrapper |
| | |
| | @torch.no_grad() |
| | @async_func |
| | def log_media( |
| | self, |
| | verts: torch.tensor, |
| | faces: torch.tensor, |
| | lmks: torch.tensor, |
| | albedos: torch.tensor, |
| | output_dict: dict, |
| | sample: dict, |
| | frame_idx: int, |
| | session: str, |
| | stage: Optional[str]=None, |
| | frame_step: int=None, |
| | epoch=None, |
| | ): |
| | """ |
| | Logs current tracking visualization to tensorboard |
| | :param verts: |
| | :param lmks: |
| | :param sample: |
| | :param frame_idx: |
| | :param frame_step: |
| | :param show_lmks: |
| | :param show_overlay: |
| | :return: |
| | """ |
| | tic = time.time() |
| | prepare_output_path = partial( |
| | self.prepare_output_path, |
| | session=session, |
| | frame_idx=frame_idx, |
| | stage=stage, |
| | step=frame_step, |
| | epoch=epoch, |
| | ) |
| |
|
| | """images""" |
| | if not self.cfg.w.always_enable_jawline_landmarks and stage is not None: |
| | disable_jawline_landmarks = self.cfg.pipeline[stage]['disable_jawline_landmarks'] |
| | else: |
| | disable_jawline_landmarks = False |
| | img = self.visualize_tracking(verts, lmks, albedos, output_dict, sample, disable_jawline_landmarks=disable_jawline_landmarks) |
| | img_path = prepare_output_path(folder_name="image_grid", file_type=self.cfg.log.image_format) |
| | torchvision.utils.save_image(img, img_path) |
| |
|
| | """meshes""" |
| | texture_path = prepare_output_path(folder_name="mesh", file_type=self.cfg.log.image_format) |
| | mtl_path = prepare_output_path(folder_name="mesh", file_type="mtl") |
| | obj_path = prepare_output_path(folder_name="mesh", file_type="obj") |
| |
|
| | vertices = verts.squeeze(0).detach().cpu().numpy() |
| | faces = faces.detach().cpu().numpy() |
| | uv_coordinates = self.flame.verts_uvs.cpu().numpy() |
| | uv_indices = self.flame.textures_idx.cpu().numpy() |
| | self.save_obj_with_texture(vertices, faces, uv_coordinates, uv_indices, albedos, obj_path, mtl_path, texture_path) |
| | """""" |
| |
|
| | toc = time.time() - tic |
| | if stage is not None: |
| | msg_prefix = f"[{session}-{stage}] frame {frame_idx}" |
| | else: |
| | msg_prefix = f"[{session}] frame {frame_idx}" |
| | if frame_step is not None: |
| | msg_prefix += f" step {frame_step}" |
| | self.logger.info(f"{msg_prefix}: Logging media took {toc:.2f}s") |
| |
|
| | @torch.no_grad() |
| | def visualize_tracking( |
| | self, |
| | verts, |
| | lmks, |
| | albedos, |
| | output_dict, |
| | sample, |
| | return_imgs_seperately=False, |
| | disable_jawline_landmarks=False, |
| | ): |
| | """ |
| | Visualizes the tracking result |
| | """ |
| | if len(self.cfg.log.view_indices) > 0: |
| | view_indices = torch.tensor(self.cfg.log.view_indices) |
| | else: |
| | num_views = sample["rgb"].shape[0] |
| | if num_views > 1: |
| | step = (num_views - 1) // (self.cfg.log.max_num_views - 1) |
| | view_indices = torch.arange(0, num_views, step=step) |
| | else: |
| | view_indices = torch.tensor([0]) |
| | num_views_log = len(view_indices) |
| |
|
| | imgs = [] |
| |
|
| | |
| | gt_rgb = output_dict["gt_rgb"][view_indices].cpu() |
| | transfm = torchvision.transforms.Resize(gt_rgb.shape[-2:]) |
| | imgs += [img[None] for img in gt_rgb] |
| |
|
| | if "pred_rgb" in output_dict: |
| | pred_rgb = transfm(output_dict["pred_rgb"][view_indices].cpu()) |
| | pred_rgb = torch.clip(pred_rgb, min=0, max=1) |
| | imgs += [img[None] for img in pred_rgb] |
| |
|
| | if "error_rgb" in output_dict: |
| | error_rgb = transfm(output_dict["error_rgb"][view_indices].cpu()) |
| | error_rgb = error_rgb.mean(dim=1) / 2 + 0.5 |
| | cmap = cm.get_cmap("seismic") |
| | error_rgb = cmap(error_rgb.cpu()) |
| | error_rgb = torch.from_numpy(error_rgb[..., :3]).to(gt_rgb).permute(0, 3, 1, 2) |
| | imgs += [img[None] for img in error_rgb] |
| | |
| | |
| | if "cid" in output_dict: |
| | cid = transfm(output_dict["cid"][view_indices].cpu()) |
| | cid = cid / cid.max() |
| | cid = cid.expand(-1, 3, -1, -1).clone() |
| |
|
| | pred_alpha = transfm(output_dict["pred_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1) |
| | bg = pred_alpha == 0 |
| | cid[bg] = 1 |
| | imgs += [img[None] for img in cid] |
| | |
| | |
| | if "albedo" in output_dict: |
| | albedo = transfm(output_dict["albedo"][view_indices].cpu()) |
| | albedo = torch.clip(albedo, min=0, max=1) |
| |
|
| | pred_alpha = transfm(output_dict["pred_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1) |
| | bg = pred_alpha == 0 |
| | albedo[bg] = 1 |
| | imgs += [img[None] for img in albedo] |
| | |
| | |
| | if "normal" in output_dict: |
| | normal = transfm(output_dict["normal"][view_indices].cpu()) |
| | normal = torch.clip(normal/2+0.5, min=0, max=1) |
| | imgs += [img[None] for img in normal] |
| | |
| | |
| | diffuse = None |
| | if self.cfg.render.lighting_type != 'constant' and "diffuse" in output_dict: |
| | diffuse = transfm(output_dict["diffuse"][view_indices].cpu()) |
| | diffuse = torch.clip(diffuse, min=0, max=1) |
| | imgs += [img[None] for img in diffuse] |
| | |
| | |
| | if "aa" in output_dict: |
| | aa = transfm(output_dict["aa"][view_indices].cpu()) |
| | aa = torch.clip(aa, min=0, max=1) |
| | imgs += [img[None] for img in aa] |
| |
|
| | |
| | if "gt_alpha" in output_dict: |
| | gt_alpha = transfm(output_dict["gt_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1) |
| | imgs += [img[None] for img in gt_alpha] |
| |
|
| | if "pred_alpha" in output_dict: |
| | pred_alpha = transfm(output_dict["pred_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1) |
| | color_alpha = torch.tensor([0.2, 0.5, 1])[None, :, None, None] |
| | fg_mask = (pred_alpha > 0).float() |
| | if diffuse is not None: |
| | fg_mask *= diffuse |
| | w = 0.7 |
| | overlay_alpha = fg_mask * (w * color_alpha * pred_alpha + (1-w) * gt_rgb) \ |
| | + (1 - fg_mask) * gt_rgb |
| | imgs += [img[None] for img in overlay_alpha] |
| |
|
| | if "error_alpha" in output_dict: |
| | error_alpha = transfm(output_dict["error_alpha"][view_indices].cpu()) |
| | error_alpha = error_alpha.mean(dim=1) / 2 + 0.5 |
| | cmap = cm.get_cmap("seismic") |
| | error_alpha = cmap(error_alpha.cpu()) |
| | error_alpha = ( |
| | torch.from_numpy(error_alpha[..., :3]).to(gt_rgb).permute(0, 3, 1, 2) |
| | ) |
| | imgs += [img[None] for img in error_alpha] |
| | else: |
| | error_alpha = None |
| | |
| | |
| | vis_lmk = self.visualize_landmarks(gt_rgb, output_dict, view_indices, disable_jawline_landmarks) |
| | if vis_lmk is not None: |
| | imgs += [img[None] for img in vis_lmk] |
| | |
| | num_types = len(imgs) // len(view_indices) |
| | |
| | if return_imgs_seperately: |
| | return imgs |
| | else: |
| | if self.cfg.log.stack_views_in_rows: |
| | imgs = [imgs[j * num_views_log + i] for i in range(num_views_log) for j in range(num_types)] |
| | imgs = torch.cat(imgs, dim=0).cpu() |
| | return torchvision.utils.make_grid(imgs, nrow=num_types) |
| | else: |
| | imgs = torch.cat(imgs, dim=0).cpu() |
| | return torchvision.utils.make_grid(imgs, nrow=num_views_log) |
| | |
| | @torch.no_grad() |
| | def visualize_landmarks(self, gt_rgb, output_dict, view_indices=torch.tensor([0]), disable_jawline_landmarks=False): |
| | h, w = gt_rgb.shape[-2:] |
| | unit = h / 750 |
| | wh = torch.tensor([[[w, h]]]) |
| | vis_lmk = None |
| | if "gt_lmk2d" in output_dict: |
| | gt_lmk2d = (output_dict['gt_lmk2d'][view_indices].cpu() * 0.5 + 0.5) * wh |
| | if disable_jawline_landmarks: |
| | gt_lmk2d = gt_lmk2d[:, 17:68] |
| | else: |
| | gt_lmk2d = gt_lmk2d[:, :68] |
| | vis_lmk = gt_rgb.clone() if vis_lmk is None else vis_lmk |
| | for i in range(len(view_indices)): |
| | vis_lmk[i] = plot_landmarks_2d( |
| | vis_lmk[i].clone(), |
| | gt_lmk2d[[i]], |
| | colors="green", |
| | unit=unit, |
| | input_float=True, |
| | ).to(vis_lmk[i]) |
| | if "pred_lmk2d" in output_dict: |
| | pred_lmk2d = (output_dict['pred_lmk2d'][view_indices].cpu() * 0.5 + 0.5) * wh |
| | if disable_jawline_landmarks: |
| | pred_lmk2d = pred_lmk2d[:, 17:68] |
| | else: |
| | pred_lmk2d = pred_lmk2d[:, :68] |
| | vis_lmk = gt_rgb.clone() if vis_lmk is None else vis_lmk |
| | for i in range(len(view_indices)): |
| | vis_lmk[i] = plot_landmarks_2d( |
| | vis_lmk[i].clone(), |
| | pred_lmk2d[[i]], |
| | colors="red", |
| | unit=unit, |
| | input_float=True, |
| | ).to(vis_lmk[i]) |
| | return vis_lmk |
| | |
| | @torch.no_grad() |
| | def evaluate(self, make_visualization=True, epoch=0): |
| | |
| | self.save_result(epoch=epoch) |
| |
|
| | self.logger.info("Started Evaluation") |
| | |
| | photo_loss = [] |
| | for frame_idx in range(self.n_timesteps): |
| |
|
| | sample = self.get_current_frame(frame_idx, include_keyframes=False) |
| | self.clear_cache() |
| | self.fill_cam_params_into_sample(sample) |
| | ( |
| | E_total, |
| | log_dict, |
| | verts, |
| | faces, |
| | lmks, |
| | albedos, |
| | output_dict, |
| | ) = self.compute_energy(sample, frame_idx) |
| |
|
| | self.log_scalars(log_dict, frame_idx, session="eval") |
| | photo_loss.append(log_dict["photo"].item()) |
| |
|
| | if make_visualization: |
| | self.log_media( |
| | verts, |
| | faces, |
| | lmks, |
| | albedos, |
| | output_dict, |
| | sample, |
| | frame_idx, |
| | session="eval", |
| | epoch=epoch, |
| | ) |
| | |
| | self.tb_writer.add_scalar(f"eval_mean/photo", np.mean(photo_loss), epoch) |
| |
|
| | def prepare_output_path(self, session, frame_idx, folder_name, file_type, stage=None, step=None, epoch=None): |
| | if epoch is not None: |
| | output_folder = self.out_dir / f'{session}_{epoch}' / folder_name |
| | else: |
| | output_folder = self.out_dir / session / folder_name |
| | os.makedirs(output_folder, exist_ok=True) |
| | |
| | if stage is not None: |
| | assert step is not None |
| | fname = "frame_{:05d}_{:03d}_{}.{}".format(frame_idx, step, stage, file_type) |
| | else: |
| | fname = "frame_{:05d}.{}".format(frame_idx, file_type) |
| | return output_folder / fname |
| |
|
| | def save_result(self, fname=None, epoch=None): |
| | """ |
| | Saves tracked/optimized flame parameters. |
| | :return: |
| | """ |
| | |
| | keys = [ |
| | "rotation", |
| | "translation", |
| | "neck_pose", |
| | "jaw_pose", |
| | "eyes_pose", |
| | "shape", |
| | "expr", |
| | "timestep_id", |
| | "n_processed_frames", |
| | ] |
| | values = [ |
| | self.rotation, |
| | self.translation, |
| | self.neck_pose, |
| | self.jaw_pose, |
| | self.eyes_pose, |
| | self.shape, |
| | self.expr, |
| | np.array(self.dataset.timestep_ids), |
| | self.frame_idx, |
| | ] |
| | if not self.calibrated: |
| | keys += ["focal_length"] |
| | values += [self.focal_length] |
| | |
| | if not self.cfg.model.tex_painted: |
| | keys += ["tex"] |
| | values += [self.tex_pca] |
| | |
| | if self.cfg.model.tex_extra: |
| | keys += ["tex_extra"] |
| | values += [self.tex_extra] |
| | |
| | if self.lights is not None: |
| | keys += ["lights"] |
| | values += [self.lights] |
| | |
| | if self.cfg.model.use_static_offset: |
| | keys += ["static_offset"] |
| | values += [self.static_offset] |
| |
|
| | if self.cfg.model.use_dynamic_offset: |
| | keys += ["dynamic_offset"] |
| | values += [self.dynamic_offset] |
| |
|
| | export_dict = {} |
| | for k, v in zip(keys, values): |
| | if not isinstance(v, np.ndarray): |
| | if isinstance(v, list): |
| | v = torch.stack(v) |
| | if isinstance(v, torch.Tensor): |
| | v = v.detach().cpu().numpy() |
| | export_dict[k] = v |
| |
|
| | export_dict["image_size"] = np.array(self.image_size) |
| |
|
| | fname = fname if fname is not None else "tracked_flame_params" |
| | if epoch is not None: |
| | fname = f"{fname}_{epoch}" |
| | np.savez(self.out_dir / f'{fname}.npz', **export_dict) |
| |
|
| |
|
| | class GlobalTracker(FlameTracker): |
| | def __init__(self, cfg: BaseTrackingConfig): |
| | super().__init__(cfg) |
| |
|
| | self.calibrated = cfg.data.calibrated |
| |
|
| | |
| | out_dir = cfg.exp.output_folder / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
| | out_dir.mkdir(parents=True,exist_ok=True) |
| |
|
| | self.frame_idx = self.cfg.begin_frame_idx |
| | self.out_dir = out_dir |
| | self.tb_writer = SummaryWriter(self.out_dir) |
| | |
| | self.log_interval_scalar = self.cfg.log.interval_scalar |
| | self.log_interval_media = self.cfg.log.interval_media |
| |
|
| | config_yaml_path = out_dir / 'config.yml' |
| | config_yaml_path.write_text(yaml.dump(cfg), "utf8") |
| | print(tyro.to_yaml(cfg)) |
| |
|
| | self.logger = get_logger(__name__, root=True, log_dir=out_dir) |
| |
|
| | |
| | self.dataset = import_module(cfg.data._target)( |
| | cfg=cfg.data, |
| | img_to_tensor=True, |
| | batchify_all_views=True, |
| | ) |
| | |
| | |
| |
|
| | self.image_size = self.dataset[0]["rgb"].shape[-2:] |
| | self.n_timesteps = len(self.dataset) |
| |
|
| | |
| | self.init_params() |
| |
|
| | if self.cfg.model.flame_params_path is not None: |
| | self.load_from_tracked_flame_params(self.cfg.model.flame_params_path) |
| | |
| | def init_params(self): |
| | train_tensors = [] |
| |
|
| | |
| | self.shape = torch.zeros(self.cfg.model.n_shape).to(self.device) |
| | self.expr = torch.zeros(self.n_timesteps, self.cfg.model.n_expr).to(self.device) |
| |
|
| | |
| | self.neck_pose = torch.zeros(self.n_timesteps, 3).to(self.device) |
| | self.jaw_pose = torch.zeros(self.n_timesteps, 3).to(self.device) |
| | self.eyes_pose = torch.zeros(self.n_timesteps, 6).to(self.device) |
| |
|
| | |
| | self.translation = torch.zeros(self.n_timesteps, 3).to(self.device) |
| | self.rotation = torch.zeros(self.n_timesteps, 3).to(self.device) |
| |
|
| | |
| | self.tex_pca = torch.zeros(self.cfg.model.n_tex).to(self.device) |
| | if self.cfg.model.tex_extra: |
| | res = self.cfg.model.tex_resolution |
| | self.tex_extra = torch.zeros(3, res, res).to(self.device) |
| | |
| | if self.cfg.render.lighting_type == 'SH': |
| | self.lights_uniform = torch.zeros(9, 3).to(self.device) |
| | self.lights_uniform[0] = torch.tensor([np.sqrt(4 * np.pi)]).expand(3).float().to(self.device) |
| | self.lights = self.lights_uniform.clone() |
| | else: |
| | self.lights = None |
| |
|
| | train_tensors += ( |
| | [self.shape, self.translation, self.rotation, self.neck_pose, self.jaw_pose, self.eyes_pose, self.expr,] |
| | ) |
| |
|
| | if not self.cfg.model.tex_painted: |
| | train_tensors += [self.tex_pca] |
| | if self.cfg.model.tex_extra: |
| | train_tensors += [self.tex_extra] |
| |
|
| | if self.lights is not None: |
| | train_tensors += [self.lights] |
| |
|
| | if self.cfg.model.use_static_offset: |
| | self.static_offset = torch.zeros(1, self.flame.v_template.shape[0], 3).to(self.device) |
| | train_tensors += [self.static_offset] |
| | else: |
| | self.static_offset = None |
| | |
| | if self.cfg.model.use_dynamic_offset: |
| | self.dynamic_offset = torch.zeros(self.n_timesteps, self.flame.v_template.shape[0], 3).to(self.device) |
| | train_tensors += self.dynamic_offset |
| | else: |
| | self.dynamic_offset = None |
| |
|
| | |
| | if not self.calibrated: |
| | |
| | self.focal_length = torch.tensor([1.5]).to(self.device) |
| | self.RT = torch.eye(3, 4).to(self.device) |
| | self.RT[2, 3] = -1 |
| | train_tensors += [self.focal_length] |
| |
|
| | for t in train_tensors: |
| | t.requires_grad = True |
| |
|
| | def optimize(self): |
| | """ |
| | Optimizes flame parameters on all frames of the dataset with random rampling |
| | :return: |
| | """ |
| | self.global_step = 0 |
| | |
| | |
| | |
| | |
| |
|
| | |
| | self.logger.info(f"Start sequential tracking FLAME in {self.n_timesteps} frames") |
| | dataloader = DataLoader(self.dataset, batch_size=None, shuffle=False, num_workers=0) |
| | for sample in dataloader: |
| | timestep = sample["timestep_index"][0].item() |
| | if timestep == 0: |
| | self.optimize_stage('lmk_init_rigid', sample) |
| | self.optimize_stage('lmk_init_all', sample) |
| | if self.cfg.exp.photometric: |
| | self.optimize_stage('rgb_init_texture', sample) |
| | self.optimize_stage('rgb_init_all', sample) |
| | if self.cfg.model.use_static_offset: |
| | self.optimize_stage('rgb_init_offset', sample) |
| |
|
| | if self.cfg.exp.photometric: |
| | self.optimize_stage('rgb_sequential_tracking', sample) |
| | else: |
| | self.optimize_stage('lmk_sequential_tracking', sample) |
| | self.initialize_next_timtestep(timestep) |
| | |
| | self.evaluate(make_visualization=False, epoch=0) |
| |
|
| | self.logger.info(f"Start global optimization of all frames") |
| | |
| | dataloader = DataLoader(self.dataset, batch_size=None, shuffle=True, num_workers=0) |
| | if self.cfg.exp.photometric: |
| | self.optimize_stage(stage='rgb_global_tracking', dataloader=dataloader, lr_scale=0.1) |
| | else: |
| | self.optimize_stage(stage='lmk_global_tracking', dataloader=dataloader, lr_scale=0.1) |
| |
|
| | self.logger.info("All done.") |
| | |
| | def optimize_stage( |
| | self, |
| | stage: Literal['lmk_init_rigid', 'lmk_init_all', 'rgb_init_texture', 'rgb_init_all', 'rgb_init_offset', 'rgb_sequential_tracking', 'rgb_global_tracking'], |
| | sample = None, |
| | dataloader = None, |
| | lr_scale = 1.0, |
| | ): |
| | params = self.get_train_parameters(stage) |
| | optimizer = self.configure_optimizer(params, lr_scale=lr_scale) |
| |
|
| | if sample is not None: |
| | num_steps = self.cfg.pipeline[stage].num_steps |
| | for step_i in range(num_steps): |
| | self.optimize_iter(sample, optimizer, stage) |
| | else: |
| | assert dataloader is not None |
| | num_epochs = self.cfg.pipeline[stage].num_epochs |
| | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) |
| | for epoch_i in range(num_epochs): |
| | self.logger.info(f"EPOCH {epoch_i+1} / {num_epochs}") |
| | for step_i, sample in enumerate(dataloader): |
| | self.optimize_iter(sample, optimizer, stage) |
| | scheduler.step() |
| |
|
| | if (epoch_i + 1) % 10 == 0: |
| | self.evaluate(make_visualization=True, epoch=epoch_i+1) |
| | |
| | def optimize_iter(self, sample, optimizer, stage): |
| | |
| | self.clear_cache() |
| |
|
| | timestep_index = sample["timestep_index"][0] |
| | self.fill_cam_params_into_sample(sample) |
| | ( |
| | E_total, |
| | log_dict, |
| | verts, |
| | faces, |
| | lmks, |
| | albedos, |
| | output_dict, |
| | ) = self.compute_energy( |
| | sample, frame_idx=timestep_index, stage=stage, |
| | ) |
| | optimizer.zero_grad() |
| | E_total.backward() |
| | optimizer.step() |
| |
|
| | |
| | if (self.global_step+1) % self.log_interval_scalar == 0: |
| | self.log_scalars( |
| | log_dict, |
| | timestep_index, |
| | session="train", |
| | stage=stage, |
| | frame_step=self.global_step, |
| | ) |
| |
|
| | if (self.global_step+1) % self.log_interval_media == 0: |
| | self.log_media( |
| | verts, |
| | faces, |
| | lmks, |
| | albedos, |
| | output_dict, |
| | sample, |
| | timestep_index, |
| | session="train", |
| | stage=stage, |
| | frame_step=self.global_step, |
| | ) |
| | del verts, faces, lmks, albedos, output_dict |
| | self.global_step += 1 |
| |
|
| |
|
| | def get_train_parameters( |
| | self, stage: Literal['lmk_init_rigid', 'lmk_init_all', 'rgb_init_all', 'rgb_init_offset', 'rgb_sequential_tracking', 'rgb_global_tracking'], |
| | ): |
| | """ |
| | Collects the parameters to be optimized for the current frame |
| | :return: dict of parameters |
| | """ |
| | self.opt_dict = defaultdict(bool) |
| | for p in self.cfg.pipeline[stage].optimizable_params: |
| | self.opt_dict[p] = True |
| | |
| | params = defaultdict(list) |
| | |
| | |
| | if self.opt_dict["cam"] and not self.calibrated: |
| | params["cam"] = [self.focal_length] |
| |
|
| | if self.opt_dict["shape"]: |
| | params["shape"] = [self.shape] |
| | |
| | if self.opt_dict["texture"]: |
| | if not self.cfg.model.tex_painted: |
| | params["tex"] = [self.tex_pca] |
| | if self.cfg.model.tex_extra: |
| | params["tex_extra"] = [self.tex_extra] |
| |
|
| | if self.opt_dict["static_offset"] and self.cfg.model.use_static_offset: |
| | params["static_offset"] = [self.static_offset] |
| | |
| | if self.opt_dict["lights"] and self.lights is not None: |
| | params["lights"] = [self.lights] |
| | |
| | |
| | if self.opt_dict["pose"]: |
| | params["translation"].append(self.translation) |
| | params["rotation"].append(self.rotation) |
| |
|
| | if self.opt_dict["joints"]: |
| | params["eyes"].append(self.eyes_pose) |
| | params["neck"].append(self.neck_pose) |
| | params["jaw"].append(self.jaw_pose) |
| |
|
| | if self.opt_dict["expr"]: |
| | params["expr"].append(self.expr) |
| | |
| | if self.opt_dict["dynamic_offset"] and self.cfg.model.use_dynamic_offset: |
| | params["dynamic_offset"].append(self.dynamic_offset) |
| |
|
| | return params |
| |
|
| | def initialize_next_timtestep(self, timestep): |
| | if timestep < self.n_timesteps - 1: |
| | self.translation[timestep + 1].data.copy_(self.translation[timestep]) |
| | self.rotation[timestep + 1].data.copy_(self.rotation[timestep]) |
| | self.neck_pose[timestep + 1].data.copy_(self.neck_pose[timestep]) |
| | self.jaw_pose[timestep + 1].data.copy_(self.jaw_pose[timestep]) |
| | self.eyes_pose[timestep + 1].data.copy_(self.eyes_pose[timestep]) |
| | self.expr[timestep + 1].data.copy_(self.expr[timestep]) |
| | if self.cfg.model.use_dynamic_offset: |
| | self.dynamic_offset[timestep + 1].data.copy_(self.dynamic_offset[timestep]) |
| |
|