Spaces:
Runtime error
Runtime error
Konstantin Gordeev
commited on
Commit
•
49e8106
1
Parent(s):
b60d6e8
Update model
Browse files
app.py
CHANGED
@@ -1,46 +1,37 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import
|
3 |
import torch
|
|
|
|
|
4 |
|
5 |
-
model_path = '
|
6 |
-
vocab_path = './models/vocab_distilbert.bin'
|
7 |
device = torch.device('cpu')
|
8 |
-
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
outputs = model(**tokens)
|
14 |
-
probs = torch.nn.Softmax()(outputs.logits)
|
15 |
-
|
16 |
-
labels = ['Computer_science', 'Economics',
|
17 |
-
'Electrical_Engineering_and_Systems_Science', 'Mathematics',
|
18 |
-
'Physics', 'Quantitative_Biology', 'Quantitative_Finance',
|
19 |
-
'Statistics']
|
20 |
-
|
21 |
-
sort_lst = sorted([(prob, label) for prob, label in zip(probs.detach().numpy()[0], labels)], key=lambda x: -x[0])
|
22 |
-
cumsum = 0
|
23 |
-
result_labels = []
|
24 |
-
for pair in sort_lst:
|
25 |
-
cumsum += pair[0]
|
26 |
-
if cumsum > 0.95 and len(result_labels) >= 1:
|
27 |
-
return result_labels
|
28 |
-
result_labels.append(pair[1])
|
29 |
|
30 |
|
31 |
@st.cache(allow_output_mutation=True)
|
32 |
def load_model():
|
33 |
-
tokenizer =
|
34 |
-
model =
|
35 |
-
model.load_state_dict(torch.load(
|
36 |
return model, tokenizer
|
37 |
|
38 |
|
39 |
-
tokenizer
|
40 |
-
|
|
|
|
|
|
|
|
|
41 |
|
42 |
st.markdown("### Movie genre classification")
|
43 |
|
|
|
|
|
44 |
text = st.text_area("Write some movie description")
|
45 |
|
46 |
if st.button('Predict'):
|
@@ -48,5 +39,6 @@ if st.button('Predict'):
|
|
48 |
if not text:
|
49 |
st.error("Write something.")
|
50 |
else:
|
51 |
-
pred = predict(text, model
|
52 |
-
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
import torch
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
|
7 |
+
model_path = 'model'
|
|
|
8 |
device = torch.device('cpu')
|
9 |
+
model_name = 'distilbert-base-cased'
|
10 |
|
11 |
+
genres = np.array(['Animation', 'Comedy', 'Adult', 'Adventure', 'Musical', 'History', 'Reality-TV', 'Film-Noir',
|
12 |
+
'Sport', 'Biography', 'Drama', 'Fantasy', 'Romance', 'Thriller', 'News', 'Documentary', 'Sci-Fi', 'Music',
|
13 |
+
'Family', 'Mystery', 'Crime', 'Horror', 'War', 'Action', 'Western'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
@st.cache(allow_output_mutation=True)
|
17 |
def load_model():
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
19 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(genres))
|
20 |
+
model.load_state_dict(torch.load(model_path))
|
21 |
return model, tokenizer
|
22 |
|
23 |
|
24 |
+
def predict(text: str, tokenizer, model):
|
25 |
+
tokens = tokenizer.encode(text)
|
26 |
+
probas = torch.nn.Softmax(dim=1)(model(torch.as_tensor([tokens], device=device))[0]).detach().numpy()[0]
|
27 |
+
top_5_index = probas.argsort()[:-6:-1]
|
28 |
+
return dict(zip(genres[top_5_index], probas[top_5_index]))
|
29 |
+
|
30 |
|
31 |
st.markdown("### Movie genre classification")
|
32 |
|
33 |
+
model, tokenizer = load_model()
|
34 |
+
|
35 |
text = st.text_area("Write some movie description")
|
36 |
|
37 |
if st.button('Predict'):
|
|
|
39 |
if not text:
|
40 |
st.error("Write something.")
|
41 |
else:
|
42 |
+
pred = predict(text, tokenizer, model)
|
43 |
+
result = pd.DataFrame(list(pred.values()), index=list(pred.keys()), columns=['Probability'])
|
44 |
+
st.write(result)
|