File size: 4,608 Bytes
ddcecdf
603596f
7ae2fd5
 
 
 
320d835
7ae2fd5
603596f
7ae2fd5
78184f0
 
 
7ae2fd5
78184f0
 
 
 
 
 
 
 
 
7ae2fd5
 
ddcecdf
 
 
 
 
 
 
 
 
 
 
7ae2fd5
ddcecdf
 
 
 
dcc574b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddcecdf
 
 
 
 
 
 
 
 
 
 
 
 
78184f0
ddcecdf
 
 
78184f0
ddcecdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78184f0
 
ddcecdf
 
78184f0
 
 
 
 
 
ddcecdf
78184f0
 
 
 
 
 
 
 
 
ddcecdf
78184f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from __future__ import annotations
import psutil
import pandas as pd
import streamlit as st
import plotly.express as px
from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
from zeroshot_classification.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier

print(f"Total mem: {psutil.virtual_memory().total}")

def init_state(key: str):
    if key not in st.session_state:
        st.session_state[key] = None


for k in [
    "current_model",
    "current_model_option",
    "current_method_option",
    "current_prediction",
    "current_chart",
]:
    init_state(k)


def load_model(model_option: str, method_option: str, random_state: int = 0):
    with st.spinner("Loading selected model..."):
        if method_option == "Natural Language Inference":
            st.session_state.current_model = NLIZeroshotClassifier(
                model_name=model_option, random_state=random_state
            )
        else:
            st.session_state.current_model = NSPZeroshotClassifier(
                model_name=model_option, random_state=random_state
            )
        st.success("Model loaded!")


def visualize_output(labels: list[str], probabilities: list[float]):
    data = pd.DataFrame({"labels": labels, "probability": probabilities}).sort_values(
        by="probability", ascending=False
    )
    chart = px.bar(
        data,
        x="probability",
        y="labels",
        color="labels",
        orientation="h",
        height=290,
        width=500,
    ).update_layout(
        {
            "xaxis": {"title": "probability", "visible": True, "showticklabels": True},
            "yaxis": {"title": None, "visible": True, "showticklabels": True},
            "margin": dict(
                l=10,  # left
                r=10,  # right
                t=50,  # top
                b=10,  # bottom
            ),
            "showlegend": False,
        }
    )
    return chart


st.title("Zero-shot Turkish Text Classification")
method_option = st.radio(
    "Select a zero-shot classification method.",
    [
        METHOD_OPTIONS["nli"],
        METHOD_OPTIONS["nsp"],
    ],
)
if method_option == METHOD_OPTIONS["nli"]:
    model_option = st.selectbox(
        "Select a natural language inference model.", NLI_MODEL_OPTIONS, index=3
    )
if method_option == METHOD_OPTIONS["nsp"]:
    model_option = st.selectbox(
        "Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS, index=0
    )

if model_option != st.session_state.current_model_option:
    st.session_state.current_model_option = model_option
    st.session_state.current_method_option = method_option
    load_model(
        st.session_state.current_model_option, st.session_state.current_method_option
    )


st.header("Configure prompts and labels")
col1, col2 = st.columns(2)
col1.subheader("Candidate labels")
labels = col1.text_area(
    label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.",
    value="spor,dünya,siyaset,ekonomi,sanat",
    key="current_labels",
)

col1.header("Make predictions")
text = col1.text_area(
    "Enter a sentence or a paragraph to classify.",
    value="Ian Anderson, Jethro Tull konserinde yan flüt çalarak zeybek oynadı.",
    key="current_text",
)
col2.subheader("Prompt template")
prompt_template = col2.text_area(
    label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.",
    value="Bu metin {} kategorisine aittir",
    key="current_template",
)
col2.header("")


make_pred = col1.button("Predict")
if make_pred:
    st.session_state.current_prediction = (
        st.session_state.current_model.predict_on_texts(
            [st.session_state.current_text],
            candidate_labels=st.session_state.current_labels.split(","),
            prompt_template=st.session_state.current_template,
        )
    )
    if "scores" in st.session_state.current_prediction[0]:
        st.session_state.current_chart = visualize_output(
            st.session_state.current_prediction[0]["labels"],
            st.session_state.current_prediction[0]["scores"],
        )
    elif "probabilities" in st.session_state.current_prediction[0]:
        st.session_state.current_chart = visualize_output(
            st.session_state.current_prediction[0]["labels"],
            st.session_state.current_prediction[0]["probabilities"],
        )
    col2.plotly_chart(st.session_state.current_chart, use_container_width=True)