File size: 3,176 Bytes
6c6aac8
 
 
 
 
 
 
 
104c2ee
 
 
6c6aac8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62bf081
6c6aac8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104c2ee
6c6aac8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a81b55
6c6aac8
 
8a81b55
6c6aac8
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import json 
import numpy as np
import joblib

from transformers import AutoTokenizer, AutoModel
import torch
import pickle

import nltk
nltk.download('stopwords')

import matplotlib.pyplot as plt

with open('srcs/vocab_to_int.json', encoding='utf-8') as f:
    vocab_to_int = json.load(f)

with open('srcs/int_to_vocab.json', encoding='utf-8') as f:
    int_to_vocab = json.load(f)

VOCAB_SIZE = len(vocab_to_int) + 1
EMBEDDING_DIM = 64 # embedding_dim
SEQ_LEN = 350
HIDDEN_SIZE = 64

with open('srcs/embedding_matrix.npy', 'rb') as f:
    embedding_matrix = np.load(f)



from srcs.srcs import LSTMConcatAttentionB
from srcs.srcs import Text_ex, clean


log_reg_vec = joblib.load('srcs/log_reg_vec.sav')

log_reg_bert = joblib.load('srcs/log_reg_bert.sav')

texter = Text_ex(clean, vocab_to_int, SEQ_LEN)
lstm = LSTMConcatAttentionB()

lstm.load_state_dict(torch.load('srcs/lstm.pt'))


vectorizer = pickle.load(open("srcs/vectorizer.pickle", "rb"))
tokenizer = AutoTokenizer.from_pretrained('srcs/tokenzier', local_files_only=True)
bert = torch.load('srcs/bert.pt')



from srcs.srcs import PredMaker
predM = PredMaker(model1=log_reg_vec, model2=lstm, rubert=bert, model3=log_reg_bert, vectorizer=vectorizer, texter=texter, clean_func=clean, tokenizer=tokenizer, itc=int_to_vocab)




import streamlit as st
st.markdown("""
    <style>
        section[data-testid="stSidebar"][aria-expanded="true"]{
            display: none;
        }
    </style>
    """, unsafe_allow_html=True)


st.title('Кинопоиск')
st.page_link("app.py", label="Home", icon='🏠')

import streamlit as st

txt = st.text_area(
    "Введите сюда отзыв на фильм:",
    "",
    )

# st.write(f'Введено {len(txt)} символов.')

if txt == '' or len(txt) < 12:
    if len(txt) >= 1:
        st.write('Введи что-нибудь нормальное')
else:
    text = txt
    res1, res2, res3, t, att, *times = predM(text)
    t_ = t[0].numpy()[0]
    k = len(t[1].split()) + 1

    labels = [int_to_vocab[str(x)] for x in t_ if int_to_vocab.get(str(x))]

    if list(set(labels[-k:])) == ["<pad>"]:
        st.write('Давай по новой миша, всё @**##')
        st.write(set(labels[-k:]))
    else:
        st.toast('!', icon='🎉')
        di = {0:'Плохо',1:'Нейтрально',2:'Хорошо'}
        d = {0: st.error, 1: st.warning, 2: st.success}

        d[res1](f'Предсказание 1-й модели: {di[res1]}')
        st.write(f'время = {round(times[0],3)}c, f1-score = 0.64')

        d[res2](f'Предсказание 2-й модели: {di[res2]}')
        st.write(f'время = {round(times[1],3)}c, f1-score = 0.70')

        d[res3](f'Предсказание 3-й модели: {di[res3]}')
        st.write(f'время = {round(times[2],3)}c, f1-score = 0.66')


        plt.figure(figsize=(8, 8))
        plt.barh(np.arange(len(t_))[-k:], att[-k:])
        plt.yticks(ticks = np.arange(len(t_))[-k:], labels = labels[-k:])
        plt.title(f'f1-score = 0.7\npred = {di[res2]}\ntime = {round(times[1],3)}c');
        st.set_option('deprecation.showPyplotGlobalUse', False)
        st.pyplot()