Spaces:
Runtime error
Runtime error
File size: 2,667 Bytes
171c1a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import os
import random
from typing import Sequence
import datasets
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import streamlit as st
from setfit import SetFitModel
st.set_page_config(
page_title="mtg-coloridentity-multilabel-classification",
page_icon="🧙",
layout="wide",
initial_sidebar_state="collapsed",
menu_items=None,
)
default_hf_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
HF_HOME = os.environ.get("HF_HOME", default_hf_home)
coloridentity_model = "joshuasundance/mtg-coloridentity-multilabel-classification"
colors = ["B", "G", "R", "U", "W"]
labels = ["black", "green", "red", "blue", "white"]
sns.set()
col1, col2 = st.columns(2)
@st.cache_resource
def get_model(
model_id: str = coloridentity_model,
cache_dir: str = HF_HOME,
**kwargs,
) -> SetFitModel:
return SetFitModel.from_pretrained(model_id, cache_dir=cache_dir, **kwargs)
@st.cache_data
def get_data(
repo_id: str = coloridentity_model,
cache_dir: str = HF_HOME,
**kwargs,
) -> datasets.Dataset:
dataset_dict = datasets.load_dataset(repo_id, cache_dir=cache_dir, **kwargs)
return datasets.concatenate_datasets(
list(dataset_dict.values()),
)
def get_random_text() -> str:
return dataset.select([random.randint(0, len(dataset))])[0]["text"] # nosec
@st.cache_data
def get_preds(input_text: str) -> Sequence[float]:
return model.predict_proba(input_text)
def prob_bars(preds: Sequence[float]) -> None:
_preds = (float(p) for p in preds)
df = pd.DataFrame(
zip(labels, _preds),
columns=["Color", "Probability"],
)
plt.figure(figsize=(8, 6))
ax = sns.barplot(x="Color", y="Probability", data=df, palette=labels)
# Add data labels on each bar
for p in ax.patches:
ax.annotate(
format(p.get_height(), ".4f"),
(p.get_x() + p.get_width() / 2.0, p.get_height()),
ha="center",
va="center",
xytext=(0, 9),
textcoords="offset points",
)
plt.title("Prediction Probabilities")
plt.xlabel("Color")
plt.ylabel("Probability")
st.pyplot(plt.gcf())
model = get_model()
dataset = get_data()
default_text = get_random_text()
if "input_text" not in st.session_state:
st.session_state.input_text = default_text
with col1:
if st.button("🎲 Roll the Dice"):
st.session_state.input_text = get_random_text()
input_text = st.text_area(
"Card name and text",
st.session_state.input_text,
height=400,
)
preds = get_preds(input_text)
with col2:
prob_bars(preds)
|