File size: 2,667 Bytes
171c1a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import random
from typing import Sequence

import datasets
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import streamlit as st
from setfit import SetFitModel

st.set_page_config(
    page_title="mtg-coloridentity-multilabel-classification",
    page_icon="🧙",
    layout="wide",
    initial_sidebar_state="collapsed",
    menu_items=None,
)


default_hf_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
HF_HOME = os.environ.get("HF_HOME", default_hf_home)

coloridentity_model = "joshuasundance/mtg-coloridentity-multilabel-classification"

colors = ["B", "G", "R", "U", "W"]
labels = ["black", "green", "red", "blue", "white"]

sns.set()

col1, col2 = st.columns(2)


@st.cache_resource
def get_model(
    model_id: str = coloridentity_model,
    cache_dir: str = HF_HOME,
    **kwargs,
) -> SetFitModel:
    return SetFitModel.from_pretrained(model_id, cache_dir=cache_dir, **kwargs)


@st.cache_data
def get_data(
    repo_id: str = coloridentity_model,
    cache_dir: str = HF_HOME,
    **kwargs,
) -> datasets.Dataset:
    dataset_dict = datasets.load_dataset(repo_id, cache_dir=cache_dir, **kwargs)
    return datasets.concatenate_datasets(
        list(dataset_dict.values()),
    )


def get_random_text() -> str:
    return dataset.select([random.randint(0, len(dataset))])[0]["text"]  # nosec


@st.cache_data
def get_preds(input_text: str) -> Sequence[float]:
    return model.predict_proba(input_text)


def prob_bars(preds: Sequence[float]) -> None:
    _preds = (float(p) for p in preds)
    df = pd.DataFrame(
        zip(labels, _preds),
        columns=["Color", "Probability"],
    )
    plt.figure(figsize=(8, 6))
    ax = sns.barplot(x="Color", y="Probability", data=df, palette=labels)

    # Add data labels on each bar
    for p in ax.patches:
        ax.annotate(
            format(p.get_height(), ".4f"),
            (p.get_x() + p.get_width() / 2.0, p.get_height()),
            ha="center",
            va="center",
            xytext=(0, 9),
            textcoords="offset points",
        )

    plt.title("Prediction Probabilities")
    plt.xlabel("Color")
    plt.ylabel("Probability")
    st.pyplot(plt.gcf())


model = get_model()
dataset = get_data()
default_text = get_random_text()

if "input_text" not in st.session_state:
    st.session_state.input_text = default_text

with col1:
    if st.button("🎲 Roll the Dice"):
        st.session_state.input_text = get_random_text()
    input_text = st.text_area(
        "Card name and text",
        st.session_state.input_text,
        height=400,
    )

preds = get_preds(input_text)

with col2:
    prob_bars(preds)