ElijahDi commited on
Commit
ed0e769
·
verified ·
1 Parent(s): f0c41ac

Upload 11 files

Browse files
BERT_base_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d279897d59dadf5867bd256912786e66012cbb5a335ba9f6ef139e68d87f055
3
+ size 6991
classifier_bag.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b645bb55ece68eae1846ea54da15f4d0b241b0d8bfce1292b9922b3c381dfb2b
3
+ size 658287
classifier_tf.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa416d91b4cd841e0709fa8da8ef59a94d386eee0a6d796edda5e0057182d4b7
3
+ size 658287
lstm_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b309ab4465c41ebdd2208525b33364e29cf3d8b522a9ae9f29685443742a13c8
3
+ size 919778
main.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ import time
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import tensor
8
+
9
+ import joblib
10
+ from dataclasses import dataclass
11
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
12
+ import json
13
+
14
+ from preprocessing import predict_review, data_preprocessing_hard
15
+ from model_lstm import LSTMClassifier
16
+ # from BERT_inputs import BertInputs
17
+
18
+ device = 'cpu'
19
+ classifier_bag = joblib.load('classifier_bag.pkl')
20
+ classifier_tf = joblib.load('classifier_tf.pkl')
21
+ BERT_lin_cl = joblib.load('BERT_base_model.pkl')
22
+
23
+ selected_model = st.sidebar.radio("Зачем пришел?", ("Классифиция отзывов лечебных учреждений",
24
+ "Оценка степени токсичности пользовательского сообщения",
25
+ "Генерация текста GPT-моделью по пользовательскому prompt"))
26
+
27
+ # Классификация отзыва на поликлиники
28
+ model_options = ["BagOfWords", "TF-IDF", "LSTM", "BERT-based-ru"]
29
+ if selected_model == "Классифиция отзывов лечебных учреждений":
30
+ st.title("""
31
+ Приложение классифицирует твой отзыв и подскажет позитивный он или негативный
32
+ """)
33
+ st.write("""
34
+ Классификация происходит с использованием классических ML моделей, нейросетевой модели LSTM,
35
+ и, как вариант, с использованием нейросетевой модели Bert-basic-ru для векторизации и линейной
36
+ регрессии для классификации.
37
+ """)
38
+ vectorizer_1 = joblib.load('vectorizer_bag.joblib')
39
+ vectorizer_2 = joblib.load('vectorizer_tf.joblib')
40
+
41
+ # LSTM
42
+ with open('vocab_lstm.json', 'r') as file:
43
+ vocab_to_int = json.load(file)
44
+
45
+ @dataclass
46
+ class ConfigRNN:
47
+ vocab_size: int
48
+ device : str
49
+ n_layers : int
50
+ embedding_dim : int
51
+ hidden_size : int
52
+ seq_len : int
53
+ bidirectional : bool or int
54
+
55
+ net_config = ConfigRNN(
56
+ vocab_size = len(vocab_to_int)+1,
57
+ device='cpu',
58
+ n_layers=2,
59
+ embedding_dim=64,
60
+ hidden_size=32,
61
+ seq_len = 100,
62
+ bidirectional=False)
63
+
64
+ lstm = LSTMClassifier(net_config)
65
+ lstm.load_state_dict(torch.load('lstm_model.pth', map_location=device))
66
+ lstm.to(device)
67
+ # lstm.eval()
68
+
69
+ # BERT
70
+ tokenizer = AutoTokenizer.from_pretrained("Geotrend/bert-base-ru-cased")
71
+ model = AutoModel.from_pretrained("Geotrend/bert-base-ru-cased")
72
+ # model.eval()
73
+ MAX_LEN = 200
74
+
75
+ data = pd.DataFrame({
76
+ 'Модель': ["BagOfWords", "TF-IDF", "LSTM", "BERT-based-ru"],
77
+ 'f1_macro': [0.934, 0.939, 0.009, 0.845]
78
+ })
79
+
80
+ st.subheader("""
81
+ Немного информации о точности используемых моделей после обучения:
82
+ """)
83
+ # st.write(data)
84
+ st.table(data)
85
+ user_text_input = st.text_area('Введите ваш отзыв здесь:', '')
86
+ selected_model_name = st.selectbox('Выберите модель:', model_options, index=0)
87
+
88
+ if st.button('Предсказать'):
89
+ start_time = time.time()
90
+
91
+ if selected_model_name == "BagOfWords":
92
+ X = vectorizer_1.transform([data_preprocessing_hard(user_text_input)])
93
+ predictions = classifier_bag.predict(X)
94
+
95
+ elif selected_model_name == "TF-IDF":
96
+ X = vectorizer_2.transform([data_preprocessing_hard(user_text_input)])
97
+ predictions = classifier_tf.predict(X)
98
+
99
+ elif selected_model_name == "LSTM":
100
+ predictions = predict_review(model=lstm, review_text=user_text_input, net_config=net_config,
101
+ vocab_to_int=vocab_to_int)
102
+
103
+ elif selected_model_name == "BERT-based-ru":
104
+ tokens = tokenizer.encode(user_text_input, add_special_tokens=True)
105
+ padded_tokens = tokens + [0] * (MAX_LEN - len(tokens))
106
+ input_tensor = tensor(padded_tokens).unsqueeze(0)
107
+ with torch.no_grad():
108
+ outputs = model(input_tensor)
109
+ X = outputs.last_hidden_state[:,0,:].detach().cpu().numpy()
110
+ predictions = BERT_lin_cl.predict(X)
111
+ pass
112
+
113
+ end_time = time.time()
114
+ prediction_time = end_time - start_time
115
+
116
+ model_message = f'Предсказание модели {selected_model_name}:'
117
+ if predictions >= 0.5:
118
+ # st.write(f'{model_message} кажется это положительный комментарий.')
119
+ gif_url = 'https://media2.giphy.com/media/v1.Y2lkPTc5MGI3NjExOTdnYjJ1eTE0bjRuMGptcjhpdTk2YTYzeXEzMzlidWFsamY2bW8wZyZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/LUg1GEjapflW7Vg6B9/giphy.gif'
120
+ st.image(gif_url, caption="Позитивный коментарий")
121
+ else:
122
+ # st.write(f'{model_message} кажется это негативный комментарий.')
123
+ gif_url = 'https://i.gifer.com/LdC3.gif'
124
+ st.image(gif_url, caption="Негативный коментарий")
125
+ st.write(f'Время предсказания: {prediction_time:.4f} секунд')
126
+
127
+
128
+
129
+ # Оценка степени токсичности пользовательского сообщения
130
+ elif selected_model == "Оценка степени токсичности пользовательского сообщения":
131
+ st.title("""
132
+ Приложение классифицирует токсичный комментарий или нет
133
+ """)
134
+
135
+ st.write("""
136
+ Классификация происходит с использованием нейросетевой модели rubert-tiny-toxicity.
137
+ """)
138
+
139
+ # Toxicity
140
+ model_t_checkpoint = 'cointegrated/rubert-tiny-toxicity'
141
+ tokenizer_t = AutoTokenizer.from_pretrained(model_t_checkpoint)
142
+ model_t = AutoModelForSequenceClassification.from_pretrained(model_t_checkpoint)
143
+
144
+ def text2toxicity(text, aggregate=True):
145
+ with torch.no_grad():
146
+ inputs = tokenizer_t(text, return_tensors='pt', truncation=True, padding=True).to(model_t.device)
147
+ proba = torch.sigmoid(model_t(**inputs).logits).cpu().numpy()
148
+ if isinstance(text, str):
149
+ proba = proba[0]
150
+ if aggregate:
151
+ return 1 - proba.T[0] * (1 - proba.T[-1])
152
+ return proba
153
+
154
+ user_text_input = st.text_area('Введите ваш отзыв здесь:')
155
+
156
+ if st.button('Предсказать'):
157
+ start_time = time.time()
158
+ proba = text2toxicity(user_text_input, True)
159
+ end_time = time.time()
160
+ prediction_time = end_time - start_time
161
+
162
+ model_message = f'Предсказание модели:'
163
+ if proba >= 0.5:
164
+ # st.write(f' Кажется это токсичный комментарий.')
165
+ gif_url = "https://media1.giphy.com/media/cInbau65cwPWUeGTIZ/giphy.gif?cid=6c09b952seqdtvky8yn2uq6bt3kvo1vu5sdzpkdznjvmtxsh&ep=v1_internal_gif_by_id&rid=giphy.gif&ct=s"
166
+ st.image(gif_url, caption="ТОКСИК")
167
+
168
+ else:
169
+ # st.write(f' Кажется это не токсичный комментарий.')
170
+ gif_url = 'https://i.gifer.com/origin/51/518fbbf9cf32763122f9466d3c686bb3_w200.gif'
171
+ st.image(gif_url, caption="МИЛОТА")
172
+ st.write(f'Время предсказания: {prediction_time:.4f} секунд')
173
+
174
+
175
+
176
+ # Генерация текста GPT-моделью
177
+ elif selected_model == "Генерация текста GPT-моделью по пользовательскому prompt":
178
+ st.title("""
179
+ Приложение генерирует текст по Вашему промту
180
+ """)
181
+
182
+ st.write("""
183
+ Для генерации текста используется предобученная сеть GPT.
184
+ """)
185
+ uploaded_img = st.sidebar.file_uploader('Загрузи свое космофото', type=["jpg", "png", "jpeg"])
186
+ if uploaded_img is not None:
187
+ input_img = io.imread(uploaded_img)
188
+ else:
189
+ input_img = io.imread('/Users/id/Documents/strlit/cv_project/Segm.jpg')
model_lstm.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class LSTMClassifier(nn.Module):
5
+ def __init__(self, rnn_conf) -> None:
6
+ super().__init__()
7
+
8
+ self.embedding_dim = rnn_conf.embedding_dim
9
+ self.hidden_size = rnn_conf.hidden_size
10
+ self.bidirectional = rnn_conf.bidirectional
11
+ self.n_layers = rnn_conf.n_layers
12
+
13
+ self.embedding = nn.Embedding(rnn_conf.vocab_size, self.embedding_dim)
14
+ self.lstm = nn.LSTM(
15
+ input_size = self.embedding_dim,
16
+ hidden_size = self.hidden_size,
17
+ bidirectional = self.bidirectional,
18
+ batch_first = True,
19
+ num_layers = self.n_layers
20
+ )
21
+ self.bidirect_factor = 2 if self.bidirectional else 1
22
+ self.clf = nn.Sequential(
23
+ nn.Linear(self.hidden_size * self.bidirect_factor, 32),
24
+ nn.Tanh(),
25
+ nn.Dropout(),
26
+ nn.Linear(32, 1)
27
+ )
28
+
29
+ def model_description(self):
30
+ direction = 'bidirect' if self.bidirectional else 'onedirect'
31
+ return f'lstm_{direction}_{self.n_layers}'
32
+
33
+
34
+ def forward(self, x: torch.Tensor):
35
+ embeddings = self.embedding(x)
36
+ out, _ = self.lstm(embeddings)
37
+ out = out[:, -1, :] # [все элементы батча, последний h_n, все элементы последнего h_n]
38
+ out = self.clf(out.squeeze())
39
+ return out
40
+
preprocessing.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import string
3
+ import numpy as np
4
+ import torch
5
+ import nltk
6
+ import pymorphy2
7
+
8
+ from nltk.corpus import stopwords
9
+ nltk.download('stopwords')
10
+ stop_words = set(stopwords.words('russian'))
11
+ morph = pymorphy2.MorphAnalyzer()
12
+
13
+ def data_preprocessing_hard(text: str) -> str:
14
+ text = text.lower()
15
+ text = re.sub('<.*?>', '', text)
16
+ text = re.sub(r'[^а-яА-Я\s]', '', text)
17
+ text = ''.join([c for c in text if c not in string.punctuation])
18
+ text = ' '.join([word for word in text.split() if word not in stop_words])
19
+ # text = ''.join([char for char in text if not char.isdigit()])
20
+ text = ' '.join([morph.parse(word)[0].normal_form for word in text.split()])
21
+
22
+ return text
23
+
24
+ def data_preprocessing(text: str) -> str:
25
+ """preprocessing string: lowercase, removing html-tags, punctuation and stopwords
26
+
27
+ Args:
28
+ text (str): input string for preprocessing
29
+
30
+ Returns:
31
+ str: preprocessed string
32
+ """
33
+
34
+ text = text.lower()
35
+ text = re.sub('<.*?>', '', text) # html tags
36
+ text = ''.join([c for c in text if c not in string.punctuation])# Remove punctuation
37
+ text = [word for word in text.split() if word not in stop_words]
38
+ text = ' '.join(text)
39
+ return text
40
+
41
+ def get_words_by_freq(sorted_words: list, n: int = 10) -> list:
42
+ return list(filter(lambda x: x[1] > n, sorted_words))
43
+
44
+ def padding(review_int: list, seq_len: int) -> np.array: # type: ignore
45
+ """Make left-sided padding for input list of tokens
46
+
47
+ Args:
48
+ review_int (list): input list of tokens
49
+ seq_len (int): max length of sequence, it len(review_int[i]) > seq_len it will be trimmed, else it will be padded by zeros
50
+
51
+ Returns:
52
+ np.array: padded sequences
53
+ """
54
+ features = np.zeros((len(review_int), seq_len), dtype = int)
55
+ for i, review in enumerate(review_int):
56
+ if len(review) <= seq_len:
57
+ zeros = list(np.zeros(seq_len - len(review)))
58
+ new = zeros + review
59
+ else:
60
+ new = review[: seq_len]
61
+ features[i, :] = np.array(new)
62
+
63
+ return features
64
+
65
+ def preprocess_single_string(
66
+ input_string: str,
67
+ seq_len: int,
68
+ vocab_to_int: dict,
69
+ verbose : bool = False
70
+ ) -> torch.tensor:
71
+ """Function for all preprocessing steps on a single string
72
+
73
+ Args:
74
+ input_string (str): input single string for preprocessing
75
+ seq_len (int): max length of sequence, it len(review_int[i]) > seq_len it will be trimmed, else it will be padded by zeros
76
+ vocab_to_int (dict, optional): word corpus {'word' : int index}. Defaults to vocab_to_int.
77
+
78
+ Returns:
79
+ list: preprocessed string
80
+ """
81
+
82
+ preprocessed_string = data_preprocessing(input_string)
83
+ result_list = []
84
+ for word in preprocessed_string.split():
85
+ try:
86
+ result_list.append(vocab_to_int[word])
87
+ except KeyError as e:
88
+ if verbose:
89
+ print(f'{e}: not in dictionary!')
90
+ pass
91
+ result_padded = padding([result_list], seq_len)[0]
92
+
93
+ return torch.tensor(result_padded)
94
+
95
+ def predict_review(model, review_text: str, net_config, vocab_to_int) -> torch.tensor:
96
+ sample = preprocess_single_string(review_text, net_config.seq_len, vocab_to_int)
97
+ probability_lstm = model(sample.unsqueeze(0)).to(net_config.device).sigmoid()
98
+ return probability_lstm.item()
requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ certifi==2024.2.2
2
+ charset-normalizer==3.3.2
3
+ click==8.1.7
4
+ dataclasses==0.6
5
+ DAWG-Python==0.7.2
6
+ docopt==0.6.2
7
+ filelock==3.13.1
8
+ fsspec==2023.12.2
9
+ huggingface-hub==0.20.3
10
+ idna==3.6
11
+ Jinja2==3.1.3
12
+ joblib==1.3.2
13
+ MarkupSafe==2.1.4
14
+ mpmath==1.3.0
15
+ networkx==3.2.1
16
+ nltk==3.8.1
17
+ numpy==1.26.3
18
+ packaging==23.2
19
+ pandas==2.2.0
20
+ pymorphy2==0.9.1
21
+ pymorphy2-dicts-ru==2.4.417127.4579844
22
+ python-dateutil==2.8.2
23
+ pytz==2024.1
24
+ PyYAML==6.0.1
25
+ regex==2023.12.25
26
+ requests==2.31.0
27
+ safetensors==0.4.2
28
+ six==1.16.0
29
+ sympy==1.12
30
+ tokenizers==0.15.1
31
+ torch==2.2.0
32
+ tqdm==4.66.1
33
+ transformers==4.37.2
34
+ typing_extensions==4.9.0
35
+ tzdata==2023.4
36
+ urllib3==2.2.0
vectorizer_bag.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78c823e4a6a3f06b1961a0d4e28e21be547839ed9606ed879d56315d1d7c01b2
3
+ size 3357923
vectorizer_tf.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78c823e4a6a3f06b1961a0d4e28e21be547839ed9606ed879d56315d1d7c01b2
3
+ size 3357923
vocab_lstm.json ADDED
The diff for this file is too large to render. See raw diff