Spaces:
Configuration error
Configuration error
import torch | |
from .networks import define_G | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class Model(torch.nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = cfg | |
self.netG = define_G(cfg).to(device) | |
def render(self, net_output, bg_image=None): | |
assert net_output.min() >= 0 and net_output.max() <= 1 | |
edit = net_output[:, :3] | |
alpha = net_output[:, 3].unsqueeze(1).repeat(1, 3, 1, 1) | |
greenscreen = torch.zeros_like(edit).to(edit.device) | |
greenscreen[:, 1, :, :] = 177 / 255 | |
greenscreen[:, 2, :, :] = 64 / 255 | |
edit_on_greenscreen = alpha * edit + (1 - alpha) * greenscreen | |
outputs = {"edit": edit, "alpha": alpha, "edit_on_greenscreen": edit_on_greenscreen} | |
if bg_image is not None: | |
outputs["composite"] = (1 - alpha) * bg_image + alpha * edit | |
return outputs | |
def forward(self, input): | |
outputs = {} | |
# augmented examples | |
if "input_crop" in input: | |
outputs["output_crop"] = self.render(self.netG(input["input_crop"]), bg_image=input["input_crop"]) | |
# pass the entire image (w/o augmentations) | |
if "input_image" in input: | |
outputs["output_image"] = self.render(self.netG(input["input_image"]), bg_image=input["input_image"]) | |
# move outputs to list | |
for outer_key in outputs.keys(): | |
for key, value in outputs[outer_key].items(): | |
outputs[outer_key][key] = [value[0]] | |
return outputs | |