Junathan Richie commited on
Commit
a3d2818
·
1 Parent(s): 697ebf3

feat: add vanilla gan

Browse files
Files changed (4) hide show
  1. app.py +16 -3
  2. model.py → stylegan_model.py +2 -2
  3. utils.py +20 -5
  4. vanillagan_model.py +32 -0
app.py CHANGED
@@ -1,10 +1,16 @@
1
  from fastapi import FastAPI, Query
2
  from fastapi.responses import StreamingResponse
3
- from utils import load_model_pt, generate_image_stylegan, load_model_pkl, generate_image_from_pkl
 
4
 
5
  app = FastAPI()
6
- stylegan = load_model_pt("model_128.pt")
7
  styleganv2 = load_model_pkl("styleganv2.pkl")
 
 
 
 
 
8
 
9
  @app.get("/ping")
10
  def ping():
@@ -16,7 +22,14 @@ def generate_stylegan():
16
  return StreamingResponse(image_stream, media_type="image/png")
17
 
18
  @app.get("/generate/styleganv2")
19
- def generate_styleganv2(seed: int = Query(0)):
 
 
20
  image_stream = generate_image_from_pkl(styleganv2, seed=seed, trunc=1)
21
  return StreamingResponse(image_stream, media_type="image/png")
22
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, Query
2
  from fastapi.responses import StreamingResponse
3
+ from utils import load_model_pt, generate_image_stylegan, load_model_pkl, generate_image_from_pkl, generate_image_vanillagan
4
+ import random
5
 
6
  app = FastAPI()
7
+ stylegan = load_model_pt("model_128.pt", model_type="stylegan")
8
  styleganv2 = load_model_pkl("styleganv2.pkl")
9
+ gan = load_model_pt("generator_VanillaGAN.pt", model_type="vanillagan")
10
+
11
+ @app.get("/")
12
+ def root():
13
+ return {"message": "Welcome to the FastAPI StyleGAN API"}
14
 
15
  @app.get("/ping")
16
  def ping():
 
22
  return StreamingResponse(image_stream, media_type="image/png")
23
 
24
  @app.get("/generate/styleganv2")
25
+ def generate_styleganv2(seed: int = Query(-1)):
26
+ if seed == -1:
27
+ seed = random.randint(0, 65535)
28
  image_stream = generate_image_from_pkl(styleganv2, seed=seed, trunc=1)
29
  return StreamingResponse(image_stream, media_type="image/png")
30
 
31
+ @app.get("/generate/vanillagan")
32
+ def generate_vanillagan():
33
+ image_stream = generate_image_vanillagan(gan)
34
+ return StreamingResponse(image_stream, media_type="image/png")
35
+
model.py → stylegan_model.py RENAMED
@@ -1,7 +1,7 @@
1
- from torch import nn, optim
2
  import torch
3
  from torch.nn import functional as F
4
- from typing import Any, Callable, Optional
5
  import math
6
 
7
  class WSLinear(nn.Module):
 
1
+ from torch import nn
2
  import torch
3
  from torch.nn import functional as F
4
+ from typing import Optional
5
  import math
6
 
7
  class WSLinear(nn.Module):
utils.py CHANGED
@@ -1,4 +1,5 @@
1
- from model import StyleGAN
 
2
  import torch
3
  from io import BytesIO
4
  from torchvision.utils import save_image
@@ -11,10 +12,14 @@ LATENT_FEATURES = 512
11
  RESOLUTION = 128
12
 
13
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
- def load_model_pt(path='model_128.pt'):
15
- model = StyleGAN(LATENT_FEATURES, RESOLUTION).to(DEVICE)
16
- last_checkpoint = torch.load(path, map_location=DEVICE)
17
- model.load_state_dict(last_checkpoint['generator'], strict=False)
 
 
 
 
18
  model.eval()
19
  return model
20
 
@@ -28,6 +33,16 @@ def generate_image_stylegan(generator, steps=5, alpha=1.0):
28
  save_image(image, buffer, format='PNG')
29
  buffer.seek(0)
30
  return buffer
 
 
 
 
 
 
 
 
 
 
31
 
32
  def load_model_pkl(path='styleganv2.pkl'):
33
  with open(path, 'rb') as f:
 
1
+ from stylegan_model import StyleGAN
2
+ from vanillagan_model import VanillaGAN
3
  import torch
4
  from io import BytesIO
5
  from torchvision.utils import save_image
 
12
  RESOLUTION = 128
13
 
14
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ def load_model_pt(path='model_128.pt',model_type='stylegan'):
16
+ if model_type == "stylegan":
17
+ model = StyleGAN(LATENT_FEATURES, RESOLUTION).to(DEVICE)
18
+ last_checkpoint = torch.load(path, map_location=DEVICE)
19
+ model.load_state_dict(last_checkpoint['generator'], strict=False)
20
+ elif model_type == "vanillagan":
21
+ model = VanillaGAN(RESOLUTION, LATENT_FEATURES).to(DEVICE)
22
+ model.load_state_dict(torch.load(path, map_location=DEVICE))
23
  model.eval()
24
  return model
25
 
 
33
  save_image(image, buffer, format='PNG')
34
  buffer.seek(0)
35
  return buffer
36
+
37
+ def generate_image_vanillagan(generator):
38
+ with torch.no_grad():
39
+ image = generator(torch.randn(1, LATENT_FEATURES, device=DEVICE)).view(1, 3, RESOLUTION, RESOLUTION)
40
+ image = (image * 0.5 + 0.5).clamp(0, 1)
41
+
42
+ buffer = BytesIO()
43
+ save_image(image, buffer, format='PNG')
44
+ buffer.seek(0)
45
+ return buffer
46
 
47
  def load_model_pkl(path='styleganv2.pkl'):
48
  with open(path, 'rb') as f:
vanillagan_model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, optim
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from typing import Any, Callable, Optional
5
+ import math
6
+
7
+ class VanillaGAN(nn.Module):
8
+ def __init__(self, resolution, latent_dim, hidden_dim=512, channels=3):
9
+ super(VanillaGAN, self).__init__()
10
+ output_dim = resolution * resolution * channels
11
+
12
+ self.layers = nn.Sequential(
13
+ self.gen_block(latent_dim, hidden_dim),
14
+ self.gen_block(hidden_dim, hidden_dim*2),
15
+ self.gen_block(hidden_dim*2, hidden_dim*2),
16
+ self.gen_block(hidden_dim*2, hidden_dim),
17
+ self.gen_block(hidden_dim, hidden_dim),
18
+ self.gen_block(hidden_dim, hidden_dim//2),
19
+
20
+ nn.Linear(hidden_dim//2, output_dim),
21
+ nn.Tanh()
22
+ )
23
+
24
+ def gen_block(self, input_dim, output_dim):
25
+ return nn.Sequential(
26
+ nn.Linear(input_dim, output_dim, bias=False),
27
+ nn.BatchNorm1d(output_dim, 0.8),
28
+ nn.LeakyReLU(0.2)
29
+ )
30
+
31
+ def forward(self, x):
32
+ return self.layers(x)