radames commited on
Commit
42e31b9
1 Parent(s): bef266a

classification code

Browse files
Files changed (2) hide show
  1. app.py +72 -52
  2. classifier.py +70 -0
app.py CHANGED
@@ -31,6 +31,10 @@ S3_DATA_FOLDER = Path("sd-multiplayer-data")
31
 
32
  DB_FOLDER = Path("diffusers-gallery-data")
33
 
 
 
 
 
34
  s3 = boto3.client(service_name='s3',
35
  aws_access_key_id=AWS_ACCESS_KEY_ID,
36
  aws_secret_access_key=AWS_SECRET_KEY)
@@ -76,9 +80,9 @@ def fetch_models(page=0):
76
  }
77
 
78
 
79
- def fetch_model_card(model):
80
  response = requests.get(
81
- f'https://huggingface.co/{model["id"]}/raw/main/README.md')
82
  return response.text
83
 
84
 
@@ -94,16 +98,31 @@ async def find_image_in_model_card(text):
94
  return await asyncio.gather(*tasks)
95
 
96
 
97
- def run_inference(endpoint, img):
98
- headers = {'Authorization': f'Bearer {HF_TOKEN}',
99
- "X-Wait-For-Model": "true",
100
- "X-Use-Cache": "true"}
 
 
 
 
 
 
 
101
 
102
- response = requests.post(endpoint, headers=headers, data=img)
103
- return response.json() if response.ok else []
 
 
 
104
 
 
 
 
 
105
 
106
- async def get_all_models():
 
107
  initial = fetch_models(0)
108
  num_pages = ceil(initial['numTotalItems'] / initial['numItemsPerPage'])
109
 
@@ -112,54 +131,55 @@ async def get_all_models():
112
  print(f"Found {num_pages} pages")
113
 
114
  # fetch all models
115
- models = []
116
  for page in tqdm(range(0, num_pages)):
117
  print(f"Fetching page {page} of {num_pages}")
118
  page_models = fetch_models(page)
119
- models += page_models['models']
120
-
121
- with open(DB_FOLDER / "models_temp.json", "w") as f:
122
- json.dump(models, f)
123
-
124
- # fetch datacards and images
125
- print(f"Found {len(models)} models")
126
- final_models = []
127
- for model in tqdm(models):
128
- print(f"Fetching model {model['id']}")
129
- model_card = fetch_model_card(model)
130
- images = await find_image_in_model_card(model_card)
131
- # style = await run_inference(f"https://api-inference.huggingface.co/models/{model['id']}", images[0])
132
- style = []
133
- # aesthetic = await run_inference(f"https://api-inference.huggingface.co/models/{model['id']}", images[0])
134
- aesthetic = []
135
- final_models.append(
136
- {**model, "images": images, "style": style, "aesthetic": aesthetic}
137
- )
138
- return final_models
139
-
140
 
141
  async def sync_data():
142
  print("Fetching models")
143
- models = await get_all_models()
144
-
 
145
  with open(DB_FOLDER / "models.json", "w") as f:
146
- json.dump(models, f)
147
  # with open(DB_FOLDER / "models.json", "r") as f:
148
- # models = json.load(f)
149
- # open temp db
150
- print("Updating database")
 
 
151
  with database.get_db() as db:
152
  cursor = db.cursor()
