LogicGoInfotechSpaces commited on
Commit
ab9de00
·
verified ·
1 Parent(s): b860aca

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +156 -98
app/main.py CHANGED
@@ -1,22 +1,20 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException, Header
2
  from fastapi.responses import FileResponse
3
  from huggingface_hub import hf_hub_download
4
- import uuid
5
- import os
6
- import io
7
- import json
8
  from PIL import Image
9
  import torch
10
- from torchvision import transforms
 
11
 
12
- # -------------------------------------------------
13
- # 🚀 FastAPI App
14
- # -------------------------------------------------
15
- app = FastAPI(title="Text-Guided Image Colorization API")
16
 
17
- # -------------------------------------------------
18
- # 🔐 Firebase Initialization (ENV-based)
19
- # -------------------------------------------------
20
  try:
21
  import firebase_admin
22
  from firebase_admin import credentials, app_check
@@ -34,142 +32,202 @@ try:
34
  except Exception as e:
35
  print("❌ Firebase initialization failed:", e)
36
 
37
- # -------------------------------------------------
38
- # 📁 Directories (FIXED FOR HUGGINGFACE SPACES)
39
- # -------------------------------------------------
40
- UPLOAD_DIR = "/tmp/uploads"
41
- RESULTS_DIR = "/tmp/results"
42
  os.makedirs(UPLOAD_DIR, exist_ok=True)
43
  os.makedirs(RESULTS_DIR, exist_ok=True)
44
 
45
- # -------------------------------------------------
46
- # 🧠 Load GAN Colorization Model
47
- # -------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  MODEL_REPO = "Hammad712/GAN-Colorization-Model"
49
  MODEL_FILENAME = "generator.pt"
50
 
51
  print("⬇️ Downloading model...")
52
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
53
 
54
- print("📦 Loading model weights...")
 
55
  state_dict = torch.load(model_path, map_location="cpu")
 
 
56
 
57
- # NOTE: Replace with real model architecture
58
- # from model import ColorizeNet
59
- # model = ColorizeNet()
60
- # model.load_state_dict(state_dict)
61
- # model.eval()
62
 
63
  def colorize_image(img: Image.Image):
64
- """ Dummy colorizer (replace with real model.predict) """
65
- transform = transforms.ToTensor()
66
- tensor = transform(img.convert("L")).unsqueeze(0)
67
- tensor = tensor.repeat(1, 3, 1, 1)
68
- output_img = transforms.ToPILImage()(tensor.squeeze())
69
- return output_img
70
-
71
- # -------------------------------------------------
72
- # 🩺 Health Check
73
- # -------------------------------------------------
74
- @app.get("/health")
75
- def health_check():
76
- return {"status": "healthy", "model_loaded": True}
77
 
78
- # -------------------------------------------------
79
- # 🔐 Firebase Token Validator
80
- # -------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def verify_app_check_token(token: str):
82
  if not token or len(token) < 20:
83
  raise HTTPException(status_code=401, detail="Invalid Firebase App Check token")
84
  return True
85
 
86
- # -------------------------------------------------
87
- # 📤 Upload Image
88
- # -------------------------------------------------
 
 
 
 
 
 
 
89
  @app.post("/upload")
90
- async def upload_image(
91
- file: UploadFile = File(...),
92
- x_firebase_appcheck: str = Header(None)
93
- ):
94
- verify_app_check_token(x_firebase_appcheck)
95
 
96
- if not file.content_type.startswith("image/"):
97
- raise HTTPException(status_code=400, detail="Invalid file type")
98
 
99
  image_id = f"{uuid.uuid4()}.jpg"
100
- file_path = os.path.join(UPLOAD_DIR, image_id)
101
 
102
- with open(file_path, "wb") as f:
103
  f.write(await file.read())
104
 
105
- base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
106
 
107
  return {
108
  "success": True,
109
- "image_id": image_id.replace(".jpg", ""),
110
- "file_url": f"{base_url}/uploads/{image_id}"
111
  }
112
 
113
- # -------------------------------------------------
114
- # 🎨 Colorize Image
115
- # -------------------------------------------------
116
  @app.post("/colorize")
117
- async def colorize(
118
- file: UploadFile = File(...),
119
- x_firebase_appcheck: str = Header(None)
120
- ):
121
- verify_app_check_token(x_firebase_appcheck)
122
 
123
- if not file.content_type.startswith("image/"):
124
- raise HTTPException(status_code=400, detail="Invalid file type")
125
 
126
  img = Image.open(io.BytesIO(await file.read()))
127
  output_img = colorize_image(img)
128
 
129
  result_id = f"{uuid.uuid4()}.jpg"
130
- output_path = os.path.join(RESULTS_DIR, result_id)
131
- output_img.save(output_path)
132
 
133
- base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
134
 
135
  return {
136
  "success": True,
137
- "result_id": result_id.replace(".jpg", ""),
138
- "download_url": f"{base_url}/results/{result_id}",
139
- "api_download": f"{base_url}/download/{result_id.replace('.jpg','')}"
140
  }
