wrapper228 commited on
Commit
3a9fe38
1 Parent(s): dd32181

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -6,17 +6,25 @@ from transformers import TrainingArguments, Trainer, AutoTokenizer, DataCollator
6
  from PIL import Image
7
 
8
 
9
- with open('labels.pickle', 'rb') as handle:
10
- labels = pickle.load(handle)
 
 
 
 
 
 
 
 
11
 
12
  def predict_topic_by_title_and_abstract(text, model):
13
  st.write(labels)
 
14
  tokens = tokenizer(text, return_tensors='pt', truncation=True)
15
  with torch.no_grad():
16
  logits = model(**tokens).logits
17
  probs = torch.nn.functional.softmax(logits[0], dim=0).numpy() * 100
18
- #ans = list(zip(probs,labels.values()))
19
- ans = list(zip(probs,[x for x in range(8)]))
20
  ans.sort(reverse=True)
21
  sum = 0
22
  i = 0
 
6
  from PIL import Image
7
 
8
 
9
+ labels = {
10
+ "0":"biology"
11
+ "1":"computer science"
12
+ "2":"economics"
13
+ "3":"electrics"
14
+ "4":"finance"
15
+ "5":"math"
16
+ "6":"physics"
17
+ "7":"statistics"
18
+ }
19
 
20
  def predict_topic_by_title_and_abstract(text, model):
21
  st.write(labels)
22
+ st.write(labels.values())
23
  tokens = tokenizer(text, return_tensors='pt', truncation=True)
24
  with torch.no_grad():
25
  logits = model(**tokens).logits
26
  probs = torch.nn.functional.softmax(logits[0], dim=0).numpy() * 100
27
+ ans = list(zip(probs,labels.values()))
 
28
  ans.sort(reverse=True)
29
  sum = 0
30
  i = 0