radames commited on
Commit
a101d9b
1 Parent(s): e8ea372
Files changed (4) hide show
  1. app.py +145 -104
  2. db.py +29 -0
  3. requirements.txt +9 -6
  4. schema.sql +10 -58
app.py CHANGED
@@ -1,120 +1,161 @@
1
- import boto3
2
  import os
3
  import re
 
 
 
4
  import json
5
- from pathlib import Path
6
- import sqlite3
7
- from huggingface_hub import Repository, HfFolder
8
- import tqdm
9
- import subprocess
10
 
11
- from fastapi import FastAPI
 
 
12
  from fastapi_utils.tasks import repeat_every
 
 
 
13
 
 
14
 
15
- AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
16
- AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY')
17
- AWS_S3_BUCKET_NAME = os.getenv('AWS_S3_BUCKET_NAME')
18
-
19
- s3 = boto3.client(service_name='s3',
20
- aws_access_key_id=AWS_ACCESS_KEY_ID,
21
- aws_secret_access_key=AWS_SECRET_KEY)
22
-
23
-
24
- paginator = s3.get_paginator('list_objects_v2')
25
-
26
-
27
- S3_DATA_FOLDER = Path("sd-multiplayer-data")
28
- ROOMS_DATA_DB = S3_DATA_FOLDER / "rooms_data.db"
29
-
30
-
31
- repo = Repository(
32
- local_dir=S3_DATA_FOLDER,
33
- repo_type="dataset",
34
- clone_from="huggingface-projects/sd-multiplayer-data",
35
- use_auth_token=True,
36
- )
37
- repo.git_pull()
38
-
39
-
40
- if not ROOMS_DATA_DB.exists():
41
- print("Creating database")
42
- print("ROOMS_DATA_DB", ROOMS_DATA_DB)
43
- db = sqlite3.connect(ROOMS_DATA_DB)
44
- with open(Path("schema.sql"), "r") as f:
45
- db.executescript(f.read())
46
- db.commit()
47
- db.close()
48
-
49
-
50
- def get_db(db_path):
51
- db = sqlite3.connect(db_path, check_same_thread=False)
52
- db.row_factory = sqlite3.Row
53
- try:
54
- yield db
55
- except Exception:
56
- db.rollback()
57
- finally:
58
- db.close()
59
-
60
-
61
- def sync_rooms_to_dataset():
62
- for room_data_db in get_db(ROOMS_DATA_DB):
63
- rooms = room_data_db.execute("SELECT * FROM rooms").fetchall()
64
- cursor = room_data_db.cursor()
65
- for row in tqdm.tqdm(rooms):
66
- room_id = row["room_id"]
67
- print("syncing room data: ", room_id)
68
-
69
- objects = []
70
- for result in paginator.paginate(Bucket=AWS_S3_BUCKET_NAME, Prefix=f'{room_id}/', Delimiter='/'):
71
- results = []
72
- for obj in result.get('Contents'):
73
- try:
74
- key = obj.get('Key')
75
- time = obj.get('LastModified').isoformat()
76
- split_str = re.split(r'[-/.]', key)
77
- uuid = split_str[3]
78
- x, y = [int(n)
79
- for n in re.split(r'[_]', split_str[4])]
80
- prompt = ' '.join(split_str[4:])
81
- results.append(
82
- {'x': x, 'y': y, 'prompt': prompt, 'time': time, 'key': key, 'uuid': uuid})
83
- cursor.execute(
84
- 'INSERT INTO rooms_data VALUES (NULL, ?, ?, ?, ?, ?, ?, ?)', (room_id, uuid, x, y, prompt, time, key))
85
- except Exception as e:
86
- print(e)
87
- continue
88
-
89
- objects += results
90
- room_data_db.commit()
91
-
92
- all_rows = [dict(row) for row in room_data_db.execute(
93
- "SELECT * FROM rooms_data WHERE room_id = ?", (room_id,)).fetchall()]
94
- data_path = S3_DATA_FOLDER / f"{room_id}.json"
95
- with open(data_path, 'w') as f:
96
- json.dump(all_rows, f, separators=(',', ':'))
97
- print("Updating repository")
98
- subprocess.Popen(
99
- "git add . && git commit --amend -m 'update' && git push --force", cwd=S3_DATA_FOLDER, shell=True)
100
 
