learn2refocus / training /validation.py
tedlasai's picture
commit
199f9c2
from torchmetrics import MetricCollection
from svd_pipeline import StableVideoDiffusionPipeline
from accelerate.logging import get_logger
import os
from utils import load_image
import torch
import numpy as np
import videoio
import torchmetrics.image
import matplotlib.image
from PIL import Image
logger = get_logger(__name__, log_level="INFO")
def valid_net(args, val_dataset, val_dataloader, unet, image_encoder, vae, zero, accelerator, global_step, weight_dtype):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} videos."
)
# The models need unwrapping because for compatibility in distributed training mode.
pipeline = StableVideoDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unet,
image_encoder=image_encoder,
vae=vae,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline.set_progress_bar_config(disable=True)
# run inference
val_save_dir = os.path.join(
args.output_dir, "validation_images")
print("Validation images will be saved to ", val_save_dir)
os.makedirs(val_save_dir, exist_ok=True)
num_frames = args.num_frames
unet.eval()
with torch.autocast(
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
):
for batch in val_dataloader:
#clear gradients (the torch no grad is the magic that makes this work)
with torch.no_grad():
torch.cuda.empty_cache()
pixel_values = batch["pixel_values"].to(accelerator.device)
original_pixel_values = batch['original_pixel_values'].to(accelerator.device)
idx = batch["idx"].to(accelerator.device)
if "focal_stack_num" in batch:
focal_stack_num = batch["focal_stack_num"][0].item()
else:
focal_stack_num = None
svd_output, gt_frames = pipeline(
pixel_values,
height=pixel_values.shape[3],
width=pixel_values.shape[4],
num_frames=args.num_frames,
decode_chunk_size=8,
motion_bucket_id=0 if args.conditioning != "ablate_time" else focal_stack_num,
min_guidance_scale=1.5,
max_guidance_scale=1.5,
reconstruction_guidance_scale=args.reconstruction_guidance,
fps=7,
noise_aug_strength=0,
accelerator=accelerator,
weight_dtype=weight_dtype,
conditioning = args.conditioning,
focal_stack_num = focal_stack_num,
zero=zero
# generator=generator,
)
video_frames = svd_output.frames[0]
gt_frames = gt_frames[0]
with torch.no_grad():
if args.num_frames == 10:
#remove a frame at end from video_frames and gt_frames
video_frames = video_frames[:, :-1]
gt_frames = gt_frames[:, :-1]
original_pixel_values = original_pixel_values[:, :-1]
if len(original_pixel_values.shape) == 5:
pixel_values = original_pixel_values[0] #assuming batch size is 1
else:
pixel_values = original_pixel_values.repeat(num_frames, 1, 1, 1)
pixel_values_normalized = pixel_values*0.5 + 0.5
pixel_values_normalized = torch.clamp(pixel_values_normalized,0,1)
video_frames_normalized = video_frames*0.5 + 0.5
video_frames_normalized = torch.clamp(video_frames_normalized,0,1)
video_frames_normalized = video_frames_normalized.permute(1,0,2,3)
gt_frames = torch.clamp(gt_frames,0,1)
gt_frames = gt_frames.permute(1,0,2,3)
#RESIZE images
video_frames_normalized = torch.nn.functional.interpolate(video_frames_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
gt_frames = torch.nn.functional.interpolate(gt_frames, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
pixel_values_normalized = torch.nn.functional.interpolate(pixel_values_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
os.makedirs(os.path.join(val_save_dir, f"position_{focal_stack_num}/videos"), exist_ok=True)
videoio.videosave(os.path.join(
val_save_dir,
f"position_{focal_stack_num}/videos/step_{global_step}_val_img_{idx[0].item()}.mp4",
), video_frames_normalized.permute(0,2,3,1).cpu().numpy(), fps=5)
if args.test:
#save images
os.makedirs(os.path.join(val_save_dir, f"position_{focal_stack_num}/images"), exist_ok=True)
if not args.photos:
for i in range(num_frames):
matplotlib.image.imsave(os.path.join(val_save_dir, f"position_{focal_stack_num}/images/img_{idx[0].item()}_frame_{i}.png"), video_frames_normalized[i].permute(1,2,0).cpu().numpy())
else:
for i in range(num_frames):
#use Pillow to save images
img = Image.fromarray((video_frames_normalized[i].permute(1,2,0).cpu().numpy()*255).astype(np.uint8))
#use index to assign icc profile to img
if batch['icc_profile'][0] != "none":
img.info['icc_profile'] = batch['icc_profile'][0]
img.save(os.path.join(val_save_dir, f"position_{focal_stack_num}/images/img_{idx[0].item()}_frame_{i}.png"))
del video_frames
accelerator.wait_for_everyone()
#clear gradients (the torch no grad is the magic that makes this work)
with torch.no_grad():
torch.cuda.empty_cache()
del pipeline
accelerator.wait_for_everyone() #this is really important and we need to make sure everyone is leaving at the same time