radames commited on
Commit
8964ef4
·
1 Parent(s): a4b71bb

resize images and use s3

Browse files
Files changed (3) hide show
  1. app.py +76 -44
  2. requirements.txt +3 -1
  3. schema.sql +1 -1
app.py CHANGED
@@ -1,46 +1,72 @@
1
  import os
2
  import re
3
- import asyncio
4
  import aiohttp
5
  import requests
6
  import json
 
 
 
 
 
 
7
  from tqdm import tqdm
8
  from pathlib import Path
9
  from huggingface_hub import Repository
10
-
11
  from fastapi import FastAPI, BackgroundTasks
12
  from fastapi_utils.tasks import repeat_every
13
  from fastapi.middleware.cors import CORSMiddleware
 
14
 
15
  from db import Database
16
 
 
 
 
 
 
17
  HF_TOKEN = os.environ.get("HF_TOKEN")
18
 
 
19
 
20
  DB_FOLDER = Path("diffusers-gallery-data")
21
 
 
 
 
22
 
23
- repo = Repository(
24
- local_dir=DB_FOLDER,
25
- repo_type="dataset",
26
- clone_from="huggingface-projects/diffusers-gallery-data",
27
- use_auth_token=True,
28
- )
29
- repo.git_pull()
 
30
 
31
  database = Database(DB_FOLDER)
32
 
33
 
34
- async def check_image_url(url):
35
- async with aiohttp.ClientSession() as session:
36
- async with session.head(url) as resp:
37
- if resp.status == 200 and resp.content_type.startswith("image/"):
38
- return url
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  def fetch_models(page=0):
42
  response = requests.get(
43
- f'https://huggingface.co/models-json?pipeline_tag=text-to-image&p={page}&sort=likes')
44
  data = response.json()
