Benjamin Bossan
Small update to help
d8d447c
raw
history blame contribute delete
No virus
7.17 kB
"""Start page of the app
This page is used to initialize a model card that is either:
1. based on the skops template
2. empty
3. loads an existing model card
Optionally, users can add a model file, data, requirements, and choose a task.
"""
import glob
import io
import os
import pickle
import shutil
from pathlib import Path
from tempfile import mkdtemp
import pandas as pd
import sklearn
import streamlit as st
from huggingface_hub import snapshot_download
from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
from sklearn.base import BaseEstimator
from sklearn.dummy import DummyClassifier
import skops.io as sio
from skops import card, hub_utils
tmp_path = Path(mkdtemp(prefix="skops-")) # temporary files
description = """Create a Hugging Face model repository for scikit learn models
This page aims to provide a simple interface to use the
[`skops`](https://skops.readthedocs.io/) model card and HF Hub creation
utilities.
"""
def load_model() -> None:
if st.session_state.get("model_file") is None:
st.session_state.model = DummyClassifier()
return
bytes_data = st.session_state.model_file.getvalue()
if st.session_state.model_file.name.endswith("skops"):
model = sio.loads(bytes_data, trusted=True)
else:
model = pickle.loads(bytes_data)
assert isinstance(model, BaseEstimator), "model must be an sklearn model"
st.session_state.model = model
def load_data() -> None:
if st.session_state.get("data_file"):
bytes_data = io.BytesIO(st.session_state.data_file.getvalue())
df = pd.read_csv(bytes_data)
else:
df = pd.DataFrame([])
st.session_state.data = df
def _clear_repo(path: str) -> None:
for file_path in glob.glob(str(Path(path) / "*")):
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
def init_repo() -> None:
path = st.session_state.hf_path
_clear_repo(path)
requirements = []
task = "tabular-classification"
data = pd.DataFrame([])
if "requirements" in st.session_state:
requirements = st.session_state.requirements.splitlines()
if "task" in st.session_state:
task = st.session_state.task
if "data_file" in st.session_state:
load_data()
data = st.session_state.data
if task.startswith("text") and isinstance(data, pd.DataFrame):
data = data.values.tolist()
try:
file_name = tmp_path / "model.skops"
sio.dump(st.session_state.model, file_name)
hub_utils.init(
model=file_name,
dst=path,
task=task,
data=data,
requirements=requirements,
)
except Exception as exc:
print("Uh oh, something went wrong when initializing the repo:", exc)
def create_skops_model_card() -> None:
init_repo()
metadata = card.metadata_from_config(st.session_state.hf_path)
model_card = card.Card(model=st.session_state.model, metadata=metadata)
st.session_state.model_card = model_card
st.session_state.model_card_type = "skops"
st.session_state.screen.state = "edit"
def create_empty_model_card() -> None:
init_repo()
metadata = card.metadata_from_config(st.session_state.hf_path)
model_card = card.Card(
model=st.session_state.model, metadata=metadata, template=None
)
model_card.add(**{"Untitled": "[More Information Needed]"})
st.session_state.model_card = model_card
st.session_state.model_card_type = "empty"
st.session_state.screen.state = "edit"
def create_hf_model_card() -> None:
repo_id = st.session_state.get("hf_repo_id", "").strip().strip("'").strip('"')
if not repo_id:
return
try:
allow_patterns = [
"*.md",
".txt",
"*.png",
"*.gif",
"*.jpg",
"*.jpeg",
"*.bmp",
"*.webp",
]
path = snapshot_download(repo_id, allow_patterns=allow_patterns)
except (HFValidationError, RepositoryNotFoundError):
st.error(
f"Repository '{repo_id}' could not be found on HF Hub, "
"please check that the repo ID is correct."
)
return
# move everything to the hf_path and working dir
hf_path = st.session_state.hf_path
shutil.copytree(path, hf_path, dirs_exist_ok=True)
shutil.copytree(path, ".", dirs_exist_ok=True)
model_card = card.parse_modelcard(hf_path / "README.md")
st.session_state.model_card = model_card
st.session_state.model_card_type = "loaded"
st.session_state.screen.state = "edit"
def add_help_button():
def fn():
st.session_state.screen.state = "help"
st.button(
"Instructions",
on_click=fn,
help="Detailed explanation of this space",
key="get_help",
)
def start_input_form():
if "model" not in st.session_state:
st.session_state.model = DummyClassifier()
if "data" not in st.session_state:
st.session_state.data = pd.DataFrame([])
if "model_card" not in st.session_state:
st.session_state.model_card = None
st.markdown(description)
add_help_button()
st.markdown("---")
st.text(
"Upload an sklearn model (strongly recommended)\n"
"The model can be used to automatically populate fields in the model card."
)
if not st.session_state.get("model_file"):
st.file_uploader(
"Upload an sklearn model (pickle or skops format)",
on_change=load_model,
key="model_file",
)
st.markdown("---")
st.text(
"Upload samples from your data (in csv format)\n"
"This sample data can be attached to the metadata of the model card"
)
st.file_uploader(
"Upload input data (csv)", type=["csv"], on_change=load_data, key="data_file"
)
st.markdown("---")
st.selectbox(
label="Choose the task type",
options=[
"tabular-classification",
"tabular-regression",
"text-classification",
"text-regression",
],
key="task",
on_change=init_repo,
)
st.markdown("---")
st.text_area(
label="Requirements",
value=f"scikit-learn=={sklearn.__version__}\n",
key="requirements",
on_change=init_repo,
)
st.markdown("---")
st.markdown("Choose one of the options below to get started:")
col_0, col_1, col_2 = st.columns([2, 2, 2])
with col_0:
st.button("Create a new skops model card", on_click=create_skops_model_card)
with col_1:
st.button("Create a new empty model card", on_click=create_empty_model_card)
with col_2:
with st.form("Load existing model card from HF Hub", clear_on_submit=False):
st.markdown("Load existing model card from HF Hub")
st.text_input("Repo name (e.g. 'gpt2')", key="hf_repo_id")
st.form_submit_button("Load", on_click=create_hf_model_card)
start_input_form()