Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import numpy as np | |
from PIL import Image | |
device = "cpu" | |
weights_path = "./weights/face_generator_v2.pth" | |
IN_CHANNELS = 100 | |
class FaceGenerator(nn.Module): | |
def __init__(self, in_channels): | |
super(FaceGenerator, self).__init__() | |
self.main = nn.Sequential( | |
nn.ConvTranspose2d(in_channels, 1024, 4, 2,0, bias=False), | |
nn.BatchNorm2d(1024), | |
nn.ReLU(True), # [batch_size, 1024, 2, 2] | |
nn.ConvTranspose2d(1024, 512, 4, 2,1, bias=False), | |
nn.BatchNorm2d(512), | |
nn.ReLU(True), # [batch_size, 512, 7, 7] | |
nn.ConvTranspose2d(512, 256, 4, 2,1, bias=False), | |
nn.BatchNorm2d(256), | |
nn.ReLU(True), # [batch_size, 256, 14, 14] | |
nn.ConvTranspose2d(256, 128, 4, 2,1, bias=False), | |
nn.BatchNorm2d(128), | |
nn.ReLU(True), # [batch_size, 256, 28, 28] | |
nn.ConvTranspose2d(128,3, 4, 2,1, bias=False), | |
nn.Sigmoid(), # [batch_size, 1, 32, 32] | |
) | |
def forward(self, x): | |
return self.main(x) | |
def load_model(): | |
model = FaceGenerator(IN_CHANNELS).to(device=device) | |
model.load_state_dict(torch.load(weights_path, map_location=torch.device(device)), strict=True) | |
model = model.eval() | |
print("[!] Model Loaded..") | |
return model | |
def generate(model): | |
noise = torch.randn((1,IN_CHANNELS, 1, 1)).to(device) | |
image_op = model(noise).squeeze() | |
image_op = image_op.permute(1,2,0).detach().cpu().numpy() | |
image_op = image_op * 255.0 | |
image_op = image_op.astype(np.uint8) | |
image_op = Image.fromarray(image_op) | |
image_op = image_op.resize((256, 256),resample=Image.ADAPTIVE) | |
return image_op | |