Fish-Weight / fish_model_to_hub.py
brendenc's picture
Upload fish_model_to_hub.py
e485423
from datasets import load_dataset
import pandas as pd
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.pipeline import make_pipeline
from sklearn.compose import make_column_transformer
from sklearn.compose import make_column_selector
from sklearn.preprocessing import OneHotEncoder
from skops import hub_utils
import pickle
from skops import card
from pathlib import Path
my_token = "your token here"
# Load our data
dataset = load_dataset("brendenc/Fish")
df = pd.DataFrame(dataset['train'][:])
target = df.Weight
df = df.drop('Weight', axis=1)
# One hot encode our input
one_hot_encoder = make_column_transformer(
(
OneHotEncoder(sparse=False, handle_unknown="ignore"),
make_column_selector(dtype_include="object"),
),
remainder="passthrough",
)
# Train model
pipe = make_pipeline(
one_hot_encoder, GradientBoostingRegressor(random_state=42)
)
pipe.fit(df, target)
# Save the model
model_path = "example.pkl"
local_repo = "fish-model"
with open(model_path, mode="bw") as f:
pickle.dump(pipe, file=f)
# we will now initialize a local repository
hub_utils.init(
model=model_path,
requirements=[f"scikit-learn={sklearn.__version__}"],
dst=local_repo,
task="tabular-regression",
data=df,
)
# create the card
model_card = card.Card(pipe, metadata=card.metadata_from_config(Path('fish-model')))
limitations = "This model is intended for educational purposes."
model_description = "This is a GradientBoostingRegressor on a fish dataset."
model_card_authors = "Brenden Connors"
# we can add the information using add
model_card.add(
model_card_authors=model_card_authors,
limitations=limitations,
model_description=model_description,
)
# we can set the metadata part directly
model_card.metadata.license = "mit"
model_card.save(Path(local_repo) / "README.md")
# Push to the hub
repo_id = "scikit-learn/Fish-Weight/Fish-Weight"
hub_utils.push(
repo_id=repo_id,
source=local_repo,
token=my_token,
commit_message="Adding model files",
create_remote=True,
)