141
 
142
- # -------------------------------------------------
143
- # ⬇️ Download via API (Secure)
144
- # -------------------------------------------------
145
- @app.get("/download/{file_id}")
146
- def download_result(file_id: str, x_firebase_appcheck: str = Header(None)):
147
- verify_app_check_token(x_firebase_appcheck)
148
-
149
- filename = f"{file_id}.jpg"
150
- path = os.path.join(RESULTS_DIR, filename)
151
-
152
- if not os.path.exists(path):
153
- raise HTTPException(status_code=404, detail="Result not found")
154
-
155
- return FileResponse(path, media_type="image/jpeg")
156
-
157
- # -------------------------------------------------
158
- # 🌐 Public Result File
159
- # -------------------------------------------------
160
  @app.get("/results/{filename}")
161
  def get_result(filename: str):
162
  path = os.path.join(RESULTS_DIR, filename)
163
  if not os.path.exists(path):
164
- raise HTTPException(status_code=404, detail="Result not found")
165
- return FileResponse(path, media_type="image/jpeg")
166
 
167
- # -------------------------------------------------
168
- # 🌐 Public Uploaded File
169
- # -------------------------------------------------
170
  @app.get("/uploads/{filename}")
171
  def get_upload(filename: str):
172
  path = os.path.join(UPLOAD_DIR, filename)
173
  if not os.path.exists(path):
174
- raise HTTPException(status_code=404, detail="File not found")
175
- return FileResponse(path, media_type="image/jpeg")
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException, Header
2
  from fastapi.responses import FileResponse
3
  from huggingface_hub import hf_hub_download
4
+ from torchvision import transforms
 
 
 
5
  from PIL import Image
6
  import torch
7
+ import torch.nn as nn
8
+ import os, uuid, io, json
9
 
10
+ # ======================================================
11
+ # 🚀 FASTAPI APP
12
+ # ======================================================
13
+ app = FastAPI(title="UNet Image Colorization API")
14
 
15
+ # ======================================================
16
+ # 🔐 FIREBASE INITIALIZATION (ENV BASED)
17
+ # ======================================================
18
  try:
19
  import firebase_admin
20
  from firebase_admin import credentials, app_check
 
32
  except Exception as e:
33
  print("❌ Firebase initialization failed:", e)
34
 
35
+ # ======================================================
36
+ # 📁 DIRECTORIES
37
+ # ======================================================
38
+ UPLOAD_DIR = "uploads"
39
+ RESULTS_DIR = "results"
40
  os.makedirs(UPLOAD_DIR, exist_ok=True)
41
  os.makedirs(RESULTS_DIR, exist_ok=True)
42
 
43
+ # ======================================================
44
+ # 🧠 SIMPLE UNET GENERATOR FOR COLORIZATION
45
+ # ======================================================
46
+ class UNet(nn.Module):
47
+ def __init__(self):
48
+ super(UNet, self).__init__()
49
+
50
+ def CBR(in_c, out_c):
51
+ return nn.Sequential(
52
+ nn.Conv2d(in_c, out_c, 3, padding=1),
53
+ nn.BatchNorm2d(out_c),
54
+ nn.ReLU(inplace=True)
55
+ )
56
+
57
+ self.enc1 = CBR(1, 64)
58
+ self.enc2 = CBR(64, 128)
59
+ self.enc3 = CBR(128, 256)
60
+ self.enc4 = CBR(256, 512)
61
+
62
+ self.pool = nn.MaxPool2d(2)
63
+
64
+ self.middle = CBR(512, 512)
65
+
66
+ self.up4 = nn.ConvTranspose2d(512, 256, 2, stride=2)
67
+ self.dec4 = CBR(512, 256)
68
+
69
+ self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
70
+ self.dec3 = CBR(256, 128)
71
+
72
+ self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
73
+ self.dec2 = CBR(128, 64)
74
+
75
+ self.out_layer = nn.Conv2d(64, 2, 1) # ab channels
76
+
77
+ def forward(self, x):
78
+ c1 = self.enc1(x)
79
+ p1 = self.pool(c1)
80
+
81
+ c2 = self.enc2(p1)
82
+ p2 = self.pool(c2)
83
+
84
+ c3 = self.enc3(p2)
85
+ p3 = self.pool(c3)
86
+
87
+ c4 = self.enc4(p3)
88
+ p4 = self.pool(c4)
89
+
90
+ mid = self.middle(p4)
91
+
92
+ u4 = self.up4(mid)
93
+ u4 = torch.cat([u4, c4], dim=1)
94
+ d4 = self.dec4(u4)
95
+
96
+ u3 = self.up3(d4)
97
+ u3 = torch.cat([u3, c3], dim=1)
98
+ d3 = self.dec3(u3)
99
+
100
+ u2 = self.up2(d3)
101
+ u2 = torch.cat([u2, c2], dim=1)
102
+ d2 = self.dec2(u2)
103
+
104
+ out = self.out_layer(d2)
105
+ return out
106
+
107
+
108
+ # ======================================================
109
+ # 🎨 LOAD MODEL WEIGHTS FROM HF
110
+ # ======================================================
111
  MODEL_REPO = "Hammad712/GAN-Colorization-Model"
