rrevoid commited on
Commit
ad3d6a3
1 Parent(s): 9baaef5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import tokenizers
3
 
 
4
  import streamlit as st
5
  import torch.nn as nn
6
  from transformers import RobertaTokenizer, RobertaModel
@@ -29,19 +30,26 @@ cats = ['Computer Science', 'Economics', 'Electrical Engineering',
29
  def predict(outputs):
30
  top = 0
31
  probs = nn.functional.softmax(outputs, dim=1).tolist()[0]
 
 
 
32
 
33
  for prob, cat in sorted(zip(probs, cats), reverse=True):
34
  if top < 95:
35
  percent = prob * 100
36
  top += percent
37
- st.write(f'{cat}: {round(percent, 1)}')
 
 
 
 
38
 
39
  tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
40
  model = init_model()
41
 
42
  st.markdown("### Title")
43
 
44
- title = st.text_area("Enter title", height=20)
45
 
46
  st.markdown("### Abstract")
47
 
@@ -50,6 +58,7 @@ abstract = st.text_area("Enter abstract", height=200)
50
  if not title:
51
  st.warning("Please fill out so required fields")
52
  else:
 
53
  encoded_input = tokenizer(title + '. ' + abstract, return_tensors='pt', padding=True,
54
  max_length = 512, truncation=True)
55
  with torch.no_grad():
 
1
  import torch
2
  import tokenizers
3
 
4
+ import pandas as pd
5
  import streamlit as st
6
  import torch.nn as nn
7
  from transformers import RobertaTokenizer, RobertaModel
 
30
  def predict(outputs):
31
  top = 0
32
  probs = nn.functional.softmax(outputs, dim=1).tolist()[0]
33
+
34
+ top_cats = []
35
+ top_probs = []
36
 
37
  for prob, cat in sorted(zip(probs, cats), reverse=True):
38
  if top < 95:
39
  percent = prob * 100
40
  top += percent
41
+ top_cats.append(cat)
42
+ top_probs.append(prob)
43
+
44
+ chart_data = pd.DataFrame(top_probs, columns=top_cats)
45
+ st.bar_chart(chart_data)
46
 
47
  tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
48
  model = init_model()
49
 
50
  st.markdown("### Title")
51
 
52
+ title = st.text_area("* Enter title (required)", height=20)
53
 
54
  st.markdown("### Abstract")
55
 
 
58
  if not title:
59
  st.warning("Please fill out so required fields")
60
  else:
61
+ st.markdown("### Result")
62
  encoded_input = tokenizer(title + '. ' + abstract, return_tensors='pt', padding=True,
63
  max_length = 512, truncation=True)
64
  with torch.no_grad():