davanstrien HF staff commited on
Commit
92e2ee4
1 Parent(s): f016495
Files changed (3) hide show
  1. app.py +262 -0
  2. requirements.in +9 -0
  3. requirements.txt +209 -0
app.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import re
4
+ from typing import Dict
5
+
6
+ import gradio as gr
7
+ import httpx
8
+ from cachetools import TTLCache, cached
9
+ from cashews import NOT_NONE, cache
10
+ from dotenv import load_dotenv
11
+ from httpx import AsyncClient, Limits
12
+ from huggingface_hub import (
13
+ ModelCard,
14
+ ModelFilter,
15
+ get_repo_discussions,
16
+ hf_hub_url,
17
+ list_models,
18
+ logging,
19
+ )
20
+ from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError
21
+ from tqdm.asyncio import tqdm as atqdm
22
+ from tqdm.auto import tqdm
23
+ import random
24
+
25
+ cache.setup("mem://")
26
+
27
+
28
+ load_dotenv()
29
+ token = os.environ["HUGGINGFACE_TOKEN"]
30
+ user_agent = os.environ["USER_AGENT"]
31
+ assert token
32
+ assert user_agent
33
+
34
+ headers = {"user-agent": user_agent, "authorization": f"Bearer {token}"}
35
+
36
+ limits = Limits(max_keepalive_connections=10, max_connections=50)
37
+
38
+
39
+ def create_client():
40
+ return AsyncClient(headers=headers, limits=limits, http2=True)
41
+
42
+
43
+ @cached(cache=TTLCache(maxsize=100, ttl=60 * 10))
44
+ def get_models(user_or_org):
45
+ model_filter = ModelFilter(library="transformers", author=user_or_org)
46
+ return list(
47
+ tqdm(
48
+ iter(
49
+ list_models(
50
+ filter=model_filter,
51
+ sort="downloads",
52
+ direction=-1,
53
+ cardData=True,
54
+ full=True,
55
+ )
56
+ )
57
+ )
58
+ )
59
+
60
+
61
+ def filter_models(models):
62
+ new_models = []
63
+ for model in tqdm(models):
64
+ try:
65
+ if card_data := model.cardData:
66
+ base_model = card_data.get("base_model", None)
67
+ if not base_model:
68
+ new_models.append(model)
69
+ except AttributeError:
70
+ continue
71
+ return new_models
72
+
73
+
74
+ MODEL_ID_RE_PATTERN = re.compile(
75
+ "This model is a fine-tuned version of \[(.*?)\]\(.*?\)"
76
+ )
77
+ BASE_MODEL_PATTERN = re.compile("base_model:\s+(.+)")
78
+
79
+
80
+ @cached(cache=TTLCache(maxsize=100, ttl=60 * 3))
81
+ def has_model_card(model):
82
+ if siblings := model.siblings:
83
+ for sibling in siblings:
84
+ if sibling.rfilename == "README.md":
85
+ return True
86
+ return False
87
+
88
+
89
+ @cached(cache=TTLCache(maxsize=100, ttl=60))
90
+ def check_already_has_base_model(text):
91
+ return bool(re.search(BASE_MODEL_PATTERN, text))
92
+
93
+
94
+ @cached(cache=TTLCache(maxsize=100, ttl=60))
95
+ def extract_model_name(text):
96
+ return match.group(1) if (match := re.search(MODEL_ID_RE_PATTERN, text)) else None
97
+
98
+
99
+ # semaphore = asyncio.Semaphore(10) # Maximum number of concurrent tasks
100
+
101
+
102
+ @cache(ttl=120, condition=NOT_NONE)
103
+ async def check_readme_for_match(model):
104
+ if not has_model_card(model):
105
+ return None
106
+ model_card_url = hf_hub_url(model.modelId, "README.md")
107
+ client = create_client()
108
+ try:
109
+ resp = await client.get(model_card_url)
110
+ if check_already_has_base_model(resp.text):
111
+ return None
112
+ else:
113
+ return None if resp.status_code != 200 else extract_model_name(resp.text)
114
+ except httpx.ConnectError:
115
+ return None
116
+ except httpx.ReadTimeout:
117
+ return None
118
+ except httpx.ConnectTimeout:
119
+ return None
120
+ except Exception as e:
121
+ print(e)
122
+ return None
123
+
124
+
125
+ @cache(ttl=120, condition=NOT_NONE)
126
+ async def check_model_exists(model, match):
127
+ client = create_client()
128
+ url = f"https://huggingface.co/api/models/{match}"
129
+ try:
130
+ resp = await client.get(url)
131
+ if resp.status_code == 200:
132
+ return {"modelid": model.modelId, "match": match}
133
+ if resp.status_code == 401:
134
+ return False
135
+ except httpx.ConnectError:
136
+ return None
137
+ except httpx.ReadTimeout:
138
+ return None
139
+ except httpx.ConnectTimeout:
140
+ return None
141
+ except Exception as e:
142
+ print(e)
143
+ return None
144
+
145
+
146
+ @cache(ttl=120, condition=NOT_NONE)
147
+ async def check_model(model):
148
+ match = await check_readme_for_match(model)
149
+ if match:
150
+ return await check_model_exists(model, match)
151
+
152
+
153
+ async def prep_tasks(models):
154
+ tasks = []
155
+ for model in models:
156
+ task = asyncio.create_task(check_model(model))
157
+ tasks.append(task)
158
+ return [await f for f in atqdm.as_completed(tasks)]
159
+
160
+
161
+ def get_data_for_user(user_or_org):
162
+ models = get_models(user_or_org)
163
+ models = filter_models(models)
164
+ results = asyncio.run(prep_tasks(models))
165
+ results = [r for r in results if r is not None]
166
+ return results
167
+
168
+
169
+ logger = logging.get_logger()
170
+
171
+ token = os.getenv("HUGGINGFACE_TOKEN")
172
+
173
+
174
+ def generate_issue_text(based_model_regex_match, opened_by=None):
175
+ return f"""This pull request aims to enrich the metadata of your model by adding [`{based_model_regex_match}`](https://huggingface.co/{based_model_regex_match}) as a `base_model` field, situated in the `YAML` block of your model's `README.md`.
176
+
177
+ How did we find this information? We performed a regular expression match on your `README.md` file to determine the connection.
178
+
179
+ **Why add this?** Enhancing your model's metadata in this way:
180
+ - **Boosts Discoverability** - It becomes straightforward to trace the relationships between various models on the Hugging Face Hub.
181
+ - **Highlights Impact** - It showcases the contributions and influences different models have within the community.
182
+
183
+ For a hands-on example of how such metadata can play a pivotal role in mapping model connections, take a look at [librarian-bots/base_model_explorer](https://huggingface.co/spaces/librarian-bots/base_model_explorer).
184
+
185
+ This PR comes courtesy of [Librarian Bot](https://huggingface.co/librarian-bot) by request of {opened_by}"""
186
+
187
+
188
+ def update_metadata(metadata_payload: Dict[str, str], user_making_request=None):
189
+ metadata_payload["opened_pr"] = False
190
+ regex_match = metadata_payload["match"]
191
+ repo_id = metadata_payload["modelid"]
192
+ try:
193
+ model_card = ModelCard.load(repo_id)
194
+ except RepositoryNotFoundError:
195
+ return metadata_payload
196
+ model_card.data["base_model"] = regex_match
197
+ template = generate_issue_text(regex_match, opened_by=user_making_request)
198
+ try:
199
+ if previous_discussions := list(get_repo_discussions(repo_id)):
200
+ logger.info("found previous discussions")
201
+ if prs := [
202
+ discussion
203
+ for discussion in previous_discussions
204
+ if discussion.is_pull_request
205
+ ]:
206
+ logger.info("found previous pull requests")
207
+ for pr in prs:
208
+ if pr.author == "librarian-bot":
209
+ logger.info("previously opened PR")
210
+ if (
211
+ pr.title
212
+ == "Librarian Bot: Add base_model information to model"
213
+ ):
214
+ logger.info("previously opened PR to add base_model tag")
215
+ metadata_payload["opened_pr"] = True
216
+ return metadata_payload
217
+ model_card.push_to_hub(
218
+ repo_id,
219
+ token=token,
220
+ repo_type="model",
221
+ create_pr=True,
222
+ commit_message="Librarian Bot: Add base_model information to model",
223
+ commit_description=template,
224
+ )
225
+ metadata_payload["opened_pr"] = True
226
+ return metadata_payload
227
+ except HfHubHTTPError:
228
+ return metadata_payload
229
+
230
+
231
+ def open_prs(profile: gr.OAuthProfile | None, user_or_org: str = None):
232
+ if not profile:
233
+ return "Please login to open PR requests"
234
+ username = profile.preferred_username
235
+ user_to_receive_prs = user_or_org or username
236
+ data = get_data_for_user(user_to_receive_prs)
237
+ if user_or_org:
238
+ random.sample(data, min(10, len(data)))
239
+ if not data:
240
+ return "No PRs to open"
241
+ results = []
242
+ for metadata_payload in data:
243
+ try:
244
+ results.append(
245
+ update_metadata(metadata_payload, user_making_request=username)
246
+ )
247
+
248
+ except Exception as e:
249
+ logger.error(e)
250
+ return f"Opened {len([r for r in results if r['opened_pr']])} PRs"
251
+
252
+
253
+ with gr.Blocks() as demo:
254
+ gr.Markdown("# Librarian Bot")
255
+ gr.LoginButton(), gr.LogoutButton()
256
+ user = gr.Textbox(label="user or org to Open PRs for")
257
+ button = gr.Button()
258
+ results = gr.Markdown()
259
+ button.click(open_prs, [user], results)
260
+
261
+
262
+ demo.queue(concurrency_count=1).launch()
requirements.in ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ cachetools
2
+ cashews
3
+ diskcache
4
+ gradio[oauth]
5
+ httpx[http2]
6
+ huggingface_hub
7
+ python-dotenv
8
+ toolz
9
+ tqdm
requirements.txt ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is autogenerated by pip-compile with Python 3.11
3
+ # by the following command:
4
+ #
5
+ # pip-compile requirements.in
6
+ #
7
+ aiofiles==23.2.1
8
+ # via gradio
9
+ altair==5.1.1
10
+ # via gradio
11
+ annotated-types==0.5.0
12
+ # via pydantic
13
+ anyio==3.7.1
14
+ # via
15
+ # fastapi
16
+ # httpcore
17
+ # starlette
18
+ attrs==23.1.0
19
+ # via
20
+ # jsonschema
21
+ # referencing
22
+ authlib==1.2.1
23
+ # via gradio
24
+ cachetools==5.3.1
25
+ # via -r requirements.in
26
+ cashews==6.2.0
27
+ # via -r requirements.in
28
+ certifi==2023.7.22
29
+ # via
30
+ # httpcore
31
+ # httpx
32
+ # requests
33
+ cffi==1.15.1
34
+ # via cryptography
35
+ charset-normalizer==3.2.0
36
+ # via requests
37
+ click==8.1.7
38
+ # via uvicorn
39
+ contourpy==1.1.0
40
+ # via matplotlib
41
+ cryptography==41.0.3
42
+ # via authlib
43
+ cycler==0.11.0
44
+ # via matplotlib
45
+ diskcache==5.6.3
46
+ # via -r requirements.in
47
+ fastapi==0.103.1
48
+ # via gradio
49
+ ffmpy==0.3.1
50
+ # via gradio
51
+ filelock==3.12.3
52
+ # via huggingface-hub
53
+ fonttools==4.42.1
54
+ # via matplotlib
55
+ fsspec==2023.9.0
56
+ # via
57
+ # gradio-client
58
+ # huggingface-hub
59
+ gradio[oauth]==3.43.2
60
+ # via -r requirements.in
61
+ gradio-client==0.5.0
62
+ # via gradio
63
+ h11==0.14.0
64
+ # via
65
+ # httpcore
66
+ # uvicorn
67
+ h2==4.1.0
68
+ # via httpx
69
+ hpack==4.0.0
70
+ # via h2
71
+ httpcore==0.18.0
72
+ # via httpx
73
+ httpx[http2]==0.25.0
74
+ # via
75
+ # -r requirements.in
76
+ # gradio
77
+ # gradio-client
78
+ huggingface-hub==0.17.0
79
+ # via
80
+ # -r requirements.in
81
+ # gradio
82
+ # gradio-client
83
+ hyperframe==6.0.1
84
+ # via h2
85
+ idna==3.4
86
+ # via
87
+ # anyio
88
+ # httpx
89
+ # requests
90
+ importlib-resources==6.0.1
91
+ # via gradio
92
+ itsdangerous==2.1.2
93
+ # via gradio
94
+ jinja2==3.1.2
95
+ # via
96
+ # altair
97
+ # gradio
98
+ jsonschema==4.19.0
99
+ # via altair
100
+ jsonschema-specifications==2023.7.1
101
+ # via jsonschema
102
+ kiwisolver==1.4.5
103
+ # via matplotlib
104
+ markupsafe==2.1.3
105
+ # via
106
+ # gradio
107
+ # jinja2
108
+ matplotlib==3.7.2
109
+ # via gradio
110
+ numpy==1.25.2
111
+ # via
112
+ # altair
113
+ # contourpy
114
+ # gradio
115
+ # matplotlib
116
+ # pandas
117
+ orjson==3.9.7
118
+ # via gradio
119
+ packaging==23.1
120
+ # via
121
+ # altair
122
+ # gradio
123
+ # gradio-client
124
+ # huggingface-hub
125
+ # matplotlib
126
+ pandas==2.1.0
127
+ # via
128
+ # altair
129
+ # gradio
130
+ pillow==10.0.0
131
+ # via
132
+ # gradio
133
+ # matplotlib
134
+ pycparser==2.21
135
+ # via cffi
136
+ pydantic==2.3.0
137
+ # via
138
+ # fastapi
139
+ # gradio
140
+ pydantic-core==2.6.3
141
+ # via pydantic
142
+ pydub==0.25.1
143
+ # via gradio
144
+ pyparsing==3.0.9
145
+ # via matplotlib
146
+ python-dateutil==2.8.2
147
+ # via
148
+ # matplotlib
149
+ # pandas
150
+ python-dotenv==1.0.0
151
+ # via -r requirements.in
152
+ python-multipart==0.0.6
153
+ # via gradio
154
+ pytz==2023.3.post1
155
+ # via pandas
156
+ pyyaml==6.0.1
157
+ # via
158
+ # gradio
159
+ # huggingface-hub
160
+ referencing==0.30.2
161
+ # via
162
+ # jsonschema
163
+ # jsonschema-specifications
164
+ requests==2.31.0
165
+ # via
166
+ # gradio
167
+ # gradio-client
168
+ # huggingface-hub
169
+ rpds-py==0.10.2
170
+ # via
171
+ # jsonschema
172
+ # referencing
173
+ semantic-version==2.10.0
174
+ # via gradio
175
+ six==1.16.0
176
+ # via python-dateutil
177
+ sniffio==1.3.0
178
+ # via
179
+ # anyio
180
+ # httpcore
181
+ # httpx
182
+ starlette==0.27.0
183
+ # via fastapi
184
+ toolz==0.12.0
185
+ # via
186
+ # -r requirements.in
187
+ # altair
188
+ tqdm==4.66.1
189
+ # via
190
+ # -r requirements.in
191
+ # huggingface-hub
192
+ typing-extensions==4.7.1
193
+ # via
194
+ # fastapi
195
+ # gradio
196
+ # gradio-client
197
+ # huggingface-hub
198
+ # pydantic
199
+ # pydantic-core
200
+ tzdata==2023.3
201
+ # via pandas
202
+ urllib3==2.0.4
203
+ # via requests
204
+ uvicorn==0.23.2
205
+ # via gradio
206
+ websockets==11.0.3
207
+ # via
208
+ # gradio
209
+ # gradio-client