Spaces:
Configuration error
Configuration error
import random | |
import torch | |
import torch.nn.functional as F | |
import torchvision | |
from torchvision import transforms | |
from torchvision.transforms import InterpolationMode | |
from models.image_model import Model | |
class VideoModel(Model): | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.net_preprocess = transforms.Compose([]) | |
def resize_crops(crops, resize_factor): | |
return torchvision.transforms.functional.resize( | |
crops, | |
[ | |
crops.shape[-2] // resize_factor, | |
crops.shape[-1] // resize_factor, | |
], | |
InterpolationMode.BILINEAR, | |
antialias=True, | |
) | |
def process_crops(self, uv_values, crops, original_crops, alpha=None): | |
resized_crops = [] | |
cnn_output_crops = [] | |
render_dict = {"edit": [], "alpha": [], "edit_on_greenscreen": [], "composite": []} | |
atlas_crop = crops[0] | |
for i in range(3): | |
grid_sampled_atlas_crop = F.grid_sample( | |
atlas_crop, | |
uv_values[i], | |
mode="bilinear", | |
align_corners=self.config["align_corners"], | |
).clamp(min=0.0, max=1.0) | |
resized_crops.append(grid_sampled_atlas_crop) | |
cnn_output = self.netG(atlas_crop) | |
cnn_output_crops.append(cnn_output[:, :3]) | |
rendered_atlas_crops = self.render(cnn_output, bg_image=atlas_crop) | |
for key, value in rendered_atlas_crops.items(): | |
for i in range(3): | |
sampled_frame_crop = F.grid_sample( | |
value, | |
uv_values[i], | |
mode="bilinear", | |
align_corners=self.config["align_corners"], | |
).clamp(min=0.0, max=1.0) | |
if alpha is not None: | |
sampled_frame_crop = sampled_frame_crop * alpha[i] | |
if key == "edit_on_greenscreen": | |
greenscreen = torch.zeros_like(sampled_frame_crop).to(sampled_frame_crop.device) | |
greenscreen[:, 1, :, :] = 177 / 255 | |
greenscreen[:, 2, :, :] = 64 / 255 | |
sampled_frame_crop += (1 - alpha[i]) * greenscreen | |
render_dict[key].append(sampled_frame_crop.squeeze(0)) | |
# passing a random frame to the network | |
frame_index = random.randint(0, 2) # randomly sample one of three frames | |
rec_crop = original_crops[frame_index] | |
resized_crops.append(rec_crop) | |
cnn_output = self.netG(rec_crop) | |
if alpha is not None: | |
alpha_crop = alpha[frame_index] | |
cnn_output = cnn_output * alpha_crop | |
cnn_output_crops.append(cnn_output[:, :3]) | |
rendered_frame_crop = self.render(cnn_output, bg_image=original_crops[frame_index]) | |
for key, value in rendered_frame_crop.items(): | |
render_dict[key].append(value.squeeze(0)) | |
return render_dict, resized_crops, cnn_output_crops | |
def process_atlas(self, atlas): | |
atlas_edit = self.netG(atlas) | |
rendered_atlas = self.render(atlas_edit, bg_image=atlas) | |
return rendered_atlas | |
def forward(self, input_dict): | |
inputs = input_dict["global_crops"] | |
outputs = {"background": {}, "foreground": {}} | |
if self.config["finetune_foreground"]: | |
if self.config["multiply_foreground_alpha"]: | |
alpha = inputs["foreground_alpha"] | |
else: | |
alpha = None | |
foreground_outputs, resized_crops, cnn_output_crops = self.process_crops( | |
inputs["foreground_uvs"], | |
inputs["foreground_atlas_crops"], | |
inputs["original_foreground_crops"], | |
alpha=alpha, | |
) | |
outputs["foreground"]["output_crop"] = foreground_outputs | |
outputs["foreground"]["cnn_inputs"] = resized_crops | |
outputs["foreground"]["cnn_outputs"] = cnn_output_crops | |
if "input_image" in input_dict.keys(): | |
outputs["foreground"]["output_image"] = self.process_atlas(input_dict["input_image"]) | |
elif self.config["finetune_background"]: | |
background_outputs, resized_crops, cnn_output_crops = self.process_crops( | |
inputs["background_uvs"], | |
inputs["background_atlas_crops"], | |
inputs["original_background_crops"], | |
) | |
outputs["background"]["output_crop"] = background_outputs | |
outputs["background"]["cnn_inputs"] = resized_crops | |
outputs["background"]["cnn_outputs"] = cnn_output_crops | |
if "input_image" in input_dict.keys(): | |
outputs["background"]["output_image"] = self.process_atlas(input_dict["input_image"]) | |
return outputs | |