45
  return {
46
  "models": [model for model in data['models'] if not model['private']],
@@ -56,18 +82,16 @@ def fetch_model_card(model):
56
  return response.text
57
 
58
 
59
- def find_image_in_model_card(text):
60
  image_regex = re.compile(r'https?://\S+(?:png|jpg|jpeg|webp)')
61
  urls = re.findall(image_regex, text)
62
- # python check if arrya is not empty
63
- # if urls:
64
- # tasks = []
65
- # for url in urls:
66
- # tasks.append(check_image_url(url))
67
 
68
- # results = await asyncio.gather(*tasks)
69
- # return [result for result in results if result]
70
- return urls
 
71
 
72
 
73
  def run_inference(endpoint, img):
@@ -79,10 +103,12 @@ def run_inference(endpoint, img):
79
  return response.json() if response.ok else []
80
 
81
 
82
- def get_all_models():
83
- initial = fetch_models()
84
- num_pages = initial['numTotalItems'] // initial['numItemsPerPage']
85
 
 
 
86
  print(f"Found {num_pages} pages")
87
 
88
  # fetch all models
@@ -92,13 +118,16 @@ def get_all_models():
92
  page_models = fetch_models(page)
93
  models += page_models['models']
94
 
 
 
 
95
  # fetch datacards and images
96
  print(f"Found {len(models)} models")
97
  final_models = []
98
  for model in tqdm(models):
99
  print(f"Fetching model {model['id']}")
100
  model_card = fetch_model_card(model)
101
- images = find_image_in_model_card(model_card)
102
  # style = await run_inference(f"https://api-inference.huggingface.co/models/{model['id']}", images[0])
103
  style = []
104
  # aesthetic = await run_inference(f"https://api-inference.huggingface.co/models/{model['id']}", images[0])
@@ -110,20 +139,27 @@ def get_all_models():
110
 
111
 
112
  async def sync_data():
113
- models = get_all_models()
 
114
 
115
- with open("data/models.json", "w") as f:
116
  json.dump(models, f)
117
-
 
 
 
118
  with database.get_db() as db:
119
  cursor = db.cursor()
120
  for model in models:
121
  try:
122
- cursor.execute("INSERT INTO models (data) VALUES (?)",
123
- [json.dumps(model)])
124
  except Exception as e:
125
  print(model['id'], model)
126
  db.commit()
 
 
 
127
 
128
 
129
  app = FastAPI()
@@ -145,19 +181,14 @@ async def sync(background_tasks: BackgroundTasks):
145
  MAX_PAGE_SIZE = 30
146
 
147
 
148
- @app.get("/api/models")
149
  def get_page(page: int = 1):
150
  page = page if page > 0 else 1
151
  with database.get_db() as db:
152
  cursor = db.cursor()
153
  cursor.execute("""
154
- SELECT *
155
- FROM (
156
- SELECT *, COUNT(*) OVER() AS total
157
- FROM models
158
- GROUP BY json_extract(data, '$.id')
159
- HAVING COUNT(json_extract(data, '$.id')) = 1
160
- )
161
  ORDER BY json_extract(data, '$.likes') DESC
162
  LIMIT ? OFFSET ?
163
  """, (MAX_PAGE_SIZE, (page - 1) * MAX_PAGE_SIZE))
@@ -175,8 +206,9 @@ def get_page(page: int = 1):
175
  def read_root():
176
  return "Just a bot to sync data from diffusers gallery"
177
 
 
178
  # @app.on_event("startup")
179
- # @repeat_every(seconds=1800)
180
- # def repeat_sync():
181
- # sync_rooms_to_dataset()
182
  # return "Synced data to huggingface datasets"
 
1
  import os
2
  import re
 
3
  import aiohttp
4
  import requests
5
  import json
6
+ import subprocess
7
+ import asyncio
8
+ from io import BytesIO
9
+ import uuid
10
+
11
+ from math import ceil
12
  from tqdm import tqdm
13
  from pathlib import Path
14
  from huggingface_hub import Repository
15
+ from PIL import Image, ImageOps
16
  from fastapi import FastAPI, BackgroundTasks
17
  from fastapi_utils.tasks import repeat_every
18
  from fastapi.middleware.cors import CORSMiddleware
19
+ import boto3
20
 
21
  from db import Database
22
 
23
+ AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
24
+ AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY')
25
+ AWS_S3_BUCKET_NAME = os.getenv('AWS_S3_BUCKET_NAME')
26
+
27
+
28
  HF_TOKEN = os.environ.get("HF_TOKEN")
29
 
30
+ S3_DATA_FOLDER = Path("sd-multiplayer-data")
31
 
32
  DB_FOLDER = Path("diffusers-gallery-data")
33
 
34
+ s3 = boto3.client(service_name='s3',
35
+ aws_access_key_id=AWS_ACCESS_KEY_ID,
36
+ aws_secret_access_key=AWS_SECRET_KEY)
37
 
38
+
39
+ # repo = Repository(
40
+ # local_dir=DB_FOLDER,
41
+ # repo_type="dataset",
42
+ # clone_from="huggingface-projects/diffusers-gallery-data",
43
+ # use_auth_token=True,
44
+ # )
45
+ # repo.git_pull()
46
 
47
  database = Database(DB_FOLDER)
48
 
49
 
50
+ async def upload_resize_image_url(session, image_url):
51
+ print(f"Uploading image {image_url}")
52
+ async with session.get(image_url) as response:
53
+ if response.status == 200 and response.headers['content-type'].startswith('image'):
54
+ image = Image.open(BytesIO(await response.read())).convert('RGB')
55
+ # resize image proportional
56
+ image = ImageOps.fit(image, (400, 400), Image.LANCZOS)
57
+ image_bytes = BytesIO()
58
+ image.save(image_bytes, format="JPEG")
59
+ image_bytes.seek(0)
60
+ fname = f'{uuid.uuid4()}.jpg'
61
+ s3.upload_fileobj(Fileobj=image_bytes, Bucket=AWS_S3_BUCKET_NAME, Key="diffusers-gallery/" + fname,
62
+ ExtraArgs={"ContentType": "image/jpeg", "CacheControl": "max-age=31536000"})
63
+ return fname
64
+ return None
65
 
66
 
67
  def fetch_models(page=0):
68
  response = requests.get(
69
+ f'https://huggingface.co/models-json?pipeline_tag=text-to-image&p={page}')
70
  data = response.json()
71
  return {
72
  "models": [model for model in data['models'] if not model['private']],
 
82
  return response.text
83
 
84
 
85
+ async def find_image_in_model_card(text):
86
  image_regex = re.compile(r'https?://\S+(?:png|jpg|jpeg|webp)')
87
  urls = re.findall(image_regex, text)
88
+ if not urls:
89
+ return []
 
 
 
90
 
91
+ async with aiohttp.ClientSession() as session:
92
+ tasks = [asyncio.ensure_future(upload_resize_image_url(
93
+ session, image_url)) for image_url in urls[0:3]]
94
+ return await asyncio.gather(*tasks)
95
 
96
 
97
  def run_inference(endpoint, img):
 
103
  return response.json() if response.ok else []
104
 
105
 
106
+ async def get_all_models():
107
+ initial = fetch_models(0)
108
+ num_pages = ceil(initial['numTotalItems'] / initial['numItemsPerPage'])
109
 
110
+ print(
111
+ f"Total items: {initial['numTotalItems']} - Items per page: {initial['numItemsPerPage']}")
112
  print(f"Found {num_pages} pages")
113
 
114
  # fetch all models
 
118
  page_models = fetch_models(page)
119
  models += page_models['models']
120
 
121
+ with open(DB_FOLDER / "models_temp.json", "w") as f:
122
+ json.dump(models, f)
123
+
124
  # fetch datacards and images
125
  print(f"Found {len(models)} models")
126
  final_models = []
127
  for model in tqdm(models):
128
  print(f"Fetching model {model['id']}")
129
  model_card = fetch_model_card(model)
130
+ images = await find_image_in_model_card(model_card)
131
  # style = await run_inference(f"https://api-inference.huggingface.co/models/{model['id']}", images[0])
132
  style = []
133
  # aesthetic = await run_inference(f"https://api-inference.huggingface.co/models/{model['id']}", images[0])
 
139
 
140
 
141
  async def sync_data():
142
+ print("Fetching models")
143
+ models = await get_all_models()
144
 
145
+ with open(DB_FOLDER / "models.json", "w") as f:
146
  json.dump(models, f)
147
+ # with open(DB_FOLDER / "models.json", "r") as f:
148
+ # models = json.load(f)
149
+ # open temp db
150
+ print("Updating database")
151
  with database.get_db() as db:
152
  cursor = db.cursor()
153
  for model in models:
154
  try:
155
+ cursor.execute("INSERT INTO models(id, data) VALUES (?, ?)",
156
+ [model['id'], json.dumps(model)])
157
  except Exception as e:
158
  print(model['id'], model)
159
  db.commit()
160
+ print("Updating repository")
161
+ # subprocess.Popen(
162
+ # "git add . && git commit --amend -m 'update' && git push --force", cwd=DB_FOLDER, shell=True)
163
 
164
 
165
  app = FastAPI()
 
181
  MAX_PAGE_SIZE = 30
182
 
183
 
184
+ @ app.get("/api/models")
185
  def get_page(page: int = 1):
186
  page = page if page > 0 else 1
187
  with database.get_db() as db:
188
  cursor = db.cursor()
189
  cursor.execute("""
190
+ SELECT *, COUNT(*) OVER() AS total
191
+ FROM models
 
 
 
 
 
192
  ORDER BY json_extract(data, '$.likes') DESC
193
  LIMIT ? OFFSET ?
194
  """, (MAX_PAGE_SIZE, (page - 1) * MAX_PAGE_SIZE))
 
206
  def read_root():
207
  return "Just a bot to sync data from diffusers gallery"
208
 
209
+
210
  # @app.on_event("startup")
211
+ # @repeat_every(seconds=60 * 60 * 24, wait_first=False)
212
+ # async def repeat_sync():
213
+ # await sync_data()
214
  # return "Synced data to huggingface datasets"
requirements.txt CHANGED
@@ -5,4 +5,6 @@ tqdm
5
  fastapi
6
  requests
7
  asyncio
8
- aiohttp
 
 
 
5
  fastapi
6
  requests
7
  asyncio
8
+ aiohttp
9
+ Pillow
10
+ boto3
schema.sql CHANGED
@@ -3,7 +3,7 @@ PRAGMA foreign_keys = OFF;
3
  BEGIN TRANSACTION;
4
 
5
  CREATE TABLE models (
6
- id INTEGER PRIMARY KEY AUTOINCREMENT,
7
  data json,
8
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
9
  );
 
3
  BEGIN TRANSACTION;
4
 
5
  CREATE TABLE models (
6
+ id TEXT PRIMARY KEY NOT NULL,
7
  data json,
8
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
9
  );