yjernite HF Staff commited on
Commit
bbf45d0
·
verified ·
1 Parent(s): b29ac5f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +275 -0
app.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import plotly.express as px
5
+
6
+ PIPELINE_TAGS = [
7
+ 'text-generation',
8
+ 'text-to-image',
9
+ 'text-classification',
10
+ 'text2text-generation',
11
+ 'audio-to-audio',
12
+ 'feature-extraction',
13
+ 'image-classification',
14
+ 'translation',
15
+ 'reinforcement-learning',
16
+ 'fill-mask',
17
+ 'text-to-speech',
18
+ 'automatic-speech-recognition',
19
+ 'image-text-to-text',
20
+ 'token-classification',
21
+ 'sentence-similarity',
22
+ 'question-answering',
23
+ 'image-feature-extraction',
24
+ 'summarization',
25
+ 'zero-shot-image-classification',
26
+ 'object-detection',
27
+ 'image-segmentation',
28
+ 'image-to-image',
29
+ 'image-to-text',
30
+ 'audio-classification',
31
+ 'visual-question-answering',
32
+ 'text-to-video',
33
+ 'zero-shot-classification',
34
+ 'depth-estimation',
35
+ 'text-ranking',
36
+ 'image-to-video',
37
+ 'multiple-choice',
38
+ 'unconditional-image-generation',
39
+ 'video-classification',
40
+ 'text-to-audio',
41
+ 'time-series-forecasting',
42
+ 'any-to-any',
43
+ 'video-text-to-text',
44
+ 'table-question-answering',
45
+ ]
46
+
47
+ def is_audio_speech(repo_dct):
48
+ res = (repo_dct.get("pipeline_tag", None) and "audio" in repo_dct.get("pipeline_tag", "").lower()) or \
49
+ (repo_dct.get("pipeline_tag", None) and "speech" in repo_dct.get("pipeline_tag", "").lower()) or \
50
+ (repo_dct.get("tags", None) and any("audio" in tag.lower() for tag in repo_dct.get("tags", []))) or \
51
+ (repo_dct.get("tags", None) and any("speech" in tag.lower() for tag in repo_dct.get("tags", [])))
52
+ return res
53
+
54
+ def is_music(repo_dct):
55
+ res = (repo_dct.get("tags", None) and any("music" in tag.lower() for tag in repo_dct.get("tags", [])))
56
+ return res
57
+
58
+ def is_robotics(repo_dct):
59
+ res = (repo_dct.get("tags", None) and any("robot" in tag.lower() for tag in repo_dct.get("tags", [])))
60
+ return res
61
+
62
+ def is_biomed(repo_dct):
63
+ res = (repo_dct.get("tags", None) and any("bio" in tag.lower() for tag in repo_dct.get("tags", []))) or \
64
+ (repo_dct.get("tags", None) and any("medic" in tag.lower() for tag in repo_dct.get("tags", [])))
65
+ return res
66
+
67
+ def is_timeseries(repo_dct):
68
+ res = (repo_dct.get("tags", None) and any("series" in tag.lower() for tag in repo_dct.get("tags", [])))
69
+ return res
70
+
71
+ def is_science(repo_dct):
72
+ res = (repo_dct.get("tags", None) and any("science" in tag.lower() and not "bigscience" in tag for tag in repo_dct.get("tags", [])))
73
+ return res
74
+
75
+ def is_video(repo_dct):
76
+ res = (repo_dct.get("tags", None) and any("video" in tag.lower() for tag in repo_dct.get("tags", [])))
77
+ return res
78
+
79
+ def is_image(repo_dct):
80
+ res = (repo_dct.get("tags", None) and any("image" in tag.lower() for tag in repo_dct.get("tags", [])))
81
+ return res
82
+
83
+ def is_text(repo_dct):
84
+ res = (repo_dct.get("tags", None) and any("text" in tag.lower() for tag in repo_dct.get("tags", [])))
85
+ return res
86
+
87
+ TAG_FILTER_FUNCS = {
88
+ "Audio & Speech": is_audio_speech,
89
+ "Time series": is_timeseries,
90
+ "Robotics": is_robotics,
91
+ "Music": is_music,
92
+ "Video": is_video,
93
+ "Images": is_image,
94
+ "Text": is_text,
95
+ "Biomedical": is_biomed,
96
+ "Sciences": is_science,
97
+ }
98
+
99
+ def make_org_stats(repo_type, count_by, org_stats, top_k=20, filter_func=None):
100
+ assert count_by in ["likes", "downloads", "downloads_all"]
101
+ assert repo_type in ["all", "datasets", "models"]
102
+ repos = ["datasets", "models"] if repo_type == "all" else [repo_type]
103
+ if filter_func is None:
104
+ filter_func = lambda x: True
105
+ sorted_stats = sorted(
106
+ [(
107
+ author,
108
+ sum(dct[count_by] for dct in author_dct[repo] if filter_func(dct))
109
+ ) for repo in repos for author, author_dct in org_stats.items()],
110
+ key=lambda x:x[1],
111
+ reverse=True,
112
+ )
113
+ res = sorted_stats[:top_k] + [("Others...", sum(st for auth, st in sorted_stats[top_k:]))]
114
+ total_st = sum(st for o, st in res)
115
+ res_plot_df = []
116
+ for org, st in res:
117
+ if org == "Others...":
118
+ res_plot_df += [("Others...", "other", st * 100 / total_st)]
119
+ else:
120
+ for repo in repos:
121
+ for dct in org_stats[org][repo]:
122
+ if filter_func(dct):
123
+ res_plot_df += [(org, dct["id"], dct[count_by] * 100 / total_st)]
124
+ return ([(o, 100 * st / total_st) for o, st in res if st > 0], res_plot_df)
125
+
126
+ def make_figure(count_by, repo_type, org_stats, tag_filter=None, pipeline_filter=None):
127
+ assert count_by in ["downloads", "likes", "downloads_all"]
128
+ assert repo_type in ["all", "models", "datasets"]
129
+ assert tag_filter is None or pipeline_filter is None
130
+ filter_func = None
131
+ if tag_filter:
132
+ filter_func = TAG_FILTER_FUNCS[tag_filter]
133
+ if pipeline_filter:
134
+ filter_func = lambda dct: dct.get("pipeline_tag", None) and dct.get("pipeline_tag", "") == pipeline_filter
135
+ _, res_plot_df = make_org_stats(repo_type, count_by, org_stats, top_k=25, filter_func=filter_func)
136
+ df = pd.DataFrame(
137
+ dict(
138
+ organizations=[o for o, _, _ in res_plot_df],
139
+ repo=[r for _, r, _ in res_plot_df],
140
+ stats=[s for _, _, s in res_plot_df],
141
+ )
142
+ )
143
+ df[repo_type] = repo_type # in order to have a single root node
144
+ fig = px.treemap(df, path=[repo_type, 'organizations', 'repo'], values='stats')
145
+ fig.update_layout(
146
+ treemapcolorway = ["pink" for _ in range(len(res_plot_df))],
147
+ margin = dict(t=50, l=25, r=25, b=25)
148
+ )
149
+ return fig
150
+
151
+
152
+ with gr.Blocks() as demo:
153
+ org_stats_data = gr.State(value=None) # To store loaded data
154
+
155
+ with gr.Row():
156
+ gr.Markdown("""
157
+ ## Hugging Face Organization Stats
158
+
159
+ This app shows how different organizations are contributing to different aspects of the open AI ecosystem.
160
+ Use the dropdowns on the left to select repository types, metrics, and optionally tags representing topics or modalities of interest.
161
+ """)
162
+ with gr.Row():
163
+ with gr.Column(scale=1):
164
+ repo_type_dropdown = gr.Dropdown(
165
+ label="Repository Type",
166
+ choices=["all", "models", "datasets"],
167
+ value="all"
168
+ )
169
+ count_by_dropdown = gr.Dropdown(
170
+ label="Metric",
171
+ choices=["downloads", "likes", "downloads_all"],
172
+ value="downloads"
173
+ )
174
+
175
+ filter_choice_radio = gr.Radio(
176
+ label="Filter by",
177
+ choices=["None", "Tag Filter", "Pipeline Filter"],
178
+ value="None"
179
+ )
180
+
181
+ tag_filter_dropdown = gr.Dropdown(
182
+ label="Select Tag",
183
+ choices=list(TAG_FILTER_FUNCS.keys()),
184
+ value=None,
185
+ visible=False
186
+ )
187
+ pipeline_filter_dropdown = gr.Dropdown(
188
+ label="Select Pipeline Tag",
189
+ choices=PIPELINE_TAGS,
190
+ value=None,
191
+ visible=False
192
+ )
193
+
194
+ generate_plot_button = gr.Button("Generate Plot")
195
+
196
+ with gr.Column(scale=3):
197
+ plot_output = gr.Plot()
198
+
199
+ def generate_plot_on_click(repo_type, count_by, filter_choice, tag_filter, pipeline_filter, data):
200
+ # Print the current state of the input variables
201
+ print(f"Generating plot with the following inputs:")
202
+ print(f" Repository Type: {repo_type}")
203
+ print(f" Metric (Count By): {count_by}")
204
+ print(f" Filter Choice: {filter_choice}")
205
+ if filter_choice == "Tag Filter":
206
+ print(f" Tag Filter: {tag_filter}")
207
+ elif filter_choice == "Pipeline Filter":
208
+ print(f" Pipeline Filter: {pipeline_filter}")
209
+
210
+ if data is None:
211
+ print("Error: Data not loaded yet.")
212
+ return None
213
+
214
+ selected_tag_filter = None
215
+ selected_pipeline_filter = None
216
+
217
+ if filter_choice == "Tag Filter":
218
+ selected_tag_filter = tag_filter
219
+ elif filter_choice == "Pipeline Filter":
220
+ selected_pipeline_filter = pipeline_filter
221
+
222
+ fig = make_figure(
223
+ count_by=count_by,
224
+ repo_type=repo_type,
225
+ org_stats=data,
226
+ tag_filter=selected_tag_filter,
227
+ pipeline_filter=selected_pipeline_filter
228
+ )
229
+ return fig
230
+
231
+ def update_filter_visibility(filter_choice):
232
+ if filter_choice == "Tag Filter":
233
+ return gr.update(visible=True), gr.update(visible=False)
234
+ elif filter_choice == "Pipeline Filter":
235
+ return gr.update(visible=False), gr.update(visible=True)
236
+ else: # "None"
237
+ return gr.update(visible=False), gr.update(visible=False)
238
+
239
+ filter_choice_radio.change(
240
+ fn=update_filter_visibility,
241
+ inputs=[filter_choice_radio],
242
+ outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
243
+ )
244
+
245
+ # Load data once at startup
246
+ def load_org_data():
247
+ print("Loading organization statistics data...")
248
+ loaded_org_stats = json.load(open("org_to_artifacts_2l_stats.json"))
249
+ print("Data loaded successfully.")
250
+ return loaded_org_stats
251
+
252
+ demo.load(
253
+ fn=load_org_data,
254
+ inputs=[], # No inputs needed to just load data
255
+ outputs=[org_stats_data] # Only output to the state
256
+ )
257
+
258
+ # Button click event to generate plot
259
+ generate_plot_button.click(
260
+ fn=generate_plot_on_click,
261
+ inputs=[
262
+ repo_type_dropdown,
263
+ count_by_dropdown,
264
+ filter_choice_radio,
265
+ tag_filter_dropdown,
266
+ pipeline_filter_dropdown,
267
+ org_stats_data
268
+ ],
269
+ outputs=[plot_output]
270
+ )
271
+
272
+
273
+ if __name__ == "__main__":
274
+ # org_stats = json.load(open("org_to_artifacts_2l_stats.json")) # Data loading handled by demo.load
275
+ demo.launch()