Spaces:
Runtime error
Runtime error
from platform import processor | |
import streamlit as st | |
from load_data import candidate_labels | |
import numpy as np | |
from load_data import * | |
import pickle | |
import torch | |
from BART_utils import get_taggs | |
from stqdm import stqdm | |
import pandas as pd | |
def transform_data(data, filetype = True): | |
if filetype: | |
df = pd.read_csv(uploaded_file) | |
else: | |
df = pd.read_excel(uploaded_file) | |
return df | |
def convert_df(df): | |
return df.to_csv().encode('utf-8') | |
stqdm.pandas() | |
st.title("Domain and Usage tagger") | |
st.subheader("๋ฌธ์ฅ์ ์ ๋ ฅํ๋ฉด ์ฃผ์ / ์ฉ๋ ํ๊ทธ๋ฅผ ์์ฑํฉ๋๋ค (EN์ง์)") | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
if device == "cpu": | |
processor = "๐ฅ๏ธ" | |
else: | |
processor = "๐ฝ" | |
st.subheader("Running on {}".format(device + processor)) | |
bulk = st.checkbox("ํ์ผ์ ์ ๋ก๋ํ์๊ฒ ์ด์?") | |
if not bulk: | |
user_input = st.text_area( | |
"๐ํ๊ทธ๋ฅผ ์์ฑํ ๋ฌธ์ฅ์ ์ ๋ ฅํ์ธ์ - ํ์ฌ ์๋ฌธ๋ง ์ง์๋ฉ๋๋ค.", """NLI-based Zero Shot Text Classification | |
Yin et al. proposed a method for using pre-trained NLI models as a ready-made zero-shot sequence classifiers. The method works by posing the sequence to be classified as the NLI premise and to construct a hypothesis from each candidate label. The probabilities for entailment and contradiction are then converted to label probabilities.""" | |
) | |
thred = st.slider( | |
"๐ํ๊ทธ ์์ฑ thredhold ์ค์ . ๊ฒฐ๊ณผ๊ฐ ๋์ค์ง ์์๊ฒฝ์ฐ, threshold๋ฅผ 0์ ๊ฐ๊น๊ฒ ๋ฎ์ถ์ธ์!", | |
0.0, | |
1.0, | |
0.5, | |
step=0.01, | |
) | |
if thred: | |
st.write(thred, " ์ด์์ confidence level์ธ ํ๊ทธ๋ง ์์ฑํฉ๋๋ค.") | |
maximum = st.number_input("๐์ต๋ ํ๊ทธ ๊ฐฏ์ ์ค์ ", 0, 10, 5, step=1) | |
st.write("์ต๋ {} ๊ฐ์ ํ๊ทธ ์์ฑ".format(maximum)) | |
check_source = st.checkbox("๐ท๏ธ์ฉ์ฒ / ์ถ์ฒ ํ๊ทธ ์์ฑ") | |
submit = st.button("๐ํด๋ฆญํด์ ํ๊ทธ ์์ฑ") | |
if submit: | |
with st.spinner("โํ๊ทธ๋ฅผ ์์ฑํ๋ ์ค์ ๋๋ค..."): | |
result = get_taggs(user_input, candidate_labels, thred) | |
result = result[:maximum] | |
st.subheader("๐ํน์ ์ด๋ฐ ์ฃผ์ ์ ๋ฌธ์ฅ์ธ๊ฐ์? : ") | |
if len(result) == 0: | |
st.write("๐ข์ ๋ฐ..๊ฒฐ๊ณผ๊ฐ ์์ต๋๋ค. Threshold๋ฅผ ๋ฎ์ถฐ๋ณด์ธ์!") | |
for i in result: | |
st.write("โก๏ธ " + i[0], "{}%".format(int(i[1] * 100))) | |
if check_source: | |
with st.spinner("โ์ฌ์ฉ ๋ชฉ์ ํ๊ทธ ์์ฑ์ค..."): | |
source_result = get_taggs(user_input, source, thred=0) | |
st.subheader("๐ํน์ ์ด ์ฌ์ฉ๋ชฉ์ ์ ๋ฌธ์ฅ์ธ๊ฐ์? : ") | |
for i in source_result[:3]: | |
st.write("๐ท๏ธ " + i[0], "{}%".format(int(i[1] * 100))) | |
else: | |
st.write("๐์ปฌ๋ผ๋ช ์ 'text'๋ก ์ค์ ํด, ํ์ผ์ ์ ๋ก๋ํด์ฃผ์ธ์!") | |
filetype = st.checkbox("๐Using CSV? (์ฒดํฌํ์ง ์์ผ๋ฉด xlsx ์ฌ์ฉ): ") | |
uploaded_file = st.file_uploader("Choose an csv file") | |
if uploaded_file is not None: | |
df = transform_data(uploaded_file, filetype) | |
st.write(df) | |
thred = st.slider( | |
"๐ํ๊ทธ ์์ฑ thredhold ์ค์ . ๊ฒฐ๊ณผ๊ฐ ๋์ค์ง ์์๊ฒฝ์ฐ, threshold๋ฅผ 0์ ๊ฐ๊น๊ฒ ๋ฎ์ถ์ธ์!", | |
0.0, | |
1.0, | |
0.5, | |
step=0.01, | |
) | |
if thred: | |
st.write(thred, " ์ด์์ confidence level์ธ ํ๊ทธ๋ง ์์ฑํฉ๋๋ค.") | |
maximum = st.number_input("๐์ต๋ ํ๊ทธ ๊ฐฏ์ ์ค์ ", 0, 10, 5, step=1) | |
st.write("์ต๋ {} ๊ฐ์ ํ๊ทธ ์์ฑ".format(maximum)) | |
check_source = st.checkbox("๐ท๏ธ์ฉ์ฒ / ์ถ์ฒ ํ๊ทธ ์์ฑ") | |
submit = st.button("๐ํด๋ฆญํด์ ํ๊ทธ ์์ฑ") | |
if submit: | |
with st.spinner("โํ๊ทธ๋ฅผ ์์ฑํ๋ ์ค์ ๋๋ค..."): | |
df["generated_tag"] = df["text"].progress_apply(lambda x : get_taggs(x, candidate_labels, thred)[:maximum]) | |
if check_source: | |
with st.spinner("โ์ฌ์ฉ ๋ชฉ์ ํ๊ทธ ์์ฑ์ค..."): | |
df["source"] = df["text"].progress_apply(lambda x : get_taggs(x, source, thred=0)) | |
csv = convert_df(df) | |
to_json = {} | |
for idx, row in df.iterrows(): | |
to_json[row.text] = {} | |
to_json[row.text]["generated_tag"] = row.generated_tag | |
to_json[row.text]["source"] = row.source | |
st.download_button( | |
"Press to Download", | |
csv, | |
"file.csv", | |
"text/csv", | |
key='download-csv' | |
) | |
st.write("๐Outcome: ") | |
st.write(to_json) |