Spaces:
Runtime error
Runtime error
File size: 1,620 Bytes
b456239 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
class Predictor:
def __init__(self, st_model, device, img_size):
self.device = device
self.st_model = st_model.to(device)
self.st_model.eval()
self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]
self.transformer = transforms.Compose([
transforms.Resize(img_size),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std)
])
def eval_image(self, img, style_1, style_2=None, alpha=0.5):
img = self.transformer(img).to(self.device)
gen = self.st_model(img.unsqueeze(0), style_1, style_2, alpha)
return Image.fromarray(np.uint8(np.moveaxis(gen[0].cpu().detach().numpy()*255.0, 0, 2)))
class WebcamPredictor:
def __init__(self, st_model, device):
self.device = device
self.st_model = st_model.to(device)
self.st_model.eval()
self.mean = np.array([0.485, 0.456, 0.406])
self.std = np.array([0.229, 0.224, 0.225])
self.mean = np.expand_dims(self.mean, (1,2))
self.std = np.expand_dims(self.std, (1,2))
def eval_image(self, img, style_1, style_2=None, alpha=0.5):
img = (img - self.mean) / self.std
img = torch.from_numpy(img).to(self.device)
img = img.float()
gen = self.st_model(img.unsqueeze(0), style_1, style_2, alpha)
return np.uint8(gen[0].cpu().detach().numpy()*255.0) |