Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from PIL import Image | |
| device = "cpu" | |
| weights_path = "./weights/face_generator_v2.pth" | |
| IN_CHANNELS = 100 | |
| class FaceGenerator(nn.Module): | |
| def __init__(self, in_channels): | |
| super(FaceGenerator, self).__init__() | |
| self.main = nn.Sequential( | |
| nn.ConvTranspose2d(in_channels, 1024, 4, 2,0, bias=False), | |
| nn.BatchNorm2d(1024), | |
| nn.ReLU(True), # [batch_size, 1024, 2, 2] | |
| nn.ConvTranspose2d(1024, 512, 4, 2,1, bias=False), | |
| nn.BatchNorm2d(512), | |
| nn.ReLU(True), # [batch_size, 512, 7, 7] | |
| nn.ConvTranspose2d(512, 256, 4, 2,1, bias=False), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(True), # [batch_size, 256, 14, 14] | |
| nn.ConvTranspose2d(256, 128, 4, 2,1, bias=False), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(True), # [batch_size, 256, 28, 28] | |
| nn.ConvTranspose2d(128,3, 4, 2,1, bias=False), | |
| nn.Sigmoid(), # [batch_size, 1, 32, 32] | |
| ) | |
| def forward(self, x): | |
| return self.main(x) | |
| def load_model(): | |
| model = FaceGenerator(IN_CHANNELS).to(device=device) | |
| model.load_state_dict(torch.load(weights_path, map_location=torch.device(device)), strict=True) | |
| model = model.eval() | |
| print("[!] Model Loaded..") | |
| return model | |
| def generate(model): | |
| noise = torch.randn((1,IN_CHANNELS, 1, 1)).to(device) | |
| image_op = model(noise).squeeze() | |
| image_op = image_op.permute(1,2,0).detach().cpu().numpy() | |
| image_op = image_op * 255.0 | |
| image_op = image_op.astype(np.uint8) | |
| image_op = Image.fromarray(image_op) | |
| image_op = image_op.resize((256, 256),resample=Image.ADAPTIVE) | |
| return image_op | |