Spaces:
Runtime error
Runtime error
api to get rooms data
Browse files- stablediffusion-infinity/app.py +58 -33
stablediffusion-infinity/app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import io
|
2 |
import os
|
|
|
3 |
|
4 |
from pathlib import Path
|
5 |
import uvicorn
|
@@ -26,6 +27,7 @@ import requests
|
|
26 |
import shortuuid
|
27 |
import re
|
28 |
import time
|
|
|
29 |
|
30 |
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
|
31 |
AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY')
|
@@ -37,22 +39,35 @@ FILE_TYPES = {
|
|
37 |
'image/png': 'png',
|
38 |
'image/jpeg': 'jpg',
|
39 |
}
|
40 |
-
|
|
|
|
|
41 |
|
42 |
app = FastAPI()
|
43 |
|
44 |
-
if not
|
45 |
print("Creating database")
|
46 |
-
print("
|
47 |
-
db = sqlite3.connect(
|
48 |
with open(Path("schema.sql"), "r") as f:
|
49 |
db.executescript(f.read())
|
50 |
db.commit()
|
51 |
db.close()
|
52 |
|
53 |
|
54 |
-
def
|
55 |
-
db = sqlite3.connect(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
db.row_factory = sqlite3.Row
|
57 |
try:
|
58 |
yield db
|
@@ -77,6 +92,11 @@ model = {}
|
|
77 |
STATIC_MASK = Image.open("mask.png")
|
78 |
|
79 |
|
|
|
|
|
|
|
|
|
|
|
80 |
def get_model():
|
81 |
if "inpaint" not in model:
|
82 |
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
|
@@ -86,31 +106,9 @@ def get_model():
|
|
86 |
torch_dtype=torch.float16,
|
87 |
vae=vae,
|
88 |
).to("cuda")
|
89 |
-
# lms = LMSDiscreteScheduler(
|
90 |
-
# beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
91 |
-
|
92 |
-
# img2img = StableDiffusionImg2ImgPipeline(
|
93 |
-
# vae=text2img.vae,
|
94 |
-
# text_encoder=text2img.text_encoder,
|
95 |
-
# tokenizer=text2img.tokenizer,
|
96 |
-
# unet=text2img.unet,
|
97 |
-
# scheduler=lms,
|
98 |
-
# safety_checker=text2img.safety_checker,
|
99 |
-
# feature_extractor=text2img.feature_extractor,
|
100 |
-
# ).to("cuda")
|
101 |
-
# try:
|
102 |
-
# total_memory = torch.cuda.get_device_properties(0).total_memory // (
|
103 |
-
# 1024 ** 3
|
104 |
-
# )
|
105 |
-
# if total_memory <= 5:
|
106 |
-
# inpaint.enable_attention_slicing()
|
107 |
-
# except:
|
108 |
-
# pass
|
109 |
model["inpaint"] = inpaint
|
110 |
-
# model["img2img"] = img2img
|
111 |
|
112 |
return model["inpaint"]
|
113 |
-
# model["img2img"]
|
114 |
|
115 |
|
116 |
# init model on startup
|
@@ -274,10 +272,10 @@ def get_room_count(room_id: str):
|
|
274 |
|
275 |
@ app.on_event("startup")
|
276 |
@ repeat_every(seconds=100)
|
277 |
-
|
|
|
278 |
try:
|
279 |
-
|
280 |
-
for db in get_db():
|
281 |
rooms = db.execute("SELECT * FROM rooms").fetchall()
|
282 |
for row in rooms:
|
283 |
room_id = row["room_id"]
|
@@ -291,14 +289,41 @@ async def sync_rooms():
|
|
291 |
print("Rooms update failed")
|
292 |
|
293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
@ app.get('/api/rooms')
|
295 |
-
async def get_rooms(db: sqlite3.Connection = Depends(
|
|
|
296 |
rooms = db.execute("SELECT * FROM rooms").fetchall()
|
297 |
return rooms
|
298 |
|
299 |
|
300 |
@ app.post('/api/auth')
|
301 |
-
async def autorize(request: Request
|
302 |
data = await request.json()
|
303 |
room = data["room"]
|
304 |
payload = {
|
|
|
1 |
import io
|
2 |
import os
|
3 |
+
from typing import Union
|
4 |
|
5 |
from pathlib import Path
|
6 |
import uvicorn
|
|
|
27 |
import shortuuid
|
28 |
import re
|
29 |
import time
|
30 |
+
import subprocess
|
31 |
|
32 |
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
|
33 |
AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY')
|
|
|
39 |
'image/png': 'png',
|
40 |
'image/jpeg': 'jpg',
|
41 |
}
|
42 |
+
S3_DATA_FOLDER = Path("sd-multiplayer-data")
|
43 |
+
ROOMS_DATA_DB = S3_DATA_FOLDER / "rooms_data.db"
|
44 |
+
ROOM_DB = Path("rooms.db")
|
45 |
|
46 |
app = FastAPI()
|
47 |
|
48 |
+
if not ROOM_DB.exists():
|
49 |
print("Creating database")
|
50 |
+
print("ROOM_DB", ROOM_DB)
|
51 |
+
db = sqlite3.connect(ROOM_DB)
|
52 |
with open(Path("schema.sql"), "r") as f:
|
53 |
db.executescript(f.read())
|
54 |
db.commit()
|
55 |
db.close()
|
56 |
|
57 |
|
58 |
+
def get_room_db():
|
59 |
+
db = sqlite3.connect(ROOM_DB, check_same_thread=False)
|
60 |
+
db.row_factory = sqlite3.Row
|
61 |
+
try:
|
62 |
+
yield db
|
63 |
+
except Exception:
|
64 |
+
db.rollback()
|
65 |
+
finally:
|
66 |
+
db.close()
|
67 |
+
|
68 |
+
|
69 |
+
def get_room_data_db():
|
70 |
+
db = sqlite3.connect(ROOMS_DATA_DB, check_same_thread=False)
|
71 |
db.row_factory = sqlite3.Row
|
72 |
try:
|
73 |
yield db
|
|
|
92 |
STATIC_MASK = Image.open("mask.png")
|
93 |
|
94 |
|
95 |
+
def sync_rooms_data_repo():
|
96 |
+
subprocess.Popen("git fetch && git reset --hard origin/main",
|
97 |
+
cwd=S3_DATA_FOLDER, shell=True)
|
98 |
+
|
99 |
+
|
100 |
def get_model():
|
101 |
if "inpaint" not in model:
|
102 |
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
|
|
|
106 |
torch_dtype=torch.float16,
|
107 |
vae=vae,
|
108 |
).to("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
model["inpaint"] = inpaint
|
|
|
110 |
|
111 |
return model["inpaint"]
|
|
|
112 |
|
113 |
|
114 |
# init model on startup
|
|
|
272 |
|
273 |
@ app.on_event("startup")
|
274 |
@ repeat_every(seconds=100)
|
275 |
+
def sync_rooms():
|
276 |
+
print("Syncing rooms active users")
|
277 |
try:
|
278 |
+
for db in get_room_db():
|
|
|
279 |
rooms = db.execute("SELECT * FROM rooms").fetchall()
|
280 |
for row in rooms:
|
281 |
room_id = row["room_id"]
|
|
|
289 |
print("Rooms update failed")
|
290 |
|
291 |
|
292 |
+
@ app.on_event("startup")
|
293 |
+
@ repeat_every(seconds=300)
|
294 |
+
def sync_room_datq():
|
295 |
+
print("Sync rooms data")
|
296 |
+
sync_rooms_data_repo()
|
297 |
+
|
298 |
+
|
299 |
+
@ app.get('/api/room_data/{room_id}')
|
300 |
+
async def get_rooms(room_id: str, start: str = None, end: str = None, db: sqlite3.Connection = Depends(get_room_data_db)):
|
301 |
+
print("Getting rooms data", room_id, start, end)
|
302 |
+
|
303 |
+
if start is None and end is None:
|
304 |
+
rooms_rows = db.execute(
|
305 |
+
"SELECT key, prompt, time, x, y FROM rooms_data WHERE room_id = ? ORDER BY time", (room_id,)).fetchall()
|
306 |
+
elif end is None:
|
307 |
+
rooms_rows = db.execute("SELECT key, prompt, time, x, y FROM rooms_data WHERE room_id = ? AND time >= ? ORDER BY time",
|
308 |
+
(room_id, start)).fetchall()
|
309 |
+
elif start is None:
|
310 |
+
rooms_rows = db.execute("SELECT key, prompt, time, x, y FROM rooms_data WHERE room_id = ? AND time <= ? ORDER BY time",
|
311 |
+
(room_id, end)).fetchall()
|
312 |
+
else:
|
313 |
+
rooms_rows = db.execute("SELECT key, prompt, time, x, y FROM rooms_data WHERE room_id = ? AND time >= ? AND time <= ? ORDER BY time",
|
314 |
+
(room_id, start, end)).fetchall()
|
315 |
+
return rooms_rows
|
316 |
+
|
317 |
+
|
318 |
@ app.get('/api/rooms')
|
319 |
+
async def get_rooms(db: sqlite3.Connection = Depends(get_room_db)):
|
320 |
+
print("Getting rooms")
|
321 |
rooms = db.execute("SELECT * FROM rooms").fetchall()
|
322 |
return rooms
|
323 |
|
324 |
|
325 |
@ app.post('/api/auth')
|
326 |
+
async def autorize(request: Request):
|
327 |
data = await request.json()
|
328 |
room = data["room"]
|
329 |
payload = {
|