riciii7 commited on
Commit
bfc3e04
·
verified ·
1 Parent(s): af9419e

feat: update vanillagan + dcgan

Browse files
Files changed (1) hide show
  1. app.py +28 -8
app.py CHANGED
@@ -1,28 +1,48 @@
1
  from fastapi import FastAPI, Query
 
2
  from fastapi.responses import StreamingResponse
3
- from utils import load_model_pt, generate_image_stylegan, load_model_pkl, generate_image_from_pkl, generate_image_vanillagan
4
- import random
5
 
6
  app = FastAPI()
 
 
 
 
 
 
 
 
7
  styleganv2 = load_model_pkl("styleganv2.pkl")
8
- gan = load_model_pt("generator_VanillaGAN.pt", model_type="vanillagan")
9
 
10
  @app.get("/")
11
  def root():
12
- return {"message": "Welcome to the FastAPI StyleGAN API"}
13
 
14
  @app.get("/ping")
15
  def ping():
16
  return {"status": "pong"}
17
 
 
 
 
 
 
18
  @app.get("/generate/styleganv2")
19
- def generate_styleganv2(seed: int = Query(-1)):
20
- if seed == -1:
21
- seed = random.randint(0, 65535)
22
  image_stream = generate_image_from_pkl(styleganv2, seed=seed, trunc=1)
23
  return StreamingResponse(image_stream, media_type="image/png")
24
 
 
 
 
 
 
 
 
 
 
 
25
  @app.get("/generate/vanillagan")
26
  def generate_vanillagan():
27
- image_stream = generate_image_vanillagan(gan)
28
  return StreamingResponse(image_stream, media_type="image/png")
 
1
  from fastapi import FastAPI, Query
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import StreamingResponse
4
+ from utils import load_model_pt, generate_image_stylegan, load_model_pkl, generate_image_from_pkl, generate_image_from_onnx
 
5
 
6
  app = FastAPI()
7
+
8
+ app.add_middleware(
9
+ CORSMiddleware,
10
+ allow_origins=["*"],
11
+ allow_methods=["GET"],
12
+ allow_headers=["*"],
13
+ allow_credentials=False, # gapake creds
14
+ )
15
  styleganv2 = load_model_pkl("styleganv2.pkl")
 
16
 
17
  @app.get("/")
18
  def root():
19
+ return {"message": "Welcome to the StyleGAN API!"}
20
 
21
  @app.get("/ping")
22
  def ping():
23
  return {"status": "pong"}
24
 
25
+ @app.get("/generate/stylegan")
26
+ def generate_stylegan_onnx():
27
+ image_stream = generate_image_from_onnx("stylegan.onnx", model='stylegan')
28
+ return StreamingResponse(image_stream, media_type="image/png")
29
+
30
  @app.get("/generate/styleganv2")
31
+ def generate_styleganv2(seed: int = Query(0)):
 
 
32
  image_stream = generate_image_from_pkl(styleganv2, seed=seed, trunc=1)
33
  return StreamingResponse(image_stream, media_type="image/png")
34
 
35
+ @app.get("/generate/progan")
36
+ def generate_progan():
37
+ image_stream = generate_image_from_onnx("progan.onnx", model='progan')
38
+ return StreamingResponse(image_stream, media_type="image/png")
39
+
40
+ @app.get("/generate/dcgan")
41
+ def generate_dcgan():
42
+ image_stream = generate_image_from_onnx("batik_dcgan.onnx", model='dcgan')
43
+ return StreamingResponse(image_stream, media_type="image/png")
44
+
45
  @app.get("/generate/vanillagan")
46
  def generate_vanillagan():
47
+ image_stream = generate_image_from_onnx("vanillagan.onnx", model='vanillagan')
48
  return StreamingResponse(image_stream, media_type="image/png")