radames commited on
Commit
ee20381
1 Parent(s): 22e6fb5

db for cache

Browse files
Files changed (5) hide show
  1. .gitignore +3 -1
  2. app.py +105 -33
  3. db.py +54 -0
  4. requirements.txt +6 -5
  5. schema.sql +14 -0
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  venv
2
- gradio_cached_examples
 
 
 
1
  venv
2
+ gradio_cached_examples
3
+ __pycache__/
4
+ cache/
app.py CHANGED
@@ -1,6 +1,3 @@
1
- import os
2
- import random
3
- import gradio as gr
4
  import numpy as np
5
  import PIL.Image
6
  import torch
@@ -10,23 +7,95 @@ from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
10
  from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
11
  from fastapi import FastAPI
12
  import uvicorn
13
- from pydantic import BaseModel
14
  from fastapi.middleware.cors import CORSMiddleware
15
- from fastapi.responses import RedirectResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
- class GenerateRequest(BaseModel):
19
- prompt: str
20
- negative_prompt: str = ""
21
- seed: int = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  app = FastAPI()
25
  origins = [
26
- "http://localhost.tiangolo.com",
27
- "https://localhost.tiangolo.com",
28
- "http://localhost",
29
- "http://localhost:8080",
30
  ]
31
 
32
  app.add_middleware(
@@ -38,34 +107,37 @@ app.add_middleware(
38
  )
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  @app.get("/")
42
  async def main():
43
  # redirect to https://huggingface.co/spaces/multimodalart/stable-cascade
44
- return RedirectResponse("https://multimodalart-stable-cascade.hf.space/?__theme=system")
 
 
45
 
46
 
47
  if __name__ == "__main__":
48
  uvicorn.run(app, host="0.0.0.0", port=7860)
49
 
50
- # MAX_SEED = np.iinfo(np.int32).max
51
- # USE_TORCH_COMPILE = False
52
-
53
- # dtype = torch.bfloat16
54
- # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
55
- # if torch.cuda.is_available():
56
- # prior_pipeline = StableCascadePriorPipeline.from_pretrained(
57
- # "stabilityai/stable-cascade-prior", torch_dtype=dtype) # .to(device)
58
- # decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained(
59
- # "stabilityai/stable-cascade", torch_dtype=dtype) # .to(device)
60
- # prior_pipeline.to(device)
61
- # decoder_pipeline.to(device)
62
-
63
- # if USE_TORCH_COMPILE:
64
- # prior_pipeline.prior = torch.compile(
65
- # prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
66
- # decoder_pipeline.decoder = torch.compile(
67
- # decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
68
-
69
 
70
  # else:
71
  # prior_pipeline = None
 
 
 
 
1
  import numpy as np
2
  import PIL.Image
3
  import torch
 
7
  from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
8
  from fastapi import FastAPI
9
  import uvicorn
 
10
  from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.responses import RedirectResponse, StreamingResponse
12
+ import io
13
+ import os
14
+ from pathlib import Path
15
+ from db import Database
16
+ import uuid
17
+ import logging
18
+ logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
19
+
20
+ MAX_SEED = np.iinfo(np.int32).max
21
+ USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "0") == "1"
22
+ SPACE_ID = os.environ.get('SPACE_ID', '')
23
+
24
+ DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache")
25
+ IMGS_PATH = DB_PATH / "imgs"
26
+ DB_PATH.mkdir(exist_ok=True, parents=True)
27
+ IMGS_PATH.mkdir(exist_ok=True, parents=True)
28
+
29
+ database = Database(DB_PATH)
30
+
31
+ with database() as db:
32
+ cursor = db.cursor()
33
+ cursor.execute("SELECT * FROM cache")
34
+ print(list(cursor.fetchall()))
35
+
36
+ dtype = torch.bfloat16
37
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
38
+ if torch.cuda.is_available():
39
+ prior_pipeline = StableCascadePriorPipeline.from_pretrained(
40
+ "stabilityai/stable-cascade-prior", torch_dtype=dtype
41
+ ) # .to(device)
42
+ decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained(
43
+ "stabilityai/stable-cascade", torch_dtype=dtype
44
+ ) # .to(device)
45
+ prior_pipeline.to(device)
46
+ decoder_pipeline.to(device)
47
+
48
+ if USE_TORCH_COMPILE:
49
+ prior_pipeline.prior = torch.compile(
50
+ prior_pipeline.prior, mode="reduce-overhead", fullgraph=True
51
+ )
52
+ decoder_pipeline.decoder = torch.compile(
53
+ decoder_pipeline.decoder, mode="max-autotune", fullgraph=True
54
+ )
55
 
56
 
57
+ def generate(
58
+ prompt: str,
59
+ negative_prompt: str = "",
60
+ seed: int = 0,
61
+ width: int = 1024,
62
+ height: int = 1024,
63
+ prior_num_inference_steps: int = 20,
64
+ prior_guidance_scale: float = 4.0,
65
+ decoder_num_inference_steps: int = 10,
66
+ decoder_guidance_scale: float = 0.0,
67
+ num_images_per_prompt: int = 2,
68
+ ) -> PIL.Image.Image:
69
+
70
+ generator = torch.Generator().manual_seed(seed)
71
+ prior_output = prior_pipeline(
72
+ prompt=prompt,
73
+ height=height,
74
+ width=width,
75
+ num_inference_steps=prior_num_inference_steps,
76
+ timesteps=DEFAULT_STAGE_C_TIMESTEPS,
77
+ negative_prompt=negative_prompt,
78
+ guidance_scale=prior_guidance_scale,
79
+ num_images_per_prompt=num_images_per_prompt,
80
+ generator=generator,
81
+ )
82
+ decoder_output = decoder_pipeline(
83
+ image_embeddings=prior_output.image_embeddings,
84
+ prompt=prompt,
85
+ num_inference_steps=decoder_num_inference_steps,
86
+ # timesteps=decoder_timesteps,
87
+ guidance_scale=decoder_guidance_scale,
88
+ negative_prompt=negative_prompt,
89
+ generator=generator,
90
+ output_type="pil",
91
+ ).images
92
+
93
+ return decoder_output[0]
94
 
