File size: 1,620 Bytes
581771b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from torchvision import transforms

from PIL import Image

import numpy as np

class Predictor:
    def __init__(self, st_model, device, img_size):
        self.device = device

        self.st_model = st_model.to(device)
        self.st_model.eval()
        
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]   
        
        self.transformer = transforms.Compose([
            transforms.Resize(img_size),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=self.mean, std=self.std)
        ])
        
    def eval_image(self, img, style_1, style_2=None, alpha=0.5):
        img = self.transformer(img).to(self.device)
        gen = self.st_model(img.unsqueeze(0), style_1, style_2, alpha)
        
        return Image.fromarray(np.uint8(np.moveaxis(gen[0].cpu().detach().numpy()*255.0, 0, 2)))

class WebcamPredictor:
    def __init__(self, st_model, device):
        self.device = device

        self.st_model = st_model.to(device)
        self.st_model.eval()
        
        self.mean = np.array([0.485, 0.456, 0.406])
        self.std = np.array([0.229, 0.224, 0.225])

        self.mean = np.expand_dims(self.mean, (1,2))
        self.std = np.expand_dims(self.std, (1,2))

    def eval_image(self, img, style_1, style_2=None, alpha=0.5):
        img = (img - self.mean) / self.std
        img = torch.from_numpy(img).to(self.device)
        img = img.float()

        gen = self.st_model(img.unsqueeze(0), style_1, style_2, alpha)

        return np.uint8(gen[0].cpu().detach().numpy()*255.0)