|
"""Start page of the app |
|
|
|
This page is used to initialize a model card that is either: |
|
|
|
1. based on the skops template |
|
2. empty |
|
3. loads an existing model card |
|
|
|
Optionally, users can add a model file, data, requirements, and choose a task. |
|
|
|
""" |
|
|
|
import glob |
|
import io |
|
import os |
|
import pickle |
|
import shutil |
|
from pathlib import Path |
|
from tempfile import mkdtemp |
|
|
|
import pandas as pd |
|
import sklearn |
|
import streamlit as st |
|
from huggingface_hub import snapshot_download |
|
from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError |
|
from sklearn.base import BaseEstimator |
|
from sklearn.dummy import DummyClassifier |
|
|
|
import skops.io as sio |
|
from skops import card, hub_utils |
|
from skops.io import get_untrusted_types |
|
|
|
tmp_path = Path(mkdtemp(prefix="skops-")) |
|
description = """Create a Hugging Face model repository for scikit learn models |
|
|
|
This page aims to provide a simple interface to use the |
|
[`skops`](https://skops.readthedocs.io/) model card and HF Hub creation |
|
utilities. |
|
|
|
""" |
|
|
|
|
|
def load_model() -> None: |
|
if st.session_state.get("model_file") is None: |
|
st.session_state.model = DummyClassifier() |
|
return |
|
|
|
bytes_data = st.session_state.model_file.getvalue() |
|
if st.session_state.model_file.name.endswith("skops"): |
|
model = sio.loads(bytes_data, trusted=get_untrusted_types(data=bytes_data)) |
|
else: |
|
model = pickle.loads(bytes_data) |
|
assert isinstance(model, BaseEstimator), "model must be an sklearn model" |
|
|
|
st.session_state.model = model |
|
|
|
|
|
def load_data() -> None: |
|
if st.session_state.get("data_file"): |
|
bytes_data = io.BytesIO(st.session_state.data_file.getvalue()) |
|
df = pd.read_csv(bytes_data) |
|
else: |
|
df = pd.DataFrame([]) |
|
|
|
st.session_state.data = df |
|
|
|
|
|
def _clear_repo(path: str) -> None: |
|
for file_path in glob.glob(str(Path(path) / "*")): |
|
if os.path.isfile(file_path) or os.path.islink(file_path): |
|
os.unlink(file_path) |
|
elif os.path.isdir(file_path): |
|
shutil.rmtree(file_path) |
|
|
|
|
|
def init_repo() -> None: |
|
path = st.session_state.hf_path |
|
_clear_repo(path) |
|
requirements = [] |
|
task = "tabular-classification" |
|
data = pd.DataFrame([]) |
|
|
|
if "requirements" in st.session_state: |
|
requirements = st.session_state.requirements.splitlines() |
|
if "task" in st.session_state: |
|
task = st.session_state.task |
|
if "data_file" in st.session_state: |
|
load_data() |
|
data = st.session_state.data |
|
|
|
if task.startswith("text") and isinstance(data, pd.DataFrame): |
|
data = data.values.tolist() |
|
|
|
try: |
|
file_name = tmp_path / "model.skops" |
|
sio.dump(st.session_state.model, file_name) |
|
|
|
hub_utils.init( |
|
model=file_name, |
|
dst=path, |
|
task=task, |
|
data=data, |
|
requirements=requirements, |
|
) |
|
except Exception as exc: |
|
print("Uh oh, something went wrong when initializing the repo:", exc) |
|
|
|
|
|
def create_skops_model_card() -> None: |
|
init_repo() |
|
metadata = card.metadata_from_config(st.session_state.hf_path) |
|
model_card = card.Card(model=st.session_state.model, metadata=metadata) |
|
st.session_state.model_card = model_card |
|
st.session_state.model_card_type = "skops" |
|
st.session_state.screen.state = "edit" |
|
|
|
|
|
def create_empty_model_card() -> None: |
|
init_repo() |
|
metadata = card.metadata_from_config(st.session_state.hf_path) |
|
model_card = card.Card( |
|
model=st.session_state.model, metadata=metadata, template=None |
|
) |
|
model_card.add(**{"Untitled": "[More Information Needed]"}) |
|
st.session_state.model_card = model_card |
|
st.session_state.model_card_type = "empty" |
|
st.session_state.screen.state = "edit" |
|
|
|
|
|
def create_hf_model_card() -> None: |
|
repo_id = st.session_state.get("hf_repo_id", "").strip().strip("'").strip('"') |
|
if not repo_id: |
|
return |
|
|
|
try: |
|
allow_patterns = [ |
|
"*.md", |
|
".txt", |
|
"*.png", |
|
"*.gif", |
|
"*.jpg", |
|
"*.jpeg", |
|
"*.bmp", |
|
"*.webp", |
|
] |
|
path = snapshot_download(repo_id, allow_patterns=allow_patterns) |
|
except (HFValidationError, RepositoryNotFoundError): |
|
st.error( |
|
f"Repository '{repo_id}' could not be found on HF Hub, " |
|
"please check that the repo ID is correct." |
|
) |
|
return |
|
|
|
|
|
hf_path = st.session_state.hf_path |
|
shutil.copytree(path, hf_path, dirs_exist_ok=True) |
|
shutil.copytree(path, ".", dirs_exist_ok=True) |
|
|
|
model_card = card.parse_modelcard(hf_path / "README.md") |
|
st.session_state.model_card = model_card |
|
st.session_state.model_card_type = "loaded" |
|
st.session_state.screen.state = "edit" |
|
|
|
|
|
def add_help_button(): |
|
def fn(): |
|
st.session_state.screen.state = "help" |
|
|
|
st.button( |
|
"Instructions", |
|
on_click=fn, |
|
help="Detailed explanation of this space", |
|
key="get_help", |
|
) |
|
|
|
|
|
def start_input_form(): |
|
if "model" not in st.session_state: |
|
st.session_state.model = DummyClassifier() |
|
|
|
if "data" not in st.session_state: |
|
st.session_state.data = pd.DataFrame([]) |
|
|
|
if "model_card" not in st.session_state: |
|
st.session_state.model_card = None |
|
|
|
st.markdown(description) |
|
|
|
add_help_button() |
|
|
|
st.markdown("---") |
|
|
|
st.text( |
|
"Upload an sklearn model (strongly recommended)\n" |
|
"The model can be used to automatically populate fields in the model card." |
|
) |
|
|
|
if not st.session_state.get("model_file"): |
|
st.file_uploader( |
|
"Upload an sklearn model (pickle or skops format)", |
|
on_change=load_model, |
|
key="model_file", |
|
) |
|
|
|
st.markdown("---") |
|
|
|
st.text( |
|
"Upload samples from your data (in csv format)\n" |
|
"This sample data can be attached to the metadata of the model card" |
|
) |
|
st.file_uploader( |
|
"Upload input data (csv)", type=["csv"], on_change=load_data, key="data_file" |
|
) |
|
st.markdown("---") |
|
|
|
st.selectbox( |
|
label="Choose the task type", |
|
options=[ |
|
"tabular-classification", |
|
"tabular-regression", |
|
"text-classification", |
|
"text-regression", |
|
], |
|
key="task", |
|
on_change=init_repo, |
|
) |
|
st.markdown("---") |
|
|
|
st.text_area( |
|
label="Requirements", |
|
value=f"scikit-learn=={sklearn.__version__}\n", |
|
key="requirements", |
|
on_change=init_repo, |
|
) |
|
st.markdown("---") |
|
|
|
st.markdown("Choose one of the options below to get started:") |
|
col_0, col_1, col_2 = st.columns([2, 2, 2]) |
|
with col_0: |
|
st.button("Create a new skops model card", on_click=create_skops_model_card) |
|
|
|
with col_1: |
|
st.button("Create a new empty model card", on_click=create_empty_model_card) |
|
|
|
with col_2: |
|
with st.form("Load existing model card from HF Hub", clear_on_submit=False): |
|
st.markdown("Load existing model card from HF Hub") |
|
st.text_input("Repo name (e.g. 'gpt2')", key="hf_repo_id") |
|
st.form_submit_button("Load", on_click=create_hf_model_card) |
|
|
|
|
|
start_input_form() |
|
|