95
 
96
  app = FastAPI()
97
  origins = [
98
+ "http://huggingface.co",
 
 
 
99
  ]
100
 
101
  app.add_middleware(
 
107
  )
108
 
109
 
110
+ @app.get("/image")
111
+ async def generate_image(prompt: str, negative_prompt: str, seed: int = 2134213213):
112
+ cached_img = database.check(prompt, negative_prompt, seed)
113
+ if cached_img:
114
+ logging.info(f"Image found in cache: {cached_img[0]}")
115
+ return StreamingResponse(open(cached_img[0], "rb"), media_type="image/jpeg")
116
+
117
+ logging.info(f"Image not found in cache, generating new image")
118
+ pil_image = generate(prompt, negative_prompt, seed)
119
+ img_id = str(uuid.uuid4())
120
+ img_path = IMGS_PATH / f"{img_id}.jpg"
121
+ pil_image.save(img_path)
122
+ img_io = io.BytesIO()
123
+ pil_image.save(img_io, "JPEG")
124
+ img_io.seek(0)
125
+ database.insert(prompt, negative_prompt, str(img_path), seed)
126
+
127
+ return StreamingResponse(img_io, media_type="image/jpeg")
128
+
129
+
130
  @app.get("/")
131
  async def main():
132
  # redirect to https://huggingface.co/spaces/multimodalart/stable-cascade
133
+ return RedirectResponse(
134
+ "https://multimodalart-stable-cascade.hf.space/?__theme=system"
135
+ )
136
 
137
 
138
  if __name__ == "__main__":
139
  uvicorn.run(app, host="0.0.0.0", port=7860)
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  # else:
143
  # prior_pipeline = None
db.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from pathlib import Path
3
+
4
+
5
+ class Database:
6
+ def __init__(self, db_path=None):
7
+ if db_path is None:
8
+ raise ValueError("db_path must be provided")
9
+ self.db_path = db_path
10
+ self.db_file = self.db_path / "cache.db"
11
+ if not self.db_file.exists():
12
+ print("Creating database")
13
+ print("DB_FILE", self.db_file)
14
+ db = sqlite3.connect(self.db_file)
15
+ with open(Path("schema.sql"), "r") as f:
16
+ db.executescript(f.read())
17
+ db.commit()
18
+ db.close()
19
+
20
+ def get_db(self):
21
+ db = sqlite3.connect(self.db_file, check_same_thread=False)
22
+ db.row_factory = sqlite3.Row
23
+ return db
24
+
25
+ def __enter__(self):
26
+ self.db = self.get_db()
27
+ return self.db
28
+
29
+ def __exit__(self, exc_type, exc_value, traceback):
30
+ self.db.close()
31
+
32
+ def __call__(self):
33
+ return self
34
+
35
+ def insert(self, prompt: str, negative_prompt: str, image_path: str, seed: int):
36
+ with self() as db:
37
+ cursor = db.cursor()
38
+ cursor.execute(
39
+ "INSERT INTO cache (prompt, negative_prompt, image_path, seed) VALUES (?, ?, ?, ?)",
40
+ (prompt, negative_prompt, image_path, seed),
41
+ )
42
+ db.commit()
43
+
44
+ def check(self, prompt: str, negative_prompt: str, seed: int):
45
+ with self() as db:
46
+ cursor = db.cursor()
47
+ cursor.execute(
48
+ "SELECT image_path FROM cache WHERE prompt = ? AND negative_prompt = ? AND seed = ? ORDER BY RANDOM() LIMIT 1",
49
+ (prompt, negative_prompt, seed),
50
+ )
51
+ image_path = cursor.fetchone()
52
+ if image_path:
53
+ return image_path
54
+ return False
requirements.txt CHANGED
@@ -1,8 +1,9 @@
1
  git+https://github.com/kashif/diffusers.git@wuerstchen-v3
 
 
 
 
 
2
  accelerate
3
  safetensors
4
- transformers
5
- gradio
6
- fastapi
7
- pydantic
8
- uvicorn
 
1
  git+https://github.com/kashif/diffusers.git@wuerstchen-v3
2
+ fastapi==0.109.2
3
+ numpy==1.26.4
4
+ Pillow==10.2.0
5
+ torch==2.2.0
6
+ uvicorn==0.27.1
7
  accelerate
8
  safetensors
9
+ transformers
 
 
 
 
schema.sql ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRAGMA foreign_keys = OFF;
2
+
3
+ BEGIN TRANSACTION;
4
+
5
+ CREATE TABLE cache (
6
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
7
+ prompt TEXT NOT NULL,
8
+ negative_prompt TEXT NOT NULL,
9
+ image_path TEXT NOT NULL,
10
+ seed INTEGER NOT NULL,
11
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
12
+ );
13
+
14
+ COMMIT;