153
- for model in models:
154
- try:
155
- cursor.execute("INSERT INTO models(id, data) VALUES (?, ?)",
156
- [model['id'], json.dumps(model)])
157
- except Exception as e:
158
- print(model['id'], model)
159
- db.commit()
160
- print("Updating repository")
161
- subprocess.Popen(
162
- "git add . && git commit --amend -m 'update' && git push --force", cwd=DB_FOLDER, shell=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
 
165
  app = FastAPI()
@@ -174,7 +194,7 @@ app.add_middleware(
174
 
175
  # @ app.get("/sync")
176
  # async def sync(background_tasks: BackgroundTasks):
177
- # background_tasks.add_task(sync_data)
178
  # return "Synced data to huggingface datasets"
179
 
180
 
@@ -189,16 +209,16 @@ def get_page(page: int = 1):
189
  cursor.execute("""
190
  SELECT *, COUNT(*) OVER() AS total
191
  FROM models
192
- WHERE json_extract(data, '$.likes') > 5
193
- ORDER BY json_extract(data, '$.likes') DESC, datetime(json_extract(data, '$.lastModified')) DESC
194
  LIMIT ? OFFSET ?
195
  """, (MAX_PAGE_SIZE, (page - 1) * MAX_PAGE_SIZE))
196
  results = cursor.fetchall()
197
- total = results[0][3] if results else 0
198
  total_pages = (total + MAX_PAGE_SIZE - 1) // MAX_PAGE_SIZE
199
 
200
  return {
201
- "models": [json.loads(result[1]) for result in results],
202
  "totalPages": total_pages
203
  }
204
 
 
31
 
32
  DB_FOLDER = Path("diffusers-gallery-data")
33
 
34
+ CLASSIFIER_URL = "https://radames-aesthetic-style-nsfw-classifier.hf.space/run/inference"
35
+ ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/"
36
+
37
+
38
  s3 = boto3.client(service_name='s3',
39
  aws_access_key_id=AWS_ACCESS_KEY_ID,
40
  aws_secret_access_key=AWS_SECRET_KEY)
 
80
  }
81
 
82
 
83
+ def fetch_model_card(model_id):
84
  response = requests.get(
85
+ f'https://huggingface.co/{model_id}/raw/main/README.md')
86
  return response.text
87
 
88
 
 
98
  return await asyncio.gather(*tasks)
99
 
100
 
101
+ def run_classifier(images):
102
+ images = [i for i in images if i is not None]
103
+ if len(images) > 0:
104
+ # classifying only the first image
105
+ images_urls = [ASSETS_URL + images[0]]
106
+ response = requests.post(CLASSIFIER_URL, json={"data": [
107
+ {"urls": images_urls}, # json urls: list of images urls
108
+ False, # enable/disable gallery image output
109
+ None, # single image input
110
+ None, # files input
111
+ ]}).json()
112
 
113
+ # data response is array data:[[{img0}, {img1}, {img2}...], Label, Gallery],
114
+ class_data = response['data'][0][0]
115
+ print(class_data)
116
+ class_data_parsed = {row['label']: round(
117
+ row['score'], 3) for row in class_data}
118
 
119
+ # update row data with classificator data
120
+ return class_data_parsed
121
+ else:
122
+ return {}
123
 
124
+
125
+ async def get_all_new_models():
126
  initial = fetch_models(0)
127
  num_pages = ceil(initial['numTotalItems'] / initial['numItemsPerPage'])
128
 
 
131
  print(f"Found {num_pages} pages")
132
 
133
  # fetch all models
134
+ new_models = []
135
  for page in tqdm(range(0, num_pages)):
136
  print(f"Fetching page {page} of {num_pages}")
137
  page_models = fetch_models(page)
138
+ new_models += page_models['models']
139
+ return new_models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  async def sync_data():
142
  print("Fetching models")
143
+ new_models = await get_all_new_models()
144
+ print(f"Found {len(new_models)} models")
145
+ # save list of all models for ids
146
  with open(DB_FOLDER / "models.json", "w") as f:
147
+ json.dump(new_models, f)
148
  # with open(DB_FOLDER / "models.json", "r") as f:
149
+ # new_models = json.load(f)
150
+
151
+ new_models_ids = [model['id'] for model in new_models]
152
+
153
+ # get existing models
154
  with database.get_db() as db:
155
  cursor = db.cursor()
156
+ cursor.execute("SELECT id FROM models")
157
+ existing_models = [row['id'] for row in cursor.fetchall()]
158
+ models_ids_to_add = list(set(new_models_ids) - set(existing_models))
159
+ # find all models id to add from new_models
160
+ models = [model for model in new_models if model['id'] in models_ids_to_add]
161
+
162
+ print(f"Found {len(models)} new models")
163
+ for model in tqdm(models):
164
+ model_id = model['id']
165
+ model_card = fetch_model_card(model_id)
166
+ images = await find_image_in_model_card(model_card)
167
+ classifier = run_classifier(images)
168
+ # update model row with image and classifier data
169
+ with database.get_db() as db:
170
+ cursor = db.cursor()
171
+ cursor.execute("INSERT INTO models(id, data) VALUES (?, ?)",
172
+ [model_id, json.dumps({
173
+ **model,
174
+ "images": images,
175
+ "class": classifier
176
+ })])
177
+ db.commit()
178
+
179
+
180
+ # print("Updating repository")
181
+ # subprocess.Popen(
182
+ # "git add . && git commit --amend -m 'update' && git push --force", cwd=DB_FOLDER, shell=True)
183
 
184
 
185
  app = FastAPI()
 
194
 
195
  # @ app.get("/sync")
196
  # async def sync(background_tasks: BackgroundTasks):
197
+ # await sync_data()
198
  # return "Synced data to huggingface datasets"
199
 
200
 
 
209
  cursor.execute("""
210
  SELECT *, COUNT(*) OVER() AS total
211
  FROM models
212
+ WHERE json_extract(data, '$.likes') > 4
213
+ ORDER BY datetime(json_extract(data, '$.lastModified')) DESC
214
  LIMIT ? OFFSET ?
215
  """, (MAX_PAGE_SIZE, (page - 1) * MAX_PAGE_SIZE))
216
  results = cursor.fetchall()
217
+ total = results[0]['total'] if results else 0
218
  total_pages = (total + MAX_PAGE_SIZE - 1) // MAX_PAGE_SIZE
219
 
220
  return {
221
+ "models": [json.loads(result['data']) for result in results],
222
  "totalPages": total_pages
223
  }
224
 
classifier.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import requests
4
+ import json
5
+ import subprocess
6
+ from io import BytesIO
7
+ import uuid
8
+
9
+ from math import ceil
10
+ from tqdm import tqdm
11
+ from pathlib import Path
12
+
13
+ from db import Database
14
+
15
+ DB_FOLDER = Path("diffusers-gallery-data")
16
+
17
+ database = Database(DB_FOLDER)
18
+
19
+
20
+ CLASSIFIER_URL = "https://radames-aesthetic-style-nsfw-classifier.hf.space/run/inference"
21
+ ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/"
22
+
23
+
24
+ def main():
25
+
26
+ with database.get_db() as db:
27
+ cursor = db.cursor()
28
+ cursor.execute("""
29
+ SELECT *
30
+ FROM models
31
+ """)
32
+ results = list(cursor.fetchall())
33
+
34
+ for row in tqdm(results):
35
+ row_id = row['id']
36
+ # keep json data on row_data
37
+ row_data = json.loads(row['data'])
38
+ print("updating row", row_id)
39
+ images = row_data['images']
40
+
41
+ # filter nones
42
+ images = [i for i in images if i is not None]
43
+ if len(images) > 0:
44
+ # classifying only the first image
45
+ images_urls = [ASSETS_URL + images[0]]
46
+ response = requests.post(CLASSIFIER_URL, json={"data": [
47
+ {"urls": images_urls}, # json urls: list of images urls
48
+ False, # enable/disable gallery image output
49
+ None, # single image input
50
+ None, # files input
51
+ ]}).json()
52
+
53
+ # data response is array data:[[{img0}, {img1}, {img2}...], Label, Gallery],
54
+ class_data = response['data'][0][0]
55
+ class_data_parsed = {row['label']: round(
56
+ row['score'], 3) for row in class_data}
57
+
58
+ # update row data with classificator data
59
+ row_data['class'] = class_data_parsed
60
+ else:
61
+ row_data['class'] = {}
62
+ with database.get_db() as db:
63
+ cursor = db.cursor()
64
+ cursor.execute("UPDATE models SET data = ? WHERE id = ?",
65
+ [json.dumps(row_data), row_id])
66
+ db.commit()
67
+
68
+
69
+ if __name__ == "__main__":
70
+ main()