Spaces:
Runtime error
Runtime error
File size: 4,907 Bytes
f5a1b52 3cba90e 9c313f7 f5a1b52 8a7f283 58eeaa0 8a7f283 4448ca2 f5a1b52 9c313f7 4448ca2 860bbbc f5a1b52 9c313f7 58eeaa0 cb37e2c cb7bd99 58eeaa0 b81c70b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from platform import processor
import streamlit as st
from load_data import candidate_labels
import numpy as np
from load_data import *
import pickle
import torch
from BART_utils import get_taggs
from stqdm import stqdm
import pandas as pd
def transform_data(data, filetype = True):
if filetype:
df = pd.read_csv(uploaded_file)
else:
df = pd.read_excel(uploaded_file)
return df
def convert_df(df):
return df.to_csv().encode('utf-8')
stqdm.pandas()
st.title("Domain and Usage tagger")
st.subheader("๋ฌธ์ฅ์ ์
๋ ฅํ๋ฉด ์ฃผ์ / ์ฉ๋ ํ๊ทธ๋ฅผ ์์ฑํฉ๋๋ค (EN์ง์)")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device == "cpu":
processor = "๐ฅ๏ธ"
else:
processor = "๐ฝ"
st.subheader("Running on {}".format(device + processor))
bulk = st.checkbox("ํ์ผ์ ์
๋ก๋ํ์๊ฒ ์ด์?")
if not bulk:
user_input = st.text_area(
"๐ํ๊ทธ๋ฅผ ์์ฑํ ๋ฌธ์ฅ์ ์
๋ ฅํ์ธ์ - ํ์ฌ ์๋ฌธ๋ง ์ง์๋ฉ๋๋ค.", """NLI-based Zero Shot Text Classification
Yin et al. proposed a method for using pre-trained NLI models as a ready-made zero-shot sequence classifiers. The method works by posing the sequence to be classified as the NLI premise and to construct a hypothesis from each candidate label. The probabilities for entailment and contradiction are then converted to label probabilities."""
)
thred = st.slider(
"๐ํ๊ทธ ์์ฑ thredhold ์ค์ . ๊ฒฐ๊ณผ๊ฐ ๋์ค์ง ์์๊ฒฝ์ฐ, threshold๋ฅผ 0์ ๊ฐ๊น๊ฒ ๋ฎ์ถ์ธ์!",
0.0,
1.0,
0.5,
step=0.01,
)
if thred:
st.write(thred, " ์ด์์ confidence level์ธ ํ๊ทธ๋ง ์์ฑํฉ๋๋ค.")
maximum = st.number_input("๐์ต๋ ํ๊ทธ ๊ฐฏ์ ์ค์ ", 0, 10, 5, step=1)
st.write("์ต๋ {} ๊ฐ์ ํ๊ทธ ์์ฑ".format(maximum))
check_source = st.checkbox("๐ท๏ธ์ฉ์ฒ / ์ถ์ฒ ํ๊ทธ ์์ฑ")
submit = st.button("๐ํด๋ฆญํด์ ํ๊ทธ ์์ฑ")
if submit:
with st.spinner("โํ๊ทธ๋ฅผ ์์ฑํ๋ ์ค์
๋๋ค..."):
result = get_taggs(user_input, candidate_labels, thred)
result = result[:maximum]
st.subheader("๐ํน์ ์ด๋ฐ ์ฃผ์ ์ ๋ฌธ์ฅ์ธ๊ฐ์? : ")
if len(result) == 0:
st.write("๐ข์ ๋ฐ..๊ฒฐ๊ณผ๊ฐ ์์ต๋๋ค. Threshold๋ฅผ ๋ฎ์ถฐ๋ณด์ธ์!")
for i in result:
st.write("โก๏ธ " + i[0], "{}%".format(int(i[1] * 100)))
if check_source:
with st.spinner("โ์ฌ์ฉ ๋ชฉ์ ํ๊ทธ ์์ฑ์ค..."):
source_result = get_taggs(user_input, source, thred=0)
st.subheader("๐ํน์ ์ด ์ฌ์ฉ๋ชฉ์ ์ ๋ฌธ์ฅ์ธ๊ฐ์? : ")
for i in source_result[:3]:
st.write("๐ท๏ธ " + i[0], "{}%".format(int(i[1] * 100)))
else:
st.write("๐์ปฌ๋ผ๋ช
์ 'text'๋ก ์ค์ ํด, ํ์ผ์ ์
๋ก๋ํด์ฃผ์ธ์!")
filetype = st.checkbox("๐Using CSV? (์ฒดํฌํ์ง ์์ผ๋ฉด xlsx ์ฌ์ฉ): ")
uploaded_file = st.file_uploader("Choose an csv file")
if uploaded_file is not None:
df = transform_data(uploaded_file, filetype)
st.write(df)
thred = st.slider(
"๐ํ๊ทธ ์์ฑ thredhold ์ค์ . ๊ฒฐ๊ณผ๊ฐ ๋์ค์ง ์์๊ฒฝ์ฐ, threshold๋ฅผ 0์ ๊ฐ๊น๊ฒ ๋ฎ์ถ์ธ์!",
0.0,
1.0,
0.5,
step=0.01,
)
if thred:
st.write(thred, " ์ด์์ confidence level์ธ ํ๊ทธ๋ง ์์ฑํฉ๋๋ค.")
maximum = st.number_input("๐์ต๋ ํ๊ทธ ๊ฐฏ์ ์ค์ ", 0, 10, 5, step=1)
st.write("์ต๋ {} ๊ฐ์ ํ๊ทธ ์์ฑ".format(maximum))
check_source = st.checkbox("๐ท๏ธ์ฉ์ฒ / ์ถ์ฒ ํ๊ทธ ์์ฑ")
submit = st.button("๐ํด๋ฆญํด์ ํ๊ทธ ์์ฑ")
if submit:
with st.spinner("โํ๊ทธ๋ฅผ ์์ฑํ๋ ์ค์
๋๋ค..."):
df["generated_tag"] = df["text"].progress_apply(lambda x : get_taggs(x, candidate_labels, thred)[:maximum])
if check_source:
with st.spinner("โ์ฌ์ฉ ๋ชฉ์ ํ๊ทธ ์์ฑ์ค..."):
df["source"] = df["text"].progress_apply(lambda x : get_taggs(x, source, thred=0))
csv = convert_df(df)
to_json = {}
for idx, row in df.iterrows():
to_json[row.text] = {}
to_json[row.text]["generated_tag"] = row.generated_tag
to_json[row.text]["source"] = row.source
st.download_button(
"Press to Download",
csv,
"file.csv",
"text/csv",
key='download-csv'
)
st.write("๐Outcome: ")
st.write(to_json) |