|
import torch |
|
from PIL import Image |
|
from torchvision import transforms |
|
import torch.nn as nn |
|
|
|
class Generator(nn.Module): |
|
def __init__(self, input_size, output_channels): |
|
super(Generator, self).__init__() |
|
|
|
|
|
self.model = nn.Sequential( |
|
nn.Linear(input_size, 128), |
|
nn.LeakyReLU(0.2), |
|
nn.Linear(128, 256), |
|
nn.BatchNorm1d(256), |
|
nn.LeakyReLU(0.2), |
|
nn.Linear(256, 512), |
|
nn.BatchNorm1d(512), |
|
nn.LeakyReLU(0.2), |
|
nn.Linear(512, output_channels), |
|
nn.Tanh() |
|
) |
|
|
|
def forward(self, x): |
|
|
|
return self.model(x) |
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path=""): |
|
""" |
|
Initialize model |
|
""" |
|
self.model = Generator() |
|
|
|
def generate_random_image(self): |
|
""" |
|
Generate a random image using the GAN model. |
|
Return: |
|
A :obj:`PIL.Image` with the generated image. |
|
""" |
|
noise = torch.randn(1, 100, 1, 1) |
|
with torch.no_grad(): |
|
output_image = self.model(noise) |
|
|
|
|
|
output_image = (output_image + 1) / 2 |
|
|
|
|
|
pil_image = transforms.ToPILImage()(output_image[0]) |
|
|
|
return pil_image |
|
|
|
|
|
if __name__ == "__main__": |
|
pipeline = PreTrainedPipeline() |
|
generated_image = pipeline.generate_random_image() |
|
generated_image.save('generated_image.jpg') |
|
print("Generated image saved at 'generated_image.jpg'") |
|
|