radames HF staff commited on
Commit
4281ee1
1 Parent(s): bc62831

api to get rooms data

Browse files
Files changed (1) hide show
  1. stablediffusion-infinity/app.py +58 -33
stablediffusion-infinity/app.py CHANGED
@@ -1,5 +1,6 @@
1
  import io
2
  import os
 
3
 
4
  from pathlib import Path
5
  import uvicorn
@@ -26,6 +27,7 @@ import requests
26
  import shortuuid
27
  import re
28
  import time
 
29
 
30
  AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
31
  AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY')
@@ -37,22 +39,35 @@ FILE_TYPES = {
37
  'image/png': 'png',
38
  'image/jpeg': 'jpg',
39
  }
40
- DB_PATH = Path("rooms.db")
 
 
41
 
42
  app = FastAPI()
43
 
44
- if not DB_PATH.exists():
45
  print("Creating database")
46
- print("DB_PATH", DB_PATH)
47
- db = sqlite3.connect(DB_PATH)
48
  with open(Path("schema.sql"), "r") as f:
49
  db.executescript(f.read())
50
  db.commit()
51
  db.close()
52
 
53
 
54
- def get_db():
55
- db = sqlite3.connect(DB_PATH, check_same_thread=False)
 
 
 
 
 
 
 
 
 
 
 
56
  db.row_factory = sqlite3.Row
57
  try:
58
  yield db
@@ -77,6 +92,11 @@ model = {}
77
  STATIC_MASK = Image.open("mask.png")
78
 
79
 
 
 
 
 
 
80
  def get_model():
81
  if "inpaint" not in model:
82
  vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
@@ -86,31 +106,9 @@ def get_model():
86
  torch_dtype=torch.float16,
87
  vae=vae,
88
  ).to("cuda")
89
- # lms = LMSDiscreteScheduler(
90
- # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
91
-
92
- # img2img = StableDiffusionImg2ImgPipeline(
93
- # vae=text2img.vae,
94
- # text_encoder=text2img.text_encoder,
95
- # tokenizer=text2img.tokenizer,
96
- # unet=text2img.unet,
97
- # scheduler=lms,
98
- # safety_checker=text2img.safety_checker,
99
- # feature_extractor=text2img.feature_extractor,
100
- # ).to("cuda")
101
- # try:
102
- # total_memory = torch.cuda.get_device_properties(0).total_memory // (
103
- # 1024 ** 3
104
- # )
105
- # if total_memory <= 5:
106
- # inpaint.enable_attention_slicing()
107
- # except:
108
- # pass
109
  model["inpaint"] = inpaint
110
- # model["img2img"] = img2img
111
 
112
  return model["inpaint"]
113
- # model["img2img"]
114
 
115
 
116
  # init model on startup
@@ -274,10 +272,10 @@ def get_room_count(room_id: str):
274
 
275
  @ app.on_event("startup")
276
  @ repeat_every(seconds=100)
277
- async def sync_rooms():
 
278
  try:
279
- jwtToken = generateAuthToken()
280
- for db in get_db():
281
  rooms = db.execute("SELECT * FROM rooms").fetchall()
282
  for row in rooms:
283
  room_id = row["room_id"]
@@ -291,14 +289,41 @@ async def sync_rooms():
291
  print("Rooms update failed")
292
 
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  @ app.get('/api/rooms')
295
- async def get_rooms(db: sqlite3.Connection = Depends(get_db)):
 
296
  rooms = db.execute("SELECT * FROM rooms").fetchall()
297
  return rooms
298
 
299
 
300
  @ app.post('/api/auth')
301
- async def autorize(request: Request, db: sqlite3.Connection = Depends(get_db)):
302
  data = await request.json()
303
  room = data["room"]
304
  payload = {
 
1
  import io
2
  import os
3
+ from typing import Union
4
 
5
  from pathlib import Path
6
  import uvicorn
 
27
  import shortuuid
28
  import re
29
  import time
30
+ import subprocess
31
 
32
  AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
33
  AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY')
 
39
  'image/png': 'png',
40
  'image/jpeg': 'jpg',
41
  }
