Konstantin Gordeev commited on
Commit
b60d6e8
1 Parent(s): 16a6ec1
Files changed (1) hide show
  1. app.py +46 -6
app.py CHANGED
@@ -1,12 +1,52 @@
1
  import streamlit as st
2
- from transformers import pipeline
 
3
 
4
- st.markdown("## Movie genre classification")
5
- st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
 
 
6
 
7
 
8
- pipe = pipeline("text-classification", "Tejas3/distillbert_110_uncased_movie_genre")
 
 
 
9
 
10
- text = st.text_area("TEXT HERE")
 
 
 
11
 
12
- st.markdown(f"{pipe(text)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import DistilBertModel, DistilBertTokenizer
3
+ import torch
4
 
5
+ model_path = './models/pytorch_distilbert.bin'
6
+ vocab_path = './models/vocab_distilbert.bin'
7
+ device = torch.device('cpu')
8
+ MAX_LEN = 512
9
 
10
 
11
+ def get_labels(text, model, tokenizer, count_labels=8):
12
+ tokens = tokenizer(text, return_tensors='pt')
13
+ outputs = model(**tokens)
14
+ probs = torch.nn.Softmax()(outputs.logits)
15
 
16
+ labels = ['Computer_science', 'Economics',
17
+ 'Electrical_Engineering_and_Systems_Science', 'Mathematics',
18
+ 'Physics', 'Quantitative_Biology', 'Quantitative_Finance',
19
+ 'Statistics']
20
 
21
+ sort_lst = sorted([(prob, label) for prob, label in zip(probs.detach().numpy()[0], labels)], key=lambda x: -x[0])
22
+ cumsum = 0
23
+ result_labels = []
24
+ for pair in sort_lst:
25
+ cumsum += pair[0]
26
+ if cumsum > 0.95 and len(result_labels) >= 1:
27
+ return result_labels
28
+ result_labels.append(pair[1])
29
+
30
+
31
+ @st.cache(allow_output_mutation=True)
32
+ def load_model():
33
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased")
34
+ model = DistilBertModel.from_pretrained("distilbert-base-cased", num_labels=8)
35
+ model.load_state_dict(torch.load('weight_model'))
36
+ return model, tokenizer
37
+
38
+
39
+ tokenizer = DistilBertTokenizer.from_pretrained(vocab_path)
40
+ model = torch.load(model_path, map_location=torch.device(device))
41
+
42
+ st.markdown("### Movie genre classification")
43
+
44
+ text = st.text_area("Write some movie description")
45
+
46
+ if st.button('Predict'):
47
+ with st.spinner("Wait..."):
48
+ if not text:
49
+ st.error("Write something.")
50
+ else:
51
+ pred = predict(text, model.to(device))
52
+ st.success("\n\n".join(pred))