text2live / Text2LIVE-main /util /video_logger.py
SupermanxKiaski's picture
Upload 356 files
16d007c
raw
history blame
No virus
3.85 kB
from pathlib import Path
import imageio
import torch
from util.util import tensor2im
class DataLogger:
def __init__(self, config, dataset):
self.layers_edits = {
"background": dataset.original_video.detach().cpu(),
"foreground": dataset.original_video.detach().cpu(),
}
self.alpha_video = dataset.all_alpha.detach().cpu()
self.config = config
self.layer_name = "foreground" if config["finetune_foreground"] else "background"
@torch.no_grad()
def log_data(self, epoch, lr, losses, model, dataset):
log_data = {}
for layer, layer_losses in losses.items():
for key in layer_losses:
log_data[f"Loss/{layer}_{key}"] = layer_losses[key].detach()
log_data["epoch"] = epoch
log_data["lr"] = lr
if epoch % self.config["log_images_freq"] == 0:
layer = self.layer_name
edited_atlas_dict, edit_dict, uv_mask = dataset.render_video_from_atlas(model, layer=layer)
alpha_of_edit = None
edit_only = None
for key in edited_atlas_dict.keys():
if key != "edit":
masked = tensor2im(edited_atlas_dict[key].detach().cpu() * uv_mask)
log_data[f"Atlases/{layer}_masked_{key}"] = (
wandb.Image(masked) if self.config["use_wandb"] else masked
)
if key == "alpha":
alpha_of_edit = edited_atlas_dict[key].detach().cpu() * uv_mask
if key == "edit":
edit_only = edited_atlas_dict[key].detach().cpu() * uv_mask
rgba_edit = tensor2im(torch.cat((edit_only, alpha_of_edit[:, [0]]), dim=1))
log_data[f"Atlases/{layer}_rgba_layer"] = wandb.Image(rgba_edit) if self.config["use_wandb"] else rgba_edit
for key in edit_dict.keys():
if key != "composite" and key != "edit":
video = (255 * edit_dict[key].detach().cpu()).to(torch.uint8)
log_data[f"Videos/{layer}_{key}"] = (
wandb.Video(video, fps=10, format="mp4") if self.config["use_wandb"] else video
)
if self.config[f"finetune_{layer}"]:
self.layers_edits[layer] = edit_dict["composite"].detach().cpu()
full_video = (
self.alpha_video * self.layers_edits["foreground"]
+ (1 - self.alpha_video) * self.layers_edits["background"]
)
full_video = (255 * full_video.detach().cpu()).to(torch.uint8)
log_data["Videos/full_video"] = (
wandb.Video(full_video, fps=10, format="mp4") if self.config["use_wandb"] else full_video
)
# save model checkpoint
if epoch > self.config["save_model_starting_epoch"]:
filename = f"checkpoint_epoch_{epoch}.pt"
dict_to_save = {
"model": model.state_dict(),
}
if self.config["use_wandb"]:
checkpoint_path = f"{wandb.run.dir}/{filename}"
else:
checkpoint_path = f"{self.config['results_folder']}/{filename}"
torch.save(dict_to_save, checkpoint_path)
return log_data
def save_locally(self, log_data):
path = Path(self.config["results_folder"], str(log_data["epoch"]))
path.mkdir(parents=True, exist_ok=True)
for key in log_data.keys():
save_name = key.replace("/", "_")
if key.startswith("Videos"):
imageio.mimwrite(f"{path}/{save_name}.mp4", log_data[key].permute(0, 2, 3, 1))
elif key.startswith("Atlases"):
imageio.imwrite(f"{path}/{save_name}.png", log_data[key])