42
+ S3_DATA_FOLDER = Path("sd-multiplayer-data")
43
+ ROOMS_DATA_DB = S3_DATA_FOLDER / "rooms_data.db"
44
+ ROOM_DB = Path("rooms.db")
45
 
46
  app = FastAPI()
47
 
48
+ if not ROOM_DB.exists():
49
  print("Creating database")
50
+ print("ROOM_DB", ROOM_DB)
51
+ db = sqlite3.connect(ROOM_DB)
52
  with open(Path("schema.sql"), "r") as f:
53
  db.executescript(f.read())
54
  db.commit()
55
  db.close()
56
 
57
 
58
+ def get_room_db():
59
+ db = sqlite3.connect(ROOM_DB, check_same_thread=False)
60
+ db.row_factory = sqlite3.Row
61
+ try:
62
+ yield db
63
+ except Exception:
64
+ db.rollback()
65
+ finally:
66
+ db.close()
67
+
68
+
69
+ def get_room_data_db():
70
+ db = sqlite3.connect(ROOMS_DATA_DB, check_same_thread=False)
71
  db.row_factory = sqlite3.Row
72
  try:
73
  yield db
 
92
  STATIC_MASK = Image.open("mask.png")
93
 
94
 
95
+ def sync_rooms_data_repo():
96
+ subprocess.Popen("git fetch && git reset --hard origin/main",
97
+ cwd=S3_DATA_FOLDER, shell=True)
98
+
99
+
100
  def get_model():
101
  if "inpaint" not in model:
102
  vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
 
106
  torch_dtype=torch.float16,
107
  vae=vae,
108
  ).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  model["inpaint"] = inpaint
 
110
 
111
  return model["inpaint"]
 
112
 
113
 
114
  # init model on startup
 
272
 
273
  @ app.on_event("startup")
274
  @ repeat_every(seconds=100)
275
+ def sync_rooms():
276
+ print("Syncing rooms active users")
277
  try:
278
+ for db in get_room_db():
 
279
  rooms = db.execute("SELECT * FROM rooms").fetchall()
280
  for row in rooms:
281
  room_id = row["room_id"]
 
289
  print("Rooms update failed")
290
 
291
 
292
+ @ app.on_event("startup")
293
+ @ repeat_every(seconds=300)
294
+ def sync_room_datq():
295
+ print("Sync rooms data")
296
+ sync_rooms_data_repo()
297
+
298
+
299
+ @ app.get('/api/room_data/{room_id}')
300
+ async def get_rooms(room_id: str, start: str = None, end: str = None, db: sqlite3.Connection = Depends(get_room_data_db)):
301
+ print("Getting rooms data", room_id, start, end)
302
+
303
+ if start is None and end is None:
304
+ rooms_rows = db.execute(
305
+ "SELECT key, prompt, time, x, y FROM rooms_data WHERE room_id = ? ORDER BY time", (room_id,)).fetchall()
306
+ elif end is None:
307
+ rooms_rows = db.execute("SELECT key, prompt, time, x, y FROM rooms_data WHERE room_id = ? AND time >= ? ORDER BY time",
308
+ (room_id, start)).fetchall()
309
+ elif start is None:
310
+ rooms_rows = db.execute("SELECT key, prompt, time, x, y FROM rooms_data WHERE room_id = ? AND time <= ? ORDER BY time",
311
+ (room_id, end)).fetchall()
312
+ else:
313
+ rooms_rows = db.execute("SELECT key, prompt, time, x, y FROM rooms_data WHERE room_id = ? AND time >= ? AND time <= ? ORDER BY time",
314
+ (room_id, start, end)).fetchall()
315
+ return rooms_rows
316
+
317
+
318
  @ app.get('/api/rooms')
319
+ async def get_rooms(db: sqlite3.Connection = Depends(get_room_db)):
320
+ print("Getting rooms")
321
  rooms = db.execute("SELECT * FROM rooms").fetchall()
322
  return rooms
323
 
324
 
325
  @ app.post('/api/auth')
326
+ async def autorize(request: Request):
327
  data = await request.json()
328
  room = data["room"]
329
  payload = {