101
 
102
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
 
 
104
 
105
- @app.get("/")
106
- def read_root():
107
- return "Just a bot to sync data to huggingface datasets and tweet tha latest data"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
- @app.get("/sync")
111
- def sync():
112
- sync_rooms_to_dataset()
113
  return "Synced data to huggingface datasets"
114
 
115
 
116
- @app.on_event("startup")
117
- @repeat_every(seconds=1800)
118
- def repeat_sync():
119
- sync_rooms_to_dataset()
120
- return "Synced data to huggingface datasets"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
+ import asyncio
4
+ import aiohttp
5
+ import requests
6
  import json
7
+ from tqdm import tqdm
 
 
 
 
8
 
9
+ from huggingface_hub import Repository
10
+
11
+ from fastapi import FastAPI, BackgroundTasks
12
  from fastapi_utils.tasks import repeat_every
13
+ from fastapi.staticfiles import StaticFiles
14
+
15
+ from db import Database
16
 
17
+ HF_TOKEN = os.environ.get("HF_TOKEN")
18
 
19
+ database = Database()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
+ async def check_image_url(url):
23
+ async with aiohttp.ClientSession() as session:
24
+ async with session.head(url) as resp:
25
+ if resp.status == 200 and resp.content_type.startswith("image/"):
26
+ return url
27
+
28
+
29
+ def fetch_models(page=0):
30
+ response = requests.get(
31
+ f'https://huggingface.co/models-json?pipeline_tag=text-to-image&p={page}&sort=likes')
32
+ data = response.json()
33
+ return {
34
+ "models": [model for model in data['models'] if not model['private']],
35
+ "numItemsPerPage": data['numItemsPerPage'],
36
+ "numTotalItems": data['numTotalItems'],
37
+ "pageIndex": data['pageIndex']
38
+ }
39
+
40
+
41
+ def fetch_model_card(model):
42
+ response = requests.get(
43
+ f'https://huggingface.co/{model["id"]}/raw/main/README.md')
44
+ return response.text
45
+
46
+
47
+ def find_image_in_model_card(text):
48
+ image_regex = re.compile(r'https?://\S+(?:png|jpg|jpeg|webp)')
49
+ urls = re.findall(image_regex, text)
50
+ # python check if arrya is not empty
51
+ # if urls:
52
+ # tasks = []
53
+ # for url in urls:
54
+ # tasks.append(check_image_url(url))
55
+
56
+ # results = await asyncio.gather(*tasks)
57
+ # return [result for result in results if result]
58
+ return urls
59
+
60
+
61
+ def run_inference(endpoint, img):
62
+ headers = {'Authorization': f'Bearer {HF_TOKEN}',
63
+ "X-Wait-For-Model": "true",
64
+ "X-Use-Cache": "true"}
65
 
66
+ response = requests.post(endpoint, headers=headers, data=img)
67
+ return response.json() if response.ok else []
68
 
