Konstantin Gordeev commited on
Commit
49e8106
1 Parent(s): b60d6e8

Update model

Browse files
Files changed (1) hide show
  1. app.py +22 -30
app.py CHANGED
@@ -1,46 +1,37 @@
1
  import streamlit as st
2
- from transformers import DistilBertModel, DistilBertTokenizer
3
  import torch
 
 
4
 
5
- model_path = './models/pytorch_distilbert.bin'
6
- vocab_path = './models/vocab_distilbert.bin'
7
  device = torch.device('cpu')
8
- MAX_LEN = 512
9
 
10
-
11
- def get_labels(text, model, tokenizer, count_labels=8):
12
- tokens = tokenizer(text, return_tensors='pt')
13
- outputs = model(**tokens)
14
- probs = torch.nn.Softmax()(outputs.logits)
15
-
16
- labels = ['Computer_science', 'Economics',
17
- 'Electrical_Engineering_and_Systems_Science', 'Mathematics',
18
- 'Physics', 'Quantitative_Biology', 'Quantitative_Finance',
19
- 'Statistics']
20
-
21
- sort_lst = sorted([(prob, label) for prob, label in zip(probs.detach().numpy()[0], labels)], key=lambda x: -x[0])
22
- cumsum = 0
23
- result_labels = []
24
- for pair in sort_lst:
25
- cumsum += pair[0]
26
- if cumsum > 0.95 and len(result_labels) >= 1:
27
- return result_labels
28
- result_labels.append(pair[1])
29
 
30
 
31
  @st.cache(allow_output_mutation=True)
32
  def load_model():
33
- tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased")
34
- model = DistilBertModel.from_pretrained("distilbert-base-cased", num_labels=8)
35
- model.load_state_dict(torch.load('weight_model'))
36
  return model, tokenizer
37
 
38
 
39
- tokenizer = DistilBertTokenizer.from_pretrained(vocab_path)
40
- model = torch.load(model_path, map_location=torch.device(device))
 
 
 
 
41
 
42
  st.markdown("### Movie genre classification")
43
 
 
 
44
  text = st.text_area("Write some movie description")
45
 
46
  if st.button('Predict'):
@@ -48,5 +39,6 @@ if st.button('Predict'):
48
  if not text:
49
  st.error("Write something.")
50
  else:
51
- pred = predict(text, model.to(device))
52
- st.success("\n\n".join(pred))
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
+ import numpy as np
5
+ import pandas as pd
6
 
7
+ model_path = 'model'
 
8
  device = torch.device('cpu')
9
+ model_name = 'distilbert-base-cased'
10
 
11
+ genres = np.array(['Animation', 'Comedy', 'Adult', 'Adventure', 'Musical', 'History', 'Reality-TV', 'Film-Noir',
12
+ 'Sport', 'Biography', 'Drama', 'Fantasy', 'Romance', 'Thriller', 'News', 'Documentary', 'Sci-Fi', 'Music',
13
+ 'Family', 'Mystery', 'Crime', 'Horror', 'War', 'Action', 'Western'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  @st.cache(allow_output_mutation=True)
17
  def load_model():
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(genres))
20
+ model.load_state_dict(torch.load(model_path))
21
  return model, tokenizer
22
 
23
 
24
+ def predict(text: str, tokenizer, model):
25
+ tokens = tokenizer.encode(text)
26
+ probas = torch.nn.Softmax(dim=1)(model(torch.as_tensor([tokens], device=device))[0]).detach().numpy()[0]
27
+ top_5_index = probas.argsort()[:-6:-1]
28
+ return dict(zip(genres[top_5_index], probas[top_5_index]))
29
+
30
 
31
  st.markdown("### Movie genre classification")
32
 
33
+ model, tokenizer = load_model()
34
+
35
  text = st.text_area("Write some movie description")
36
 
37
  if st.button('Predict'):
 
39
  if not text:
40
  st.error("Write something.")
41
  else:
42
+ pred = predict(text, tokenizer, model)
43
+ result = pd.DataFrame(list(pred.values()), index=list(pred.keys()), columns=['Probability'])
44
+ st.write(result)