Spaces:
Runtime error
Runtime error
Konstantin Gordeev
commited on
Commit
•
b60d6e8
1
Parent(s):
16a6ec1
Nothing
Browse files
app.py
CHANGED
@@ -1,12 +1,52 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import
|
|
|
3 |
|
4 |
-
|
5 |
-
|
|
|
|
|
6 |
|
7 |
|
8 |
-
|
|
|
|
|
|
|
9 |
|
10 |
-
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import DistilBertModel, DistilBertTokenizer
|
3 |
+
import torch
|
4 |
|
5 |
+
model_path = './models/pytorch_distilbert.bin'
|
6 |
+
vocab_path = './models/vocab_distilbert.bin'
|
7 |
+
device = torch.device('cpu')
|
8 |
+
MAX_LEN = 512
|
9 |
|
10 |
|
11 |
+
def get_labels(text, model, tokenizer, count_labels=8):
|
12 |
+
tokens = tokenizer(text, return_tensors='pt')
|
13 |
+
outputs = model(**tokens)
|
14 |
+
probs = torch.nn.Softmax()(outputs.logits)
|
15 |
|
16 |
+
labels = ['Computer_science', 'Economics',
|
17 |
+
'Electrical_Engineering_and_Systems_Science', 'Mathematics',
|
18 |
+
'Physics', 'Quantitative_Biology', 'Quantitative_Finance',
|
19 |
+
'Statistics']
|
20 |
|
21 |
+
sort_lst = sorted([(prob, label) for prob, label in zip(probs.detach().numpy()[0], labels)], key=lambda x: -x[0])
|
22 |
+
cumsum = 0
|
23 |
+
result_labels = []
|
24 |
+
for pair in sort_lst:
|
25 |
+
cumsum += pair[0]
|
26 |
+
if cumsum > 0.95 and len(result_labels) >= 1:
|
27 |
+
return result_labels
|
28 |
+
result_labels.append(pair[1])
|
29 |
+
|
30 |
+
|
31 |
+
@st.cache(allow_output_mutation=True)
|
32 |
+
def load_model():
|
33 |
+
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased")
|
34 |
+
model = DistilBertModel.from_pretrained("distilbert-base-cased", num_labels=8)
|
35 |
+
model.load_state_dict(torch.load('weight_model'))
|
36 |
+
return model, tokenizer
|
37 |
+
|
38 |
+
|
39 |
+
tokenizer = DistilBertTokenizer.from_pretrained(vocab_path)
|
40 |
+
model = torch.load(model_path, map_location=torch.device(device))
|
41 |
+
|
42 |
+
st.markdown("### Movie genre classification")
|
43 |
+
|
44 |
+
text = st.text_area("Write some movie description")
|
45 |
+
|
46 |
+
if st.button('Predict'):
|
47 |
+
with st.spinner("Wait..."):
|
48 |
+
if not text:
|
49 |
+
st.error("Write something.")
|
50 |
+
else:
|
51 |
+
pred = predict(text, model.to(device))
|
52 |
+
st.success("\n\n".join(pred))
|