radames commited on
Commit
8f77a4e
·
1 Parent(s): ee20381

only from the

Browse files
Files changed (1) hide show
  1. app.py +14 -193
app.py CHANGED
@@ -15,11 +15,14 @@ from pathlib import Path
15
  from db import Database
16
  import uuid
17
  import logging
 
 
 
18
  logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "0") == "1"
22
- SPACE_ID = os.environ.get('SPACE_ID', '')
23
 
24
  DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache")
25
  IMGS_PATH = DB_PATH / "imgs"
@@ -28,11 +31,6 @@ IMGS_PATH.mkdir(exist_ok=True, parents=True)
28
 
29
  database = Database(DB_PATH)
30
 
31
- with database() as db:
32
- cursor = db.cursor()
33
- cursor.execute("SELECT * FROM cache")
34
- print(list(cursor.fetchall()))
35
-
36
  dtype = torch.bfloat16
37
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
38
  if torch.cuda.is_available():
@@ -96,6 +94,7 @@ def generate(
96
  app = FastAPI()
97
  origins = [
98
  "http://huggingface.co",
 
99
  ]
100
 
101
  app.add_middleware(
@@ -107,6 +106,15 @@ app.add_middleware(
107
  )
108
 
109
 
 
 
 
 
 
 
 
 
 
110
  @app.get("/image")
111
  async def generate_image(prompt: str, negative_prompt: str, seed: int = 2134213213):
112
  cached_img = database.check(prompt, negative_prompt, seed)
@@ -137,190 +145,3 @@ async def main():
137
 
138
  if __name__ == "__main__":
139
  uvicorn.run(app, host="0.0.0.0", port=7860)
140
-
141
-
142
- # else:
143
- # prior_pipeline = None
144
- # decoder_pipeline = None
145
-
146
-
147
- # def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
148
- # if randomize_seed:
149
- # seed = random.randint(0, MAX_SEED)
150
- # return seed
151
-
152
-
153
- # def generate(
154
- # prompt: str,
155
- # negative_prompt: str = "",
156
- # seed: int = 0,
157
- # width: int = 1024,
158
- # height: int = 1024,
159
- # prior_num_inference_steps: int = 30,
160
- # # prior_timesteps: List[float] = None,
161
- # prior_guidance_scale: float = 4.0,
162
- # decoder_num_inference_steps: int = 12,
163
- # # decoder_timesteps: List[float] = None,
164
- # decoder_guidance_scale: float = 0.0,
165
- # num_images_per_prompt: int = 2,
166
- # progress=gr.Progress(track_tqdm=True),
167
- # ) -> PIL.Image.Image:
168
-
169
- # generator = torch.Generator().manual_seed(seed)
170
- # prior_output = prior_pipeline(
171
- # prompt=prompt,
172
- # height=height,
173
- # width=width,
174
- # num_inference_steps=prior_num_inference_steps,
175
- # timesteps=DEFAULT_STAGE_C_TIMESTEPS,
176
- # negative_prompt=negative_prompt,
177
- # guidance_scale=prior_guidance_scale,
178
- # num_images_per_prompt=num_images_per_prompt,
179
- # generator=generator,
180
- # )
181
- # decoder_output = decoder_pipeline(
182
- # image_embeddings=prior_output.image_embeddings,
183
- # prompt=prompt,
184
- # num_inference_steps=decoder_num_inference_steps,
185
- # # timesteps=decoder_timesteps,
186
- # guidance_scale=decoder_guidance_scale,
187
- # negative_prompt=negative_prompt,
188
- # generator=generator,
189
- # output_type="pil",
190
- # ).images
191
-
192
- # return decoder_output[0]
193
-
194
-
195
- # examples = [
196
- # "An astronaut riding a green horse",
197
- # "A mecha robot in a favela by Tarsila do Amaral",
198
- # "The sprirt of a Tamagotchi wandering in the city of Los Angeles",
199
- # "A delicious feijoada ramen dish"
200
- # ]
201
-
202
- # with gr.Blocks() as demo:
203
- # gr.Markdown(DESCRIPTION)
204
- # gr.DuplicateButton(
205
- # value="Duplicate Space for private use",
206
- # elem_id="duplicate-button",
207
- # visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
208
- # )
209
- # with gr.Group():
210
- # with gr.Row():
211
- # prompt = gr.Text(
212
- # label="Prompt",
213
- # show_label=False,
214
- # max_lines=1,
215
- # placeholder="Enter your prompt",
216
- # container=False,
217
- # )
218
- # run_button = gr.Button("Run", scale=0)
219
- # result = gr.Image(label="Result", show_label=False)
220
- # with gr.Accordion("Advanced options", open=False):
221
- # negative_prompt = gr.Text(
222
- # label="Negative prompt",
223
- # max_lines=1,
224
- # placeholder="Enter a Negative Prompt",
225
- # )
226
-
227
- # seed = gr.Slider(
228
- # label="Seed",
229
- # minimum=0,
230
- # maximum=MAX_SEED,
231
- # step=1,
232
- # value=0,
233
- # )
234
- # randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
235
- # with gr.Row():
236
- # width = gr.Slider(
237
- # label="Width",
238
- # minimum=1024,
239
- # maximum=1536,
240
- # step=512,
241
- # value=1024,
242
- # )
243
- # height = gr.Slider(
244
- # label="Height",
245
- # minimum=1024,
246
- # maximum=1536,
247
- # step=512,
248
- # value=1024,
249
- # )
250
- # num_images_per_prompt = gr.Slider(
251
- # label="Number of Images",
252
- # minimum=1,
253
- # maximum=2,
254
- # step=1,
255
- # value=1,
256
- # )
257
- # with gr.Row():
258
- # prior_guidance_scale = gr.Slider(
259
- # label="Prior Guidance Scale",
260
- # minimum=0,
261
- # maximum=20,
262
- # step=0.1,
263
- # value=4.0,
264
- # )
265
- # prior_num_inference_steps = gr.Slider(
266
- # label="Prior Inference Steps",
267
- # minimum=10,
268
- # maximum=30,
269
- # step=1,
270
- # value=20,
271
- # )
272
-
273
- # decoder_guidance_scale = gr.Slider(
274
- # label="Decoder Guidance Scale",
275
- # minimum=0,
276
- # maximum=0,
277
- # step=0.1,
278
- # value=0.0,
279
- # )
280
- # decoder_num_inference_steps = gr.Slider(
281
- # label="Decoder Inference Steps",
282
- # minimum=4,
283
- # maximum=12,
284
- # step=1,
285
- # value=10,
286
- # )
287
-
288
- # gr.Examples(
289
- # examples=examples,
290
- # inputs=prompt,
291
- # outputs=result,
292
- # fn=generate,
293
- # cache_examples=False,
294
- # )
295
-
296
- # inputs = [
297
- # prompt,
298
- # negative_prompt,
299
- # seed,
300
- # width,
301
- # height,
302
- # prior_num_inference_steps,
303
- # # prior_timesteps,
304
- # prior_guidance_scale,
305
- # decoder_num_inference_steps,
306
- # # decoder_timesteps,
307
- # decoder_guidance_scale,
308
- # num_images_per_prompt,
309
- # ]
310
- # gr.on(
311
- # triggers=[prompt.submit, negative_prompt.submit, run_button.click],
312
- # fn=randomize_seed_fn,
313
- # inputs=[seed, randomize_seed],
314
- # outputs=seed,
315
- # queue=False,
316
- # api_name=False,
317
- # ).then(
318
- # fn=generate,
319
- # inputs=inputs,
320
- # outputs=result,
321
- # api_name="run",
322
- # )
323
-
324
-
325
- # if __name__ == "__main__":
326
- # demo.queue(max_size=20).launch()
 
15
  from db import Database
16
  import uuid
17
  import logging
18
+ from fastapi import FastAPI, Request, HTTPException
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+
21
  logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
22
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
  USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "0") == "1"
25
+ SPACE_ID = os.environ.get("SPACE_ID", "")
26
 
27
  DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache")
28
  IMGS_PATH = DB_PATH / "imgs"
 
31
 
32
  database = Database(DB_PATH)
33
 
 
 
 
 
 
34
  dtype = torch.bfloat16
35
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
36
  if torch.cuda.is_available():
 
94
  app = FastAPI()
95
  origins = [
96
  "http://huggingface.co",
97
+ "localhost",
98
  ]
99
 
100
  app.add_middleware(
 
106
  )
107
 
108
 
109
+ @app.middleware("http")
110
+ async def validate_origin(request: Request, call_next):
111
+ logging.info(f"Request origin: {request.headers.get('origin')}")
112
+ if request.headers.get("origin") not in origins:
113
+ raise HTTPException(status_code=403, detail="Forbidden")
114
+ response = await call_next(request)
115
+ return response
116
+
117
+
118
  @app.get("/image")
119
  async def generate_image(prompt: str, negative_prompt: str, seed: int = 2134213213):
120
  cached_img = database.check(prompt, negative_prompt, seed)
 
145
 
146
  if __name__ == "__main__":
147
  uvicorn.run(app, host="0.0.0.0", port=7860)