import torch from torchvision import transforms from PIL import Image import numpy as np class Predictor: def __init__(self, st_model, device, img_size): self.device = device self.st_model = st_model.to(device) self.st_model.eval() self.mean = [0.485, 0.456, 0.406] self.std = [0.229, 0.224, 0.225] self.transformer = transforms.Compose([ transforms.Resize(img_size), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std) ]) def eval_image(self, img, style_1, style_2=None, alpha=0.5): img = self.transformer(img).to(self.device) gen = self.st_model(img.unsqueeze(0), style_1, style_2, alpha) return Image.fromarray(np.uint8(np.moveaxis(gen[0].cpu().detach().numpy()*255.0, 0, 2))) class WebcamPredictor: def __init__(self, st_model, device): self.device = device self.st_model = st_model.to(device) self.st_model.eval() self.mean = np.array([0.485, 0.456, 0.406]) self.std = np.array([0.229, 0.224, 0.225]) self.mean = np.expand_dims(self.mean, (1,2)) self.std = np.expand_dims(self.std, (1,2)) def eval_image(self, img, style_1, style_2=None, alpha=0.5): img = (img - self.mean) / self.std img = torch.from_numpy(img).to(self.device) img = img.float() gen = self.st_model(img.unsqueeze(0), style_1, style_2, alpha) return np.uint8(gen[0].cpu().detach().numpy()*255.0)