LogicGoInfotechSpaces commited on
Commit
7471c96
·
1 Parent(s): 2ae242d

Deploy: HF cache dir, img2img fallback, auth bypass, root route

Browse files
Files changed (3) hide show
  1. app/colorize_model.py +75 -48
  2. app/main.py +14 -0
  3. postman_collection.json +86 -0
app/colorize_model.py CHANGED
@@ -2,10 +2,11 @@
2
  ColorizeNet model wrapper for image colorization
3
  """
4
  import logging
 
5
  import torch
6
  import numpy as np
7
  from PIL import Image
8
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline
9
  from diffusers.utils import load_image
10
  from transformers import pipeline
11
  from huggingface_hub import hf_hub_download
@@ -29,60 +30,86 @@ class ColorizeModel:
29
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
30
  logger.info("Using device: %s", self.device)
31
  self.dtype = torch.float16 if self.device == "cuda" else torch.float32
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  try:
34
- # Try loading as ControlNet with Stable Diffusion
35
- logger.info("Attempting to load model as ControlNet: %s", self.model_id)
36
- try:
37
- # Load ControlNet model
38
- self.controlnet = ControlNetModel.from_pretrained(
39
- self.model_id,
40
- torch_dtype=self.dtype
41
- )
42
-
43
- # Try SDXL first, fallback to SD 1.5
44
  try:
45
- self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
46
- "stabilityai/stable-diffusion-xl-base-1.0",
47
- controlnet=self.controlnet,
48
- torch_dtype=self.dtype,
49
- safety_checker=None,
50
- requires_safety_checker=False
51
- )
52
- logger.info("Loaded with SDXL base model")
53
- except:
54
- self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
55
- "runwayml/stable-diffusion-v1-5",
56
- controlnet=self.controlnet,
57
  torch_dtype=self.dtype,
58
- safety_checker=None,
59
- requires_safety_checker=False
60
  )
61
- logger.info("Loaded with SD 1.5 base model")
62
-
63
- self.pipe.to(self.device)
64
-
65
- # Enable memory efficient attention if available
66
- if hasattr(self.pipe, "enable_xformers_memory_efficient_attention"):
67
  try:
68
- self.pipe.enable_xformers_memory_efficient_attention()
69
- logger.info("XFormers memory efficient attention enabled")
70
- except Exception as e:
71
- logger.warning("Could not enable XFormers: %s", str(e))
72
-
73
- logger.info("ColorizeNet model loaded successfully as ControlNet")
74
- self.model_type = "controlnet"
75
-
76
- except Exception as e:
77
- logger.warning("Failed to load as ControlNet: %s", str(e))
78
- # Fallback: try as image-to-image pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  logger.info("Trying to load as image-to-image pipeline...")
80
- self.pipe = pipeline(
81
- "image-to-image",
82
- model=self.model_id,
83
- device=0 if self.device == "cuda" else -1,
84
- torch_dtype=self.dtype
85
- )
 
 
 
86
  logger.info("ColorizeNet model loaded using image-to-image pipeline")
87
  self.model_type = "pipeline"
88
 
 
2
  ColorizeNet model wrapper for image colorization
3
  """
4
  import logging
5
+ import os
6
  import torch
7
  import numpy as np
8
  from PIL import Image
9
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionXLControlNetPipeline, StableDiffusionImg2ImgPipeline
10
  from diffusers.utils import load_image
11
  from transformers import pipeline
12
  from huggingface_hub import hf_hub_download
 
30
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
  logger.info("Using device: %s", self.device)
32
  self.dtype = torch.float16 if self.device == "cuda" else torch.float32
33
+ self.hf_token = os.getenv("HF_TOKEN") or None
34
+
35
+ # Configure writable cache to avoid permission issues on Spaces
36
+ hf_cache_dir = os.getenv("HF_HOME", "./hf_cache")
37
+ os.environ.setdefault("HF_HOME", hf_cache_dir)
38
+ os.environ.setdefault("HUGGINGFACE_HUB_CACHE", hf_cache_dir)
39
+ os.environ.setdefault("TRANSFORMERS_CACHE", hf_cache_dir)
40
+ os.makedirs(hf_cache_dir, exist_ok=True)
41
+
42
+ # Avoid libgomp warning by setting a valid integer
43
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
44
 
45
  try:
46
+ # Decide whether to use ControlNet based on model_id
47
+ wants_controlnet = "control" in self.model_id.lower()
48
+
49
+ if wants_controlnet:
50
+ # Try loading as ControlNet with Stable Diffusion
51
+ logger.info("Attempting to load model as ControlNet: %s", self.model_id)
 
 
 
 
52
  try:
