ki33elev commited on
Commit
62096a0
1 Parent(s): 41e544a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -18,6 +18,8 @@ def load_model():
18
  @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})
19
  def predict(title, summary, tokenizer, model):
20
  text = title + "\n" + summary
 
 
21
  tokens = tokenizer.encode(text)
22
  with torch.no_grad():
23
  logits = model(torch.as_tensor([tokens]))[0]
@@ -38,7 +40,7 @@ def predict(title, summary, tokenizer, model):
38
 
39
  @st.cache(suppress_st_warning=True)
40
  def get_results(prediction, prediction_probs):
41
- frame = pd.DataFrame({'Topic': prediction, 'Confidence': prediction_probs})
42
  frame.index = np.arange(1, len(frame) + 1)
43
  return frame
44
 
@@ -47,7 +49,7 @@ label_to_theme = {0: 'Computer science', 1: 'Economics', 2: 'Electrical Engineer
47
 
48
  st.title("Arxiv articles classification")
49
  st.markdown("<h1 style='text-align: center;'><img width=300px src='https://media.wired.com/photos/592700e3cfe0d93c474320f1/191:100/w_1200,h_630,c_limit/faces-icon.jpg'>", unsafe_allow_html=True)
50
- st.markdown("This is an interface that can determine the article's topic based on its title and summary. Though it can work with title only, it is recommended that you provide summary if possible - this will result in a better prediction quality.")
51
 
52
  tokenizer, model = load_model()
53
 
@@ -58,5 +60,8 @@ button = st.button('Run')
58
  if button:
59
  prediction, prediction_probs = predict(title, summary, tokenizer, model)
60
  ans = get_results(prediction, prediction_probs)
61
- st.subheader('Results:')
62
- st.write(ans)
 
 
 
 
18
  @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})
19
  def predict(title, summary, tokenizer, model):
20
  text = title + "\n" + summary
21
+ if len(text) < 20:
22
+ return 'error'
23
  tokens = tokenizer.encode(text)
24
  with torch.no_grad():
25
  logits = model(torch.as_tensor([tokens]))[0]
 
40
 
41
  @st.cache(suppress_st_warning=True)
42
  def get_results(prediction, prediction_probs):
43
+ frame = pd.DataFrame({'Category': prediction, 'Confidence': prediction_probs})
44
  frame.index = np.arange(1, len(frame) + 1)
45
  return frame
46
 
 
49
 
50
  st.title("Arxiv articles classification")
51
  st.markdown("<h1 style='text-align: center;'><img width=300px src='https://media.wired.com/photos/592700e3cfe0d93c474320f1/191:100/w_1200,h_630,c_limit/faces-icon.jpg'>", unsafe_allow_html=True)
52
+ st.markdown("This is an interface that can determine the article's category based on its title and summary. Though it can work with title only, it is recommended that you provide summary if possible - this will result in a better prediction quality.")
53
 
54
  tokenizer, model = load_model()
55
 
 
60
  if button:
61
  prediction, prediction_probs = predict(title, summary, tokenizer, model)
62
  ans = get_results(prediction, prediction_probs)
63
+ if ans == 'error':
64
+ st.error("Your input is too short. It is probably not a real article, please try again.")
65
+ else:
66
+ st.subheader('Results:')
67
+ st.write(ans)