SupermanxKiaski's picture
Upload 356 files
16d007c
raw
history blame
No virus
1.56 kB
import torch
from .networks import define_G
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Model(torch.nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.netG = define_G(cfg).to(device)
def render(self, net_output, bg_image=None):
assert net_output.min() >= 0 and net_output.max() <= 1
edit = net_output[:, :3]
alpha = net_output[:, 3].unsqueeze(1).repeat(1, 3, 1, 1)
greenscreen = torch.zeros_like(edit).to(edit.device)
greenscreen[:, 1, :, :] = 177 / 255
greenscreen[:, 2, :, :] = 64 / 255
edit_on_greenscreen = alpha * edit + (1 - alpha) * greenscreen
outputs = {"edit": edit, "alpha": alpha, "edit_on_greenscreen": edit_on_greenscreen}
if bg_image is not None:
outputs["composite"] = (1 - alpha) * bg_image + alpha * edit
return outputs
def forward(self, input):
outputs = {}
# augmented examples
if "input_crop" in input:
outputs["output_crop"] = self.render(self.netG(input["input_crop"]), bg_image=input["input_crop"])
# pass the entire image (w/o augmentations)
if "input_image" in input:
outputs["output_image"] = self.render(self.netG(input["input_image"]), bg_image=input["input_image"])
# move outputs to list
for outer_key in outputs.keys():
for key, value in outputs[outer_key].items():
outputs[outer_key][key] = [value[0]]
return outputs