Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch.nn as nn | |
import torch | |
class Discriminator(nn.Module): | |
def __init__( | |
self, | |
image_shape: (int, int, int), | |
use_cuda: bool = False, | |
saved_model: str or None = None | |
): | |
super(Discriminator, self).__init__() | |
self.model = nn.Sequential( | |
nn.Linear(int(np.prod(image_shape)), 512), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Linear(512, 256), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Linear(256, 1), | |
) | |
if saved_model is not None: | |
self.model.load_state_dict( | |
torch.load( | |
saved_model, | |
map_location=torch.device('cuda' if use_cuda else 'cpu') | |
) | |
) | |
def forward(self, img): | |
img_flat = img.view(img.shape[0], -1) | |
validity = self.model(img_flat) | |
return validity | |
def save(self, to): | |
torch.save(self.model.state_dict(), to) | |