Benjamin Bossan commited on
Commit
76ffd6d
1 Parent(s): ceead2c

Initial commit

Browse files
Files changed (3) hide show
  1. app.py +391 -0
  2. make-data.py +26 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF space creator starting from an sklearn model
2
+
3
+ import base64
4
+ import glob
5
+ import io
6
+ import json
7
+ import os
8
+ import pickle
9
+ import re
10
+ import shutil
11
+ from pathlib import Path
12
+ from tempfile import mkdtemp
13
+
14
+ import pandas as pd
15
+ import sklearn
16
+ import streamlit as st
17
+ from sklearn.base import BaseEstimator
18
+
19
+ import skops.io as sio
20
+ from skops import card, hub_utils
21
+
22
+ st.set_page_config(layout="wide")
23
+ st.title("Skops space creator for sklearn")
24
+
25
+
26
+ PLACEHOLDER = "[More Information Needed]"
27
+ PLOT_PREFIX = "__plot__:"
28
+ custom_sections: dict[str, str] = {}
29
+ tmp_repo = Path(mkdtemp(prefix="skops-"))
30
+ left_col, right_col = st.columns([1, 2])
31
+
32
+ # a hacky way to "persist" custom sections
33
+ CUSTOM_SECTIONS_CACHE_FILE = ".custom-sections.json"
34
+
35
+
36
+ def _clear_custom_section_cache():
37
+ with open(CUSTOM_SECTIONS_CACHE_FILE, "w") as f:
38
+ f.write("")
39
+
40
+
41
+ def _load_custom_section_cache():
42
+ global custom_sections
43
+
44
+ # in case file doesn't exist yet, create it
45
+ if not os.path.exists(CUSTOM_SECTIONS_CACHE_FILE):
46
+ Path(CUSTOM_SECTIONS_CACHE_FILE).touch()
47
+
48
+ with open(CUSTOM_SECTIONS_CACHE_FILE, "r") as f:
49
+ try:
50
+ custom_sections = json.load(f)
51
+ except ValueError:
52
+ pass
53
+
54
+
55
+ def _write_custom_section_cache():
56
+ with open(CUSTOM_SECTIONS_CACHE_FILE, "w") as f:
57
+ json.dump(custom_sections, f)
58
+
59
+
60
+ def _remove_custom_section(key):
61
+ del custom_sections[key]
62
+ _write_custom_section_cache()
63
+
64
+
65
+ def _clear_repo(path):
66
+ for file_path in glob.glob(str(Path(path) / "*")):
67
+ if os.path.isfile(file_path) or os.path.islink(file_path):
68
+ os.unlink(file_path)
69
+ elif os.path.isdir(file_path):
70
+ shutil.rmtree(file_path)
71
+
72
+
73
+ def _write_plot(plot_name, plot_file):
74
+ with open(plot_name, "wb") as f:
75
+ f.write(plot_file)
76
+
77
+
78
+ def init_repo():
79
+ _clear_repo(tmp_repo)
80
+
81
+ try:
82
+ file_name = Path(mkdtemp(prefix="skops-")) / "model.skops"
83
+ sio.dump(model, file_name)
84
+ hub_utils.init(
85
+ model=file_name,
86
+ dst=tmp_repo,
87
+ task=task,
88
+ data=data,
89
+ requirements=requirements,
90
+ )
91
+ except Exception as exc:
92
+ print("Uh oh, something went wrong when initializing the repo:", exc)
93
+
94
+
95
+ def load_model():
96
+ if model_file is None:
97
+ return
98
+
99
+ bytes_data = model_file.getvalue()
100
+ model = pickle.loads(bytes_data)
101
+ assert isinstance(model, BaseEstimator), "model must be an sklearn model"
102
+ return model
103
+
104
+
105
+ def load_data():
106
+ if data_file is None:
107
+ return
108
+
109
+ bytes_data = io.BytesIO(data_file.getvalue())
110
+ df = pd.read_csv(bytes_data)
111
+ return df
112
+
113
+
114
+ def _parse_metrics(metrics):
115
+ metrics_table = {}
116
+ for line in metrics.splitlines():
117
+ line = line.strip()
118
+ name, _, val = line.partition("=")
119
+ try:
120
+ # try to coerce to float but don't error if it fails
121
+ val = float(val.strip())
122
+ except ValueError:
123
+ pass
124
+ metrics_table[name.strip()] = val
125
+ return metrics_table
126
+
127
+
128
+ def _create_model_card():
129
+ if model is None or data is None:
130
+ st.text("*some data is missing to render the model card*")
131
+ return
132
+
133
+ init_repo()
134
+ metadata = card.metadata_from_config(tmp_repo)
135
+ model_card = card.Card(model=model, metadata=metadata)
136
+
137
+ if model_description:
138
+ model_card.add(**{"Model description": model_description})
139
+
140
+ if intended_uses:
141
+ model_card.add(
142
+ **{"Model description/Intended uses & limitations": intended_uses}
143
+ )
144
+
145
+ if metrics:
146
+ metrics_table = _parse_metrics(metrics)
147
+ model_card.add_metrics(**metrics_table)
148
+
149
+ if authors:
150
+ model_card.add(**{"Model Card Authors": authors})
151
+
152
+ if contact:
153
+ model_card.add(**{"Model Card Contact": contact})
154
+
155
+ if citation:
156
+ model_card.add(**{"Citation": citation})
157
+
158
+ if custom_sections:
159
+ for key, val in custom_sections.items():
160
+ if not key:
161
+ continue
162
+
163
+ if key.startswith(PLOT_PREFIX):
164
+ key = key[len(PLOT_PREFIX):]
165
+ model_card.add_plot(**{key: val})
166
+ else:
167
+ model_card.add(**{key: val})
168
+
169
+ return model_card
170
+
171
+
172
+ def _process_card_for_rendering(rendered: str) -> tuple[str, str]:
173
+ idx = rendered[1:].index("\n---") + 1
174
+ metadata = rendered[3:idx]
175
+ rendered = rendered[idx + 4 :] # noqa: E203
176
+
177
+ # below is a hack to display the images in streamlit
178
+ # https://discuss.streamlit.io/t/image-in-markdown/13274/10 The problem is
179
+
180
+ # that streamlit does not display images in markdown, so we need to replace
181
+ # them with html. However, we only want that in the rendered markdown, not
182
+ # in the card that is produced for the hub
183
+ def markdown_images(markdown):
184
+ # example image markdown:
185
+ # ![Test image](images/test.png "Alternate text")
186
+ images = re.findall(
187
+ r'(!\[(?P<image_title>[^\]]+)\]\((?P<image_path>[^\)"\s]+)\s*([^\)]*)\))',
188
+ markdown
189
+ )
190
+ return images
191
+
192
+ def img_to_bytes(img_path):
193
+ img_bytes = Path(img_path).read_bytes()
194
+ encoded = base64.b64encode(img_bytes).decode()
195
+ return encoded
196
+
197
+ def img_to_html(img_path, img_alt):
198
+ img_format = img_path.split(".")[-1]
199
+ img_html = (
200
+ f'<img src="data:image/{img_format.lower()};'
201
+ f'base64,{img_to_bytes(img_path)}" '
202
+ f'alt="{img_alt}" '
203
+ 'style="max-width: 100%;">'
204
+ )
205
+ return img_html
206
+
207
+ def markdown_insert_images(markdown):
208
+ images = markdown_images(markdown)
209
+
210
+ for image in images:
211
+ image_markdown = image[0]
212
+ image_alt = image[1]
213
+ image_path = image[2]
214
+ markdown = markdown.replace(image_markdown, img_to_html(image_path, image_alt))
215
+ return markdown
216
+
217
+ rendered_with_img = markdown_insert_images(rendered)
218
+ return metadata, rendered_with_img
219
+
220
+
221
+ def display_model_card():
222
+ model_card = _create_model_card()
223
+ if not model_card:
224
+ return
225
+
226
+ rendered = model_card.render()
227
+ metadata, rendered = _process_card_for_rendering(rendered)
228
+ # idx = rendered[1:].index("\n---") + 1
229
+ # metadata = rendered[3:idx]
230
+ # rendered = rendered[idx + 4 :] # noqa: E203
231
+
232
+ with right_col:
233
+ # strip metadata
234
+ with st.expander("show metadata"):
235
+ st.text(metadata)
236
+ st.markdown(rendered, unsafe_allow_html=True)
237
+
238
+
239
+ def download_model_card():
240
+ model_card = _create_model_card()
241
+ if model_card is not None:
242
+ return model_card.render()
243
+ return ""
244
+
245
+
246
+ def add_custom_section():
247
+ # this is required to "refresh" these variables...
248
+ global section_name, section_content
249
+ section_name = st.session_state.key_section_name
250
+ section_content = st.session_state.key_section_content
251
+
252
+ if not section_name or not section_content:
253
+ return
254
+
255
+ custom_sections[section_name] = section_content
256
+ _write_custom_section_cache()
257
+
258
+
259
+ def add_custom_plot():
260
+ # this is required to "refresh" these variables...
261
+ global section_name, section_content
262
+ plot_name = st.session_state.key_plot_name
263
+ plot_file = st.session_state.key_plot_file
264
+
265
+ if not plot_name or not plot_file:
266
+ return
267
+
268
+ # store plot in temp repo
269
+ file_name = plot_file.name.replace(" ", "_")
270
+ file_path = str(tmp_repo / file_name)
271
+ with open(file_path, "wb") as f:
272
+ f.write(plot_file.getvalue())
273
+
274
+ custom_sections[str(PLOT_PREFIX + plot_name)] = file_path
275
+ _write_custom_section_cache()
276
+
277
+
278
+ with left_col:
279
+ # This contains every element required to edit the model card
280
+ model = None
281
+ data = None
282
+ section_name = None
283
+ section_content = None
284
+
285
+ model_file = st.file_uploader("Upload a model*", on_change=load_model)
286
+ data_file = st.file_uploader(
287
+ "Upload X data (csv)*", type=["csv"], on_change=load_data
288
+ )
289
+
290
+ task = st.selectbox(
291
+ label="Choose the task type*",
292
+ options=[
293
+ "tabular-classification",
294
+ "tabular-regression",
295
+ "text-classification",
296
+ "text-regression",
297
+ ],
298
+ on_change=init_repo,
299
+ )
300
+
301
+ requirements = st.text_input(
302
+ label="Requirements*",
303
+ value=[f"scikit-learn=={sklearn.__version__}\n"],
304
+ on_change=init_repo,
305
+ )
306
+
307
+ if model_file is not None:
308
+ model = load_model()
309
+
310
+ if data_file is not None:
311
+ data = load_data()
312
+
313
+ if model is not None and data is not None:
314
+ init_repo()
315
+
316
+ model_description = st.text_input("Model description", value=PLACEHOLDER)
317
+ intended_uses = st.text_area(
318
+ "Intended uses & limitations", height=2, value=PLACEHOLDER
319
+ )
320
+ metrics = st.text_area("Metrics (e.g. 'accuracy = 0.95'), one metric per line")
321
+ authors = st.text_area(
322
+ "Authors",
323
+ value="This model card is written by following authors:\n\n" + PLACEHOLDER,
324
+ )
325
+ contact = st.text_area(
326
+ "Contact",
327
+ value="You can contact the model card authors through following channels:\n\n"
328
+ + PLACEHOLDER,
329
+ )
330
+ citation = st.text_area(
331
+ "Citation",
332
+ value="Below you can find information related to citation.\n\nBibTex:\n\n```\n"
333
+ + PLACEHOLDER
334
+ + "\n```",
335
+ height=5,
336
+ )
337
+
338
+ with st.form("custom-section", clear_on_submit=True):
339
+ section_name = st.text_input(
340
+ "Section name (use '/' for subsections, e.g. 'Model description/My new"
341
+ " section')",
342
+ key="key_section_name",
343
+ )
344
+ section_content = st.text_area(
345
+ "Content of the new section", key="key_section_content"
346
+ )
347
+ submit_new_section = st.form_submit_button(
348
+ "Create new section", on_click=add_custom_section
349
+ )
350
+
351
+ with st.form("custom-plots", clear_on_submit=True):
352
+ plot_name = st.text_input(
353
+ "Section name (use '/' for subsections, e.g. 'Model description/My new"
354
+ " plot')",
355
+ key="key_plot_name",
356
+ )
357
+ plot_file = st.file_uploader(
358
+ "Upload a figure*", key="key_plot_file"
359
+ )
360
+
361
+ submit_new_plot = st.form_submit_button(
362
+ "Add plot", on_click=add_custom_plot
363
+ )
364
+
365
+ _load_custom_section_cache()
366
+ for key in custom_sections:
367
+ if not key:
368
+ continue
369
+
370
+ if key.startswith(PLOT_PREFIX):
371
+ st.button(
372
+ f"Remove plot '{key[len(PLOT_PREFIX):]}'", on_click=_remove_custom_section, args=(key,)
373
+ )
374
+ else:
375
+ st.button(
376
+ f"Remove section '{key}'", on_click=_remove_custom_section, args=(key,)
377
+ )
378
+
379
+ if custom_sections:
380
+ st.button(
381
+ f"Remove all ({len(custom_sections)}) custom elements",
382
+ on_click=_clear_custom_section_cache,
383
+ )
384
+
385
+
386
+ with right_col:
387
+ # this contains the rendered model card
388
+ st.button(label="Render model card", on_click=display_model_card)
389
+ rendered = download_model_card()
390
+ if rendered:
391
+ st.download_button(label="Download model card (markdown format)", data=rendered)
make-data.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # companion script to st-space-creator.py
2
+
3
+ import pickle
4
+
5
+ import pandas as pd
6
+ from sklearn.datasets import make_classification
7
+ from sklearn.linear_model import LogisticRegression
8
+ from sklearn.pipeline import Pipeline
9
+ from sklearn.preprocessing import StandardScaler
10
+
11
+ X, y = make_classification()
12
+ df = pd.DataFrame(X)
13
+
14
+ clf = Pipeline(
15
+ [
16
+ ("scale", StandardScaler()),
17
+ ("clf", LogisticRegression(random_state=0)),
18
+ ]
19
+ )
20
+ clf.fit(X, y)
21
+
22
+ with open("logreg.pkl", "wb") as f:
23
+ pickle.dump(clf, f)
24
+
25
+
26
+ df.to_csv("data.csv", index=False)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ pandas
2
+ scikit-learn
3
+ skops