EgorDan commited on
Commit
eb4714a
·
1 Parent(s): 5963bd7

Added genre guessing

Browse files
Files changed (1) hide show
  1. pages/genre_model.py +66 -0
pages/genre_model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from joblib import load
4
+ import textwrap
5
+ import streamlit as st
6
+
7
+ device = 'cpu'
8
+
9
+
10
+
11
+ class GenreNet(nn.Module):
12
+ def __init__(self, config):
13
+ super().__init__()
14
+ # параметры сетиnspose arrayroupout']
15
+ self.dropout = config['dropout']
16
+ self.out_range = config['out_range']
17
+
18
+
19
+ # финальный полносвязный слой для пронгоза оценки
20
+ self.head = nn.Sequential(
21
+ nn.Linear(312, 256),
22
+ nn.Dropout(self.dropout[0]),
23
+ nn.ReLU(),
24
+ nn.Linear(256, 128),
25
+ nn.Dropout(self.dropout[0]),
26
+ nn.ReLU(),
27
+ nn.Linear(128, 64),
28
+ nn.Dropout(self.dropout[0]),
29
+ nn.ReLU(),
30
+ nn.Linear(64, 1),
31
+ )
32
+
33
+
34
+
35
+ def forward(self, emb):
36
+ x = torch.sigmoid(self.head(emb))
37
+ x = x * (self.out_range[1] - self.out_range[0]) + self.out_range[0]
38
+ return(x)
39
+
40
+ config = {
41
+ 'dropout': [.5],
42
+ 'out_range': [1.,5.] # для номировки выходных оценок
43
+ }
44
+
45
+ bert = load('./model.joblib')
46
+ model = GenreNet(config)
47
+ model.load_state_dict(torch.load('./pages/weights_los065_ep100_lr0001_lay256_128_64_1.pt', map_location=device))
48
+ tokenizer = load('./tokenizer.joblib')
49
+
50
+ def embed_bert_cls(text, model, tokenizer):
51
+ t = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
52
+ with torch.no_grad():
53
+ model_output = model(**{k: v.to(device) for k, v in t.items()})
54
+ embeddings = model_output.last_hidden_state[:, 0, :]
55
+ embeddings = torch.nn.functional.normalize(embeddings)
56
+ return embeddings[0]
57
+
58
+ genre = {1 : 'Романтика', 2:'Поэзия', 3:'Детектив', 4:'Приключения', 5:'Фантастика', }
59
+
60
+ prompt = st.text_input('Узнаем жанр!',)
61
+ if len(prompt) > 1:
62
+ with torch.inference_mode():
63
+ prompt_embedding = embed_bert_cls([prompt], bert, tokenizer)
64
+ out = model(prompt_embedding).cpu().numpy()
65
+ #for out_ in out:
66
+ st.write('Предполагаемый жанр:', genre[int(round(out.item(), 0))])