69
+
70
+ def get_all_models():
71
+ initial = fetch_models()
72
+ num_pages = initial['numTotalItems'] // initial['numItemsPerPage']
73
+
74
+ print(f"Found {num_pages} pages")
75
+
76
+ # fetch all models
77
+ models = []
78
+ for page in tqdm(range(0, num_pages)):
79
+ print(f"Fetching page {page} of {num_pages}")
80
+ page_models = fetch_models(page)
81
+ models += page_models['models']
82
+
83
+ # fetch datacards and images
84
+ print(f"Found {len(models)} models")
85
+ final_models = []
86
+ for model in tqdm(models):
87
+ print(f"Fetching model {model['id']}")
88
+ model_card = fetch_model_card(model)
89
+ images = find_image_in_model_card(model_card)
90
+ # style = await run_inference(f"https://api-inference.huggingface.co/models/{model['id']}", images[0])
91
+ style = []
92
+ # aesthetic = await run_inference(f"https://api-inference.huggingface.co/models/{model['id']}", images[0])
93
+ aesthetic = []
94
+ final_models.append(
95
+ {**model, "images": images, "style": style, "aesthetic": aesthetic}
96
+ )
97
+ return final_models
98
+
99
+
100
+ async def sync_data():
101
+ models = get_all_models()
102
+
103
+ with open("data/models.json", "w") as f:
104
+ json.dump(models, f)
105
+
106
+ with database.get_db() as db:
107
+ cursor = db.cursor()
108
+ for model in models:
109
+ try:
110
+ cursor.execute("INSERT INTO models (data) VALUES (?)",
111
+ [json.dumps(model)])
112
+ except Exception as e:
113
+ print(model['id'], model)
114
+ db.commit()
115
+
116
+
117
+ app = FastAPI()
118
 
119
 
120
+ @ app.get("/sync")
121
+ async def sync(background_tasks: BackgroundTasks):
122
+ background_tasks.add_task(sync_data)
123
  return "Synced data to huggingface datasets"
124
 
125
 
126
+ MAX_PAGE_SIZE = 30
127
+
128
+
129
+ @app.get("/api/models")
130
+ def get_page(page: int = 1):
131
+ page = page if page > 0 else 1
132
+ with database.get_db() as db:
133
+ cursor = db.cursor()
134
+ cursor.execute("""
135
+ SELECT *
136
+ FROM (
137
+ SELECT *, COUNT(*) OVER() AS total
138
+ FROM models
139
+ GROUP BY json_extract(data, '$.id')
140
+ HAVING COUNT(json_extract(data, '$.id')) = 1
141
+ )
142
+ ORDER BY json_extract(data, '$.likes') DESC
143
+ LIMIT ? OFFSET ?
144
+ """, (MAX_PAGE_SIZE, (page - 1) * MAX_PAGE_SIZE))
145
+ results = cursor.fetchall()
146
+ total = results[0][3] if results else 0
147
+ total_pages = (total + MAX_PAGE_SIZE - 1) // MAX_PAGE_SIZE
148
+
149
+ return {
150
+ "models": [json.loads(result[1]) for result in results],
151
+ "totalPages": total_pages
152
+ }
153
+
154
+
155
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")
156
+
157
+ # @app.on_event("startup")
158
+ # @repeat_every(seconds=1800)
159
+ # def repeat_sync():
160
+ # sync_rooms_to_dataset()
161
+ # return "Synced data to huggingface datasets"
db.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from pathlib import Path
3
+
4
+
5
+ class Database:
6
+ DB_PATH = Path("data/")
7
+ DB_FILE = DB_PATH / "models.db"
8
+
9
+ def __init__(self):
10
+ if not self.DB_FILE.exists():
11
+ print("Creating database")
12
+ print("DB_FILE", self.DB_FILE)
13
+ db = sqlite3.connect(self.DB_FILE)
14
+ with open(Path("schema.sql"), "r") as f:
15
+ db.executescript(f.read())
16
+ db.commit()
17
+ db.close()
18
+
19
+ def get_db(self):
20
+ db = sqlite3.connect(self.DB_FILE, check_same_thread=False)
21
+ db.row_factory = sqlite3.Row
22
+ return db
23
+
24
+ def __enter__(self):
25
+ self.db = self.get_db()
26
+ return self.db
27
+
28
+ def __exit__(self, exc_type, exc_value, traceback):
29
+ self.db.close()
requirements.txt CHANGED
@@ -1,6 +1,9 @@
1
- huggingface-hub==0.10
2
- fastapi-utils==0.2
3
- uvicorn==0.19
4
- tqdm==4.64
5
- boto3==1.26
6
- fastapi==0.86
 
 
 
 
1
+ huggingface-hub
2
+ fastapi-utils
3
+ uvicorn
4
+ tqdm
5
+ fastapi
6
+ requests
7
+ asyncio
8
+ aiohttp
9
+ sqlite3
schema.sql CHANGED
@@ -1,59 +1,11 @@
1
- PRAGMA foreign_keys=OFF;
 
