mbabanov commited on
Commit
15e4f68
1 Parent(s): 33d6eda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -1
app.py CHANGED
@@ -37,14 +37,27 @@ def LoadModel():
37
 
38
  model = LoadModel()
39
 
 
 
40
  def process(title, summary):
41
  text = title + summary
 
 
42
  model.eval()
43
  lines = [text]
44
  X = tokenizer(lines, padding=True, truncation=True, return_tensors="pt")
45
  out = model(X)
46
  probs = torch.exp(out[0])
47
- return probs
 
 
 
 
 
 
 
 
 
48
 
49
  title = st.text_area("Title", height=30)
50
 
 
37
 
38
  model = LoadModel()
39
 
40
+ classes = ['Computer Science', 'Mathematics', 'Physics', 'Quantitative Biology', 'Statistics']
41
+
42
  def process(title, summary):
43
  text = title + summary
44
+ if not text.strip():
45
+ return ''
46
  model.eval()
47
  lines = [text]
48
  X = tokenizer(lines, padding=True, truncation=True, return_tensors="pt")
49
  out = model(X)
50
  probs = torch.exp(out[0])
51
+ sorted_indexes = torch.argsort(probs, descending=True)
52
+ probs_sum = idx = 0
53
+ str = ''
54
+ while probs_sum < 0.95:
55
+ prob_idx = sorted_indexes[idx]
56
+ prob = probs[prob_idx]
57
+ str += f'{classes[prob_idx]}: {prob:.3f}\n'
58
+ idx += 1
59
+ probs_sum += prob
60
+ return str
61
 
62
  title = st.text_area("Title", height=30)
63