53
+ # Load ControlNet model
54
+ self.controlnet = ControlNetModel.from_pretrained(
55
+ self.model_id,
 
 
 
 
 
 
 
 
 
56
  torch_dtype=self.dtype,
57
+ token=self.hf_token,
58
+ cache_dir=hf_cache_dir
59
  )
60
+
61
+ # Try SDXL first, fallback to SD 1.5
 
 
 
 
62
  try:
63
+ self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
64
+ "stabilityai/stable-diffusion-xl-base-1.0",
65
+ controlnet=self.controlnet,
66
+ torch_dtype=self.dtype,
67
+ safety_checker=None,
68
+ requires_safety_checker=False,
69
+ token=self.hf_token,
70
+ cache_dir=hf_cache_dir
71
+ )
72
+ logger.info("Loaded with SDXL base model")
73
+ except Exception:
74
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
75
+ "runwayml/stable-diffusion-v1-5",
76
+ controlnet=self.controlnet,
77
+ torch_dtype=self.dtype,
78
+ safety_checker=None,
79
+ requires_safety_checker=False,
80
+ token=self.hf_token,
81
+ cache_dir=hf_cache_dir
82
+ )
83
+ logger.info("Loaded with SD 1.5 base model")
84
+
85
+ self.pipe.to(self.device)
86
+
87
+ # Enable memory efficient attention if available
88
+ if hasattr(self.pipe, "enable_xformers_memory_efficient_attention"):
89
+ try:
90
+ self.pipe.enable_xformers_memory_efficient_attention()
91
+ logger.info("XFormers memory efficient attention enabled")
92
+ except Exception as e:
93
+ logger.warning("Could not enable XFormers: %s", str(e))
94
+
95
+ logger.info("ColorizeNet model loaded successfully as ControlNet")
96
+ self.model_type = "controlnet"
97
+ except Exception as e:
98
+ logger.warning("Failed to load as ControlNet: %s", str(e))
99
+ wants_controlnet = False # fall through to pipeline
100
+
101
+ if not wants_controlnet:
102
+ # Load as image-to-image pipeline
103
  logger.info("Trying to load as image-to-image pipeline...")
104
+ self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
105
+ self.model_id,
106
+ torch_dtype=self.dtype,
107
+ safety_checker=None,
108
+ requires_safety_checker=False,
109
+ use_safetensors=True,
110
+ cache_dir=hf_cache_dir,
111
+ token=self.hf_token
112
+ ).to(self.device)
113
  logger.info("ColorizeNet model loaded using image-to-image pipeline")
114
  self.model_type = "pipeline"
115
 
app/main.py CHANGED
@@ -74,6 +74,16 @@ app.mount("/uploads", StaticFiles(directory="uploads"), name="uploads")
74
  # Initialize ColorizeNet model
75
  colorize_model = None
76
 
 
 
 
 
 
 
 
 
 
 
77
  @app.on_event("startup")
78
  async def startup_event():
79
  """Initialize the colorization model on startup"""
@@ -109,6 +119,10 @@ async def verify_request(request: Request):
109
  - Firebase Auth id_token via Authorization: Bearer <id_token>
110
  - Firebase App Check token via X-Firebase-AppCheck (when ENABLE_APP_CHECK=true)
111
  """
 
 
 
 
112
  # Try Firebase Auth id_token first if present
113
  bearer = _extract_bearer_token(request.headers.get("Authorization"))
114
  if bearer:
 
74
  # Initialize ColorizeNet model
75
  colorize_model = None
76
 
77
+ @app.get("/")
78
+ async def root():
79
+ return {
80
+ "app": "Colorize API",
81
+ "version": "1.0.0",
82
+ "health": "/health",
83
+ "upload": "/upload",
84
+ "colorize": "/colorize"
85
+ }
86
+
87
  @app.on_event("startup")
88
  async def startup_event():
89
  """Initialize the colorization model on startup"""
 
119
  - Firebase Auth id_token via Authorization: Bearer <id_token>
120
  - Firebase App Check token via X-Firebase-AppCheck (when ENABLE_APP_CHECK=true)
121
  """
122
+ # If Firebase is not initialized or auth is explicitly disabled, allow
123
+ if not firebase_admin._apps or os.getenv("DISABLE_AUTH", "false").lower() == "true":
124
+ return True
125
+
126
  # Try Firebase Auth id_token first if present
127
  bearer = _extract_bearer_token(request.headers.get("Authorization"))
128
  if bearer:
postman_collection.json CHANGED
@@ -6,6 +6,65 @@
6
  "schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json"
7
  },