2
  BEGIN TRANSACTION;
3
- CREATE TABLE rooms (
4
- id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
5
- room_id TEXT NOT NULL
6
- );INSERT INTO rooms VALUES(1,'room-0');
7
- INSERT INTO rooms VALUES(2,'room-1');
8
- INSERT INTO rooms VALUES(3,'room-2');
9
- INSERT INTO rooms VALUES(4,'room-3');
10
- INSERT INTO rooms VALUES(5,'room-4');
11
- INSERT INTO rooms VALUES(6,'room-5');
12
- INSERT INTO rooms VALUES(7,'room-6');
13
- INSERT INTO rooms VALUES(8,'room-7');
14
- INSERT INTO rooms VALUES(9,'room-8');
15
- INSERT INTO rooms VALUES(10,'room-9');
16
- INSERT INTO rooms VALUES(11,'room-10');
17
- INSERT INTO rooms VALUES(12,'room-11');
18
- INSERT INTO rooms VALUES(13,'room-12');
19
- INSERT INTO rooms VALUES(14,'room-13');
20
- INSERT INTO rooms VALUES(15,'room-14');
21
- INSERT INTO rooms VALUES(16,'room-15');
22
- INSERT INTO rooms VALUES(17,'room-16');
23
- INSERT INTO rooms VALUES(18,'room-17');
24
- INSERT INTO rooms VALUES(19,'room-18');
25
- INSERT INTO rooms VALUES(20,'room-19');
26
- INSERT INTO rooms VALUES(21,'room-20');
27
- INSERT INTO rooms VALUES(22,'room-21');
28
- INSERT INTO rooms VALUES(23,'room-22');
29
- INSERT INTO rooms VALUES(24,'room-23');
30
- INSERT INTO rooms VALUES(25,'room-24');
31
- INSERT INTO rooms VALUES(26,'room-25');
32
- INSERT INTO rooms VALUES(27,'room-26');
33
- INSERT INTO rooms VALUES(28,'room-27');
34
- INSERT INTO rooms VALUES(29,'room-28');
35
- INSERT INTO rooms VALUES(30,'room-29');
36
- INSERT INTO rooms VALUES(31,'room-30');
37
- INSERT INTO rooms VALUES(32,'room-31');
38
- INSERT INTO rooms VALUES(33,'room-32');
39
- INSERT INTO rooms VALUES(34,'room-33');
40
- INSERT INTO rooms VALUES(35,'room-34');
41
- INSERT INTO rooms VALUES(36,'room-35');
42
- INSERT INTO rooms VALUES(37,'room-36');
43
- INSERT INTO rooms VALUES(38,'room-37');
44
- INSERT INTO rooms VALUES(39,'room-38');
45
- INSERT INTO rooms VALUES(40,'room-39');
46
- INSERT INTO rooms VALUES(41,'room-40');
47
- DELETE FROM sqlite_sequence;
48
- INSERT INTO sqlite_sequence VALUES('rooms',41);
49
- CREATE TABLE rooms_data (
50
- id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
51
- room_id TEXT NOT NULL,
52
- uuid TEXT NOT NULL,
53
- x INTEGER NOT NULL,
54
- y INTEGER NOT NULL,
55
- prompt TEXT NOT NULL,
56
- time DATETIME NOT NULL,
57
- key TEXT NOT NULL,
58
- UNIQUE (key) ON CONFLICT IGNORE
59
- );COMMIT;
 
1
+ PRAGMA foreign_keys = OFF;
2
+
3
  BEGIN TRANSACTION;
4
+
5
+ CREATE TABLE models (
6
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
7
+ data json,
8
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
9
+ );
10
+
11
+ COMMIT;