Spaces:
Running
on
Zero
Running
on
Zero
| from torchmetrics import MetricCollection | |
| from svd_pipeline import StableVideoDiffusionPipeline | |
| from accelerate.logging import get_logger | |
| import os | |
| from utils import load_image | |
| import torch | |
| import numpy as np | |
| import videoio | |
| import torchmetrics.image | |
| import matplotlib.image | |
| from PIL import Image | |
| logger = get_logger(__name__, log_level="INFO") | |
| def valid_net(args, val_dataset, val_dataloader, unet, image_encoder, vae, zero, accelerator, global_step, weight_dtype): | |
| logger.info( | |
| f"Running validation... \n Generating {args.num_validation_images} videos." | |
| ) | |
| # The models need unwrapping because for compatibility in distributed training mode. | |
| pipeline = StableVideoDiffusionPipeline.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| unet=unet, | |
| image_encoder=image_encoder, | |
| vae=vae, | |
| revision=args.revision, | |
| torch_dtype=weight_dtype, | |
| ) | |
| pipeline.set_progress_bar_config(disable=True) | |
| # run inference | |
| val_save_dir = os.path.join( | |
| args.output_dir, "validation_images") | |
| print("Validation images will be saved to ", val_save_dir) | |
| os.makedirs(val_save_dir, exist_ok=True) | |
| num_frames = args.num_frames | |
| unet.eval() | |
| with torch.autocast( | |
| str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" | |
| ): | |
| for batch in val_dataloader: | |
| #clear gradients (the torch no grad is the magic that makes this work) | |
| with torch.no_grad(): | |
| torch.cuda.empty_cache() | |
| pixel_values = batch["pixel_values"].to(accelerator.device) | |
| original_pixel_values = batch['original_pixel_values'].to(accelerator.device) | |
| idx = batch["idx"].to(accelerator.device) | |
| if "focal_stack_num" in batch: | |
| focal_stack_num = batch["focal_stack_num"][0].item() | |
| else: | |
| focal_stack_num = None | |
| svd_output, gt_frames = pipeline( | |
| pixel_values, | |
| height=pixel_values.shape[3], | |
| width=pixel_values.shape[4], | |
| num_frames=args.num_frames, | |
| decode_chunk_size=8, | |
| motion_bucket_id=0 if args.conditioning != "ablate_time" else focal_stack_num, | |
| min_guidance_scale=1.5, | |
| max_guidance_scale=1.5, | |
| reconstruction_guidance_scale=args.reconstruction_guidance, | |
| fps=7, | |
| noise_aug_strength=0, | |
| accelerator=accelerator, | |
| weight_dtype=weight_dtype, | |
| conditioning = args.conditioning, | |
| focal_stack_num = focal_stack_num, | |
| zero=zero | |
| # generator=generator, | |
| ) | |
| video_frames = svd_output.frames[0] | |
| gt_frames = gt_frames[0] | |
| with torch.no_grad(): | |
| if args.num_frames == 10: | |
| #remove a frame at end from video_frames and gt_frames | |
| video_frames = video_frames[:, :-1] | |
| gt_frames = gt_frames[:, :-1] | |
| original_pixel_values = original_pixel_values[:, :-1] | |
| if len(original_pixel_values.shape) == 5: | |
| pixel_values = original_pixel_values[0] #assuming batch size is 1 | |
| else: | |
| pixel_values = original_pixel_values.repeat(num_frames, 1, 1, 1) | |
| pixel_values_normalized = pixel_values*0.5 + 0.5 | |
| pixel_values_normalized = torch.clamp(pixel_values_normalized,0,1) | |
| video_frames_normalized = video_frames*0.5 + 0.5 | |
| video_frames_normalized = torch.clamp(video_frames_normalized,0,1) | |
| video_frames_normalized = video_frames_normalized.permute(1,0,2,3) | |
| gt_frames = torch.clamp(gt_frames,0,1) | |
| gt_frames = gt_frames.permute(1,0,2,3) | |
| #RESIZE images | |
| video_frames_normalized = torch.nn.functional.interpolate(video_frames_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear') | |
| gt_frames = torch.nn.functional.interpolate(gt_frames, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear') | |
| pixel_values_normalized = torch.nn.functional.interpolate(pixel_values_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear') | |
| os.makedirs(os.path.join(val_save_dir, f"position_{focal_stack_num}/videos"), exist_ok=True) | |
| videoio.videosave(os.path.join( | |
| val_save_dir, | |
| f"position_{focal_stack_num}/videos/step_{global_step}_val_img_{idx[0].item()}.mp4", | |
| ), video_frames_normalized.permute(0,2,3,1).cpu().numpy(), fps=5) | |
| if args.test: | |
| #save images | |
| os.makedirs(os.path.join(val_save_dir, f"position_{focal_stack_num}/images"), exist_ok=True) | |
| if not args.photos: | |
| for i in range(num_frames): | |
| matplotlib.image.imsave(os.path.join(val_save_dir, f"position_{focal_stack_num}/images/img_{idx[0].item()}_frame_{i}.png"), video_frames_normalized[i].permute(1,2,0).cpu().numpy()) | |
| else: | |
| for i in range(num_frames): | |
| #use Pillow to save images | |
| img = Image.fromarray((video_frames_normalized[i].permute(1,2,0).cpu().numpy()*255).astype(np.uint8)) | |
| #use index to assign icc profile to img | |
| if batch['icc_profile'][0] != "none": | |
| img.info['icc_profile'] = batch['icc_profile'][0] | |
| img.save(os.path.join(val_save_dir, f"position_{focal_stack_num}/images/img_{idx[0].item()}_frame_{i}.png")) | |
| del video_frames | |
| accelerator.wait_for_everyone() | |
| #clear gradients (the torch no grad is the magic that makes this work) | |
| with torch.no_grad(): | |
| torch.cuda.empty_cache() | |
| del pipeline | |
| accelerator.wait_for_everyone() #this is really important and we need to make sure everyone is leaving at the same time | |