qc7 commited on
Commit
f889f9f
1 Parent(s): 09648fc

Copy app,py

Browse files
Files changed (1) hide show
  1. app.py +44 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+ import transformers
6
+ from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification
7
+
8
+ @st.cache(suppress_st_warning=True, hash_funcs={transformers.AutoTokenizer: lambda _: None})
9
+ def load_tok_and_model():
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForSequenceClassification.from_pretrained(".")
12
+ return tokenizer, model
13
+
14
+
15
+ CATEGORIES = ["Computer Science", "Economics", "Electrical Engineering", "Mathematics",
16
+ "Q. Biology", "Q. Finances", "Statistics" , "Physics"]
17
+
18
+
19
+ @st.cache(suppress_st_warning=True, hash_funcs={transformers.AutoTokenizer: lambda _: None})
20
+ def forward_pass(title, abstract, tokenizer, model):
21
+ title_tensor = torch.tensor(tokenizer(title, padding="max_length", truncation=True, max_length=32)['input_ids'])
22
+ abstract_tensor = torch.tensor(tokenizer(abstract, padding="max_length", truncation=True, max_length=480)['input_ids'])
23
+
24
+ embeddings = torch.cat((title_tensor, abstract_tensor))
25
+ assert embeddings.shape == (512,)
26
+ with torch.no_grad():
27
+ logits = model(embeddings[None])['logits'][0]
28
+ assert logits.shape == (8,)
29
+ probs = torch.softmax(logits).data.cpu().numpy()
30
+
31
+ return probs
32
+
33
+ st.title("Classification of arXiv articles' main topic")
34
+ st.markdown("Please provide both summary and title when possible")
35
+
36
+ tokenizer, model = load_tok_and_model()
37
+
38
+ title = st.text_area(label='Title', height=200)
39
+ abstract = st.text_area(label='Abstract', height=200)
40
+ button = st.button('Run classifier')
41
+
42
+ if button:
43
+ probs = forward_pass(title, abstract, tokenizer, model)
44
+ st.write(probs)