8
  "item": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  {
10
  "name": "Health",
11
  "request": {
@@ -27,6 +86,11 @@
27
  "request": {
28
  "method": "POST",
29
  "header": [
 
 
 
 
 
30
  {
31
  "key": "X-Firebase-AppCheck",
32
  "value": "{{app_check_token}}",
@@ -60,6 +124,11 @@
60
  "request": {
61
  "method": "POST",
62
  "header": [
 
 
 
 
 
63
  {
64
  "key": "X-Firebase-AppCheck",
65
  "value": "{{app_check_token}}",
@@ -93,6 +162,11 @@
93
  "request": {
94
  "method": "GET",
95
  "header": [
 
 
 
 
 
96
  {
97
  "key": "X-Firebase-AppCheck",
98
  "value": "{{app_check_token}}",
@@ -155,6 +229,18 @@
155
  "key": "base_url",
156
  "value": "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
157
  },
 
 
 
 
 
 
 
 
 
 
 
 
158
  {
159
  "key": "app_check_token",
160
  "value": ""
 
6
  "schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json"
7
  },
8
  "item": [
9
+ {
10
+ "name": "Authentication",
11
+ "item": [
12
+ {
13
+ "name": "Login (Firebase Auth - email/password)",
14
+ "event": [
15
+ {
16
+ "listen": "test",
17
+ "script": {
18
+ "type": "text/javascript",
19
+ "exec": [
20
+ "try {",
21
+ " const res = pm.response.json();",
22
+ " if (res.idToken) pm.collectionVariables.set('id_token', res.idToken);",
23
+ " if (res.refreshToken) pm.collectionVariables.set('refresh_token', res.refreshToken);",
24
+ " if (res.localId) pm.collectionVariables.set('local_id', res.localId);",
25
+ "} catch (e) {",
26
+ " console.log('Failed to parse login response', e);",
27
+ "}"
28
+ ]
29
+ }
30
+ }
31
+ ],
32
+ "request": {
33
+ "method": "POST",
34
+ "header": [
35
+ {
36
+ "key": "Content-Type",
37
+ "value": "application/json"
38
+ }
39
+ ],
40
+ "body": {
41
+ "mode": "raw",
42
+ "raw": "{\n \"email\": \"{{email}}\",\n \"password\": \"{{password}}\",\n \"returnSecureToken\": true\n}"
43
+ },
44
+ "url": {
45
+ "raw": "https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={{firebase_api_key}}",
46
+ "protocol": "https",
47
+ "host": [
48
+ "identitytoolkit",
49
+ "googleapis",
50
+ "com"
51
+ ],
52
+ "path": [
53
+ "v1",
54
+ "accounts:signInWithPassword"
55
+ ],
56
+ "query": [
57
+ {
58
+ "key": "key",
59
+ "value": "{{firebase_api_key}}"
60
+ }
61
+ ]
62
+ },
63
+ "description": "Obtain Firebase Auth id_token using email/password. Stores id_token in collection variable {{id_token}}."
64
+ }
65
+ }
66
+ ]
67
+ },
68
  {
69
  "name": "Health",
70
  "request": {
 
86
  "request": {
87
  "method": "POST",
88
  "header": [
89
+ {
90
+ "key": "Authorization",
91
+ "value": "Bearer {{id_token}}",
92
+ "type": "text"
93
+ },
94
  {
95
  "key": "X-Firebase-AppCheck",
96
  "value": "{{app_check_token}}",
 
124
  "request": {
125
  "method": "POST",
126
  "header": [
127
+ {
128
+ "key": "Authorization",
129
+ "value": "Bearer {{id_token}}",
130
+ "type": "text"
131
+ },
132
  {
133
  "key": "X-Firebase-AppCheck",
134
  "value": "{{app_check_token}}",
 
162
  "request": {
163
  "method": "GET",
164
  "header": [
165
+ {
166
+ "key": "Authorization",
167
+ "value": "Bearer {{id_token}}",
168
+ "type": "text"
169
+ },
170
  {
171
  "key": "X-Firebase-AppCheck",
172
  "value": "{{app_check_token}}",
 
229
  "key": "base_url",
230
  "value": "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
231
  },
232
+ {
233
+ "key": "firebase_api_key",
234
+ "value": "AIzaSyBIB6rcfyyqy5niERTXWvVD714Ter4Vx68"
235
+ },
236
+ {
237
+ "key": "email",
238
+ "value": "itisha.logico@gmail.com"
239
+ },
240
+ {
241
+ "key": "password",
242
+ "value": "123456"
243
+ },
244
  {
245
  "key": "app_check_token",
246
  "value": ""