saniaE commited on
Commit
e1aa346
·
1 Parent(s): efbd9ad

created fastapi

Browse files
Files changed (2) hide show
  1. app.py +118 -0
  2. models.py +31 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import torch
4
+ from torch import nn
5
+ from PIL import Image
6
+ import torchvision.utils as vutils
7
+ from fastapi import FastAPI, Response, HTTPException, Query
8
+ from fastapi.responses import StreamingResponse
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ from models import Generator
13
+
14
+ app = FastAPI()
15
+
16
+ # CORS Configuration
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ # Configuration Constants
25
+ Z_DIM = 100
26
+ DEVICE = torch.device("cpu")
27
+ REPO_ID = "SaniaE/GeoGen"
28
+ FILENAME = "dcgans_model_checkpoint.pt"
29
+
30
+ # Global model variable
31
+ gen_model = None
32
+
33
+ @app.on_event("startup")
34
+ def load_model():
35
+ global gen_model
36
+ try:
37
+ token = os.getenv("HF_TOKEN")
38
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=token)
39
+
40
+ checkpoint = torch.load(model_path, map_location=DEVICE)
41
+
42
+ gen_model = Generator(z_dim=Z_DIM).to(DEVICE)
43
+ gen_model.load_state_dict(checkpoint["gen_state_dict"])
44
+ gen_model.eval()
45
+ print("Model loaded successfully.")
46
+ except Exception as e:
47
+ print(f"Error loading model: {e}")
48
+
49
+
50
+ def postprocess_image(tensor):
51
+ # Unnormalize: tanh output [-1, 1] -> [0, 1]
52
+ img_tensor = (tensor + 1) / 2
53
+ img_tensor = img_tensor.clamp(0, 1)
54
+
55
+ # Use make_grid to handle single or batch images
56
+ grid = vutils.make_grid(img_tensor, padding=0, normalize=False)
57
+
58
+ # Convert to HWC format for PIL
59
+ ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
60
+ return Image.fromarray(ndarr)
61
+
62
+
63
+ def get_image_stream(tensor):
64
+ """Helper to convert tensor to a streaming-ready PNG."""
65
+ img_tensor = (tensor + 1) / 2
66
+ img_tensor = img_tensor.clamp(0, 1)
67
+ grid = vutils.make_grid(img_tensor, padding=0)
68
+ ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
69
+
70
+ pil_img = Image.fromarray(ndarr)
71
+ buf = io.BytesIO()
72
+ pil_img.save(buf, format="PNG")
73
+ buf.seek(0)
74
+ return buf
75
+
76
+
77
+ @app.get("/")
78
+ def read_root():
79
+ return {"status": "online", "model": REPO_ID}
80
+
81
+
82
+ @app.get("/generate")
83
+ def generate_random():
84
+ """Endpoint 1: Purely random generation for 'Discovery'."""
85
+ if gen_model is None: raise HTTPException(status_code=503)
86
+
87
+ with torch.inference_mode():
88
+ noise = torch.randn(1, Z_DIM, device=DEVICE)
89
+ fake_img = gen_model(noise)
90
+ return StreamingResponse(get_image_stream(fake_img), media_type="image/png")
91
+
92
+
93
+ @app.get("/explore")
94
+ def explore_latent(
95
+ seed: int,
96
+ x_shift: float = Query(0.0, ge=-5.0, le=5.0),
97
+ y_shift: float = Query(0.0, ge=-5.0, le=5.0)
98
+ ):
99
+ """Endpoint 2: Controlled generation for 'Tuning'."""
100
+ if gen_model is None: raise HTTPException(status_code=503)
101
+
102
+ try:
103
+ with torch.inference_mode():
104
+ # Use the seed to recreate the base 'personality' of the image
105
+ torch.manual_seed(seed)
106
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
107
+
108
+ noise = torch.randn(1, Z_DIM, device=DEVICE)
109
+
110
+ # Apply shifts to specific dimensions
111
+ noise[0, 0] += x_shift
112
+ noise[0, 1] += y_shift
113
+
114
+ fake_img = gen_model(noise)
115
+ return StreamingResponse(get_image_stream(fake_img), media_type="image/png")
116
+
117
+ except Exception as e:
118
+ raise HTTPException(status_code=500, detail=str(e))
models.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ class Generator(nn.Module):
4
+ def __init__(self, z_dim=100, input_channels=3, hidden_dim=64):
5
+ super(Generator, self).__init__()
6
+ self.z_dim = z_dim
7
+ self.gen = nn.Sequential(
8
+ self.generator_block(z_dim, hidden_dim * 32, stride=1, padding=0),
9
+ self.generator_block(hidden_dim * 32, hidden_dim * 16),
10
+ self.generator_block(hidden_dim * 16, hidden_dim * 8),
11
+ self.generator_block(hidden_dim * 8, hidden_dim * 4),
12
+ self.generator_block(hidden_dim * 4, hidden_dim * 2),
13
+ self.generator_block(hidden_dim * 2, hidden_dim),
14
+ self.generator_block(hidden_dim, input_channels, final_layer=True)
15
+ )
16
+
17
+ def generator_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=1, final_layer=False):
18
+ if not final_layer:
19
+ return nn.Sequential(
20
+ nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
21
+ nn.InstanceNorm2d(output_channels, affine=True),
22
+ nn.ReLU(inplace=True)
23
+ )
24
+ else:
25
+ return nn.Sequential(
26
+ nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding),
27
+ nn.Tanh()
28
+ )
29
+
30
+ def forward(self, noise):
31
+ return self.gen(noise.view(len(noise), self.z_dim, 1, 1))