112
  MODEL_FILENAME = "generator.pt"
113
 
114
  print("⬇️ Downloading model...")
115
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
116
 
117
+ print("📦 Loading weights into UNet model...")
118
+ model = UNet()
119
  state_dict = torch.load(model_path, map_location="cpu")
120
+ model.load_state_dict(state_dict, strict=False)
121
+ model.eval()
122
 
123
+ # ======================================================
124
+ # 🎨 COLORIZE FUNCTION (LAB → RGB)
125
+ # ======================================================
126
+ import numpy as np
127
+ import cv2
128
 
129
  def colorize_image(img: Image.Image):
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ img = img.convert("L") # grayscale
132
+ img_np = np.array(img)
133
+
134
+ # Normalize L channel
135
+ L = img_np.astype("float32") / 255.0
136
+ L_tensor = torch.tensor(L).unsqueeze(0).unsqueeze(0)
137
+
138
+ with torch.no_grad():
139
+ ab = model(L_tensor).squeeze(0).numpy()
140
+
141
+ ab = np.transpose(ab, (1, 2, 0))
142
+
143
+ # Resize ab to match L
144
+ ab = cv2.resize(ab, (img_np.shape[1], img_np.shape[0]))
145
+
146
+ # Combine L + ab -> LAB image
147
+ LAB = np.zeros((img_np.shape[0], img_np.shape[1], 3), dtype=np.float32)
148
+ LAB[..., 0] = L * 100
149
+ LAB[..., 1:] = ab * 128
150
+
151
+ # Convert LAB → RGB
152
+ rgb = cv2.cvtColor(LAB.astype("float32"), cv2.COLOR_LAB2RGB)
153
+ rgb = np.clip(rgb, 0, 1)
154
+
155
+ rgb_img = Image.fromarray((rgb * 255).astype("uint8"))
156
+ return rgb_img
157
+
158
+ # ======================================================
159
+ # 🔐 FIREBASE CHECK
160
+ # ======================================================
161
  def verify_app_check_token(token: str):
162
  if not token or len(token) < 20:
163
  raise HTTPException(status_code=401, detail="Invalid Firebase App Check token")
164
  return True
165
 
166
+ # ======================================================
167
+ # 🩺 HEALTH CHECK
168
+ # ======================================================
169
+ @app.get("/health")
170
+ def health_check():
171
+ return {"status": "healthy", "unet_loaded": True}
172
+
173
+ # ======================================================
174
+ # 📤 UPLOAD
175
+ # ======================================================
176
  @app.post("/upload")
177
+ async def upload_image(file: UploadFile = File(...), x_firebase_appcheck: str = Header(None)):
 
 
 
 
178
 
179
+ verify_app_check_token(x_firebase_appcheck)
 
180
 
181
  image_id = f"{uuid.uuid4()}.jpg"
182
+ path = os.path.join(UPLOAD_DIR, image_id)
183
 
184
+ with open(path, "wb") as f:
185
  f.write(await file.read())
186
 
187
+ base = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
188
 
189
  return {
190
  "success": True,
191
+ "image_id": image_id[:-4],
192
+ "url": f"{base}/uploads/{image_id}"
193
  }
194
 
195
+ # ======================================================
196
+ # 🎨 COLORIZE
197
+ # ======================================================
198
  @app.post("/colorize")
199
+ async def colorize(file: UploadFile = File(...), x_firebase_appcheck: str = Header(None)):
 
 
 
 
200
 
201
+ verify_app_check_token(x_firebase_appcheck)
 
202
 
203
  img = Image.open(io.BytesIO(await file.read()))
204
  output_img = colorize_image(img)
205
 
206
  result_id = f"{uuid.uuid4()}.jpg"
207
+ path = os.path.join(RESULTS_DIR, result_id)
208
+ output_img.save(path)
209
 
210
+ base = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
211
 
212
  return {
213
  "success": True,
214
+ "result_id": result_id[:-4],
215
+ "url": f"{base}/results/{result_id}"
 
216
  }
217
 
218
+ # ======================================================
219
+ # PUBLIC FILE ENDPOINTS
220
+ # ======================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  @app.get("/results/{filename}")
222
  def get_result(filename: str):
223
  path = os.path.join(RESULTS_DIR, filename)
224
  if not os.path.exists(path):
225
+ raise HTTPException(status_code=404)
226
+ return FileResponse(path)
227
 
 
 
 
228
  @app.get("/uploads/{filename}")
229
  def get_upload(filename: str):
230
  path = os.path.join(UPLOAD_DIR, filename)
231
  if not os.path.exists(path):
232
+ raise HTTPException(status_code=404)
233
+ return FileResponse(path)