File size: 1,560 Bytes
16d007c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from .networks import define_G

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Model(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.netG = define_G(cfg).to(device)

    def render(self, net_output, bg_image=None):
        assert net_output.min() >= 0 and net_output.max() <= 1
        edit = net_output[:, :3]
        alpha = net_output[:, 3].unsqueeze(1).repeat(1, 3, 1, 1)
        greenscreen = torch.zeros_like(edit).to(edit.device)
        greenscreen[:, 1, :, :] = 177 / 255
        greenscreen[:, 2, :, :] = 64 / 255
        edit_on_greenscreen = alpha * edit + (1 - alpha) * greenscreen
        outputs = {"edit": edit, "alpha": alpha, "edit_on_greenscreen": edit_on_greenscreen}
        if bg_image is not None:
            outputs["composite"] = (1 - alpha) * bg_image + alpha * edit

        return outputs

    def forward(self, input):
        outputs = {}
        # augmented examples
        if "input_crop" in input:
            outputs["output_crop"] = self.render(self.netG(input["input_crop"]), bg_image=input["input_crop"])

        # pass the entire image (w/o augmentations)
        if "input_image" in input:
            outputs["output_image"] = self.render(self.netG(input["input_image"]), bg_image=input["input_image"])

        # move outputs to list
        for outer_key in outputs.keys():
            for key, value in outputs[outer_key].items():
                outputs[outer_key][key] = [value[0]]

        return outputs