Spaces:
Sleeping
Sleeping
ivanovot
commited on
Commit
·
d1474ea
1
Parent(s):
6cc6c10
update
Browse files- .gitignore +3 -0
- app.py +46 -41
- model/__init__.py +0 -1
- model/__pycache__/__init__.cpython-312.pyc +0 -0
- model/__pycache__/model.cpython-312.pyc +0 -0
- model/model.py +0 -68
- model.pth → models/model_epoch_10.pt +2 -2
- models/model_epoch_20.pt +3 -0
- models/model_epoch_30.pt +3 -0
- models/model_epoch_40.pt +3 -0
- models/model_epoch_50.pt +3 -0
- models/training.log +204 -0
- notebooks/data.ipynb +200 -0
- notebooks/dataset.ipynb +106 -0
- notebooks/evaluate.ipynb +362 -0
- notebooks/model.ipynb +166 -0
- scr/__init__.py +0 -0
- scr/dataset.py +32 -0
- scr/model.py +40 -0
- scr/sbert.py +41 -0
- scr/train.py +117 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
venv/
|
2 |
+
*__pycache__
|
3 |
+
*.gradio
|
app.py
CHANGED
@@ -1,56 +1,61 @@
|
|
1 |
import gradio as gr
|
2 |
-
from model import
|
|
|
3 |
|
4 |
-
#
|
5 |
-
|
6 |
-
["Страницу обнови, дебил. Это тоже не оскорбление, а доказанный факт - не-дебил про себя во множественном числе писать не будет. Или мы в тебя верим - это ты и твои воображаемые друзья?"],
|
7 |
-
["УПАД Т! ТАМ НЕЛЬЗЯ СТРОИТЬ! ТЕХНОЛОГИЙ НЕТ! РАЗВОРУЮТ КАК ВСЕГДА! УЖЕ ТРЕЩИНАМИ ПОШ Л! ТУПЫЕ КИТАЗЫ НЕ МОГУТ НИЧЕГО НОРМАЛЬНО СДЕЛАТЬ!"],
|
8 |
-
["хорош врать, ты террорист-торчёк-шизофреник пруф: а вот без костюма да чутка учёный, миллиардер, филантроп"],
|
9 |
-
["Мне Вас очень жаль, если для Вас оскорбления - норма"],
|
10 |
-
["Осторожней на сверхманёврах. В предыдущей методичке у вас было написано, что добрые арабы никогда ни с кем не воевали, только торговали пряностями, лел. Шапочку из фольги сними"],
|
11 |
-
["Так то стоит около 12,5 тысяч, но мне вышло в 6636 рублей и бесплатная доставка"],
|
12 |
-
["Ну хочешь я стану твоим другом? Как тебя зовут? Чем увлекаешься?"],
|
13 |
-
["Ну так это в плане изготовления изделий своими руками,а вот готовить вроде умею.Короче буду сам на себе испытывать божественный напиток и куплю огнетушитель (промышленный на всякий случай)."],
|
14 |
-
["Я согласен, что это хорошая идея! Давайте подумаем, как можно улучшить её еще больше."],
|
15 |
-
["Очень полезная информация, спасибо за подробное объяснение! Я многому научился."],
|
16 |
-
["Мне нравится, как вы объясняете! Это действительно помогает разобраться в теме."],
|
17 |
-
["Отлично написано, теперь я лучше понимаю, как работать с этим инструментом."],
|
18 |
-
["Классная идея! Надо попробовать и посмотреть, как это работает на практике."],
|
19 |
-
["Ваши советы очень полезны. Это точно сэкономит мне время на следующий раз."],
|
20 |
-
["Спасибо за ваши рекомендации! Я обязательно попробую это в будущем."],
|
21 |
-
["Мне нравится ваше решение этой проблемы. Очень креативно и практично."],
|
22 |
-
["Спасибо за помощь! Вы помогли мне разобраться в ситуации и сэкономить много времени."],
|
23 |
-
["Согласен с вами, это действительно важный аспект, о котором стоит задуматься."],
|
24 |
-
["Очень вдохновляющий пост! Я буду следовать вашим рекомендациям."],
|
25 |
-
]
|
26 |
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
def predict_text(text):
|
|
|
29 |
word_count = len(text.split())
|
30 |
-
if word_count <
|
31 |
-
return "Слишком короткий текст"
|
32 |
|
33 |
-
|
34 |
-
|
|
|
35 |
|
36 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
demo = gr.Interface(
|
38 |
fn=predict_text, # Функция для предсказания
|
39 |
inputs=gr.Textbox(
|
40 |
label="Введите текст для проверки на токсичность", # Подпись для текстового поля
|
41 |
-
placeholder="Напишите комментарий для анализа", # Подсказка для
|
42 |
-
lines=5, # Количество строк
|
43 |
-
interactive=True
|
44 |
-
),
|
45 |
-
outputs=gr.Textbox(
|
46 |
-
label="Результат анализа", # Подпись для вывода
|
47 |
-
placeholder="Результат токсичности текста будет здесь", # Подсказка для вывода
|
48 |
),
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
examples=examples, # Примеры для пользователей
|
51 |
-
title="
|
52 |
-
description="Введите
|
|
|
53 |
)
|
54 |
|
55 |
-
# Запуск приложения
|
56 |
-
demo.launch(
|
|
|
1 |
import gradio as gr
|
2 |
+
from scr.model import Model
|
3 |
+
import torch
|
4 |
|
5 |
+
# Настройка устройства и загрузка модели
|
6 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
model = Model()
|
9 |
+
model.load_state_dict(torch.load('models/model_epoch_50.pt', map_location=device))
|
10 |
+
model.eval()
|
11 |
+
|
12 |
+
# Функция для предсказания оценки токсичности текста
|
13 |
def predict_text(text):
|
14 |
+
# Проверяем длину текста
|
15 |
word_count = len(text.split())
|
16 |
+
if word_count < 3:
|
17 |
+
return "Слишком короткий текст", None
|
18 |
|
19 |
+
# Предсказываем результат
|
20 |
+
score = round(float(model.predict(text).item()), 5) # Приводим результат к числу с 5 знаками после запятой
|
21 |
+
return f"Оценка токсичности: {score}", score
|
22 |
|
23 |
+
# Примеры для демонстрации
|
24 |
+
examples = [
|
25 |
+
"Этот продукт ��росто великолепен, спасибо!",
|
26 |
+
"Ты ужасен, не могу терпеть твои комментарии!",
|
27 |
+
"Сегодня был хороший день, несмотря на небольшой дождь.",
|
28 |
+
"Твой проект провалился, и это только твоя вина.",
|
29 |
+
"Замечательная работа, вы молодцы!"
|
30 |
+
]
|
31 |
+
|
32 |
+
# Создаем интерфейс
|
33 |
demo = gr.Interface(
|
34 |
fn=predict_text, # Функция для предсказания
|
35 |
inputs=gr.Textbox(
|
36 |
label="Введите текст для проверки на токсичность", # Подпись для текстового поля
|
37 |
+
placeholder="Напишите комментарий для анализа", # Подсказка для ввода
|
38 |
+
lines=5, # Количество строк
|
39 |
+
interactive=True # Включаем интерактивность
|
|
|
|
|
|
|
|
|
40 |
),
|
41 |
+
outputs=[
|
42 |
+
gr.Textbox(
|
43 |
+
label="Результат анализа", # Подпись для вывода
|
44 |
+
placeholder="Оценка токсичности будет показана здесь", # Подсказка для вывода
|
45 |
+
),
|
46 |
+
gr.Slider(
|
47 |
+
label="Шкала токсичности", # Подпись шкалы
|
48 |
+
minimum=0.0,
|
49 |
+
maximum=1.0,
|
50 |
+
step=0.00001,
|
51 |
+
interactive=False, # Делаем слайдер только для вывода
|
52 |
+
)
|
53 |
+
],
|
54 |
examples=examples, # Примеры для пользователей
|
55 |
+
title="Toxicity Classification", # Заголовок
|
56 |
+
description="Введите текст, чтобы узнать его оценку токсичности (0 - не токсичный, 1 - максимально токсичный).", # Описание
|
57 |
+
live=True, # Автоматический запуск модели при изменении текста
|
58 |
)
|
59 |
|
60 |
+
# Запуск приложения
|
61 |
+
demo.launch()
|
model/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .model import model
|
|
|
|
model/__pycache__/__init__.cpython-312.pyc
DELETED
Binary file (176 Bytes)
|
|
model/__pycache__/model.cpython-312.pyc
DELETED
Binary file (3.69 kB)
|
|
model/model.py
DELETED
@@ -1,68 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from transformers import BertTokenizer, BertModel
|
3 |
-
import torch.nn as nn
|
4 |
-
|
5 |
-
class PowerfulBinaryTextClassifier(nn.Module):
|
6 |
-
def __init__(self, model_name, lstm_hidden_size=256, num_layers=3, dropout_rate=0.2):
|
7 |
-
super(PowerfulBinaryTextClassifier, self).__init__()
|
8 |
-
self.bert = BertModel.from_pretrained(model_name)
|
9 |
-
|
10 |
-
# Добавляем несколько LSTM слоев с большим размером скрытого состояния
|
11 |
-
self.lstm = nn.LSTM(input_size=self.bert.config.hidden_size,
|
12 |
-
hidden_size=lstm_hidden_size,
|
13 |
-
num_layers=num_layers,
|
14 |
-
batch_first=True,
|
15 |
-
bidirectional=True,
|
16 |
-
dropout=dropout_rate if num_layers > 1 else 0)
|
17 |
-
|
18 |
-
# Полносвязный блок с увеличенным количеством нейронов и слоев Dropout
|
19 |
-
self.fc = nn.Sequential(
|
20 |
-
nn.Linear(lstm_hidden_size * 2, 2), # полносвязный слой
|
21 |
-
nn.Sigmoid()
|
22 |
-
)
|
23 |
-
|
24 |
-
self.tokenizer = BertTokenizer.from_pretrained(model_name) # Инициализация токенизатора
|
25 |
-
|
26 |
-
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
27 |
-
|
28 |
-
def forward(self, input_ids, attention_mask):
|
29 |
-
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
30 |
-
bert_outputs = outputs.last_hidden_state # (batch_size, sequence_length, hidden_size)
|
31 |
-
|
32 |
-
# Применяем LSTM
|
33 |
-
lstm_out, _ = self.lstm(bert_outputs) # (batch_size, sequence_length, lstm_hidden_size * 2)
|
34 |
-
|
35 |
-
# Берем выход последнего временного шага для классификации
|
36 |
-
last_time_step = lstm_out[:, -1, :] # (batch_size, lstm_hidden_size * 2)
|
37 |
-
|
38 |
-
logits = self.fc(last_time_step) # Применяем полносвязный блок
|
39 |
-
|
40 |
-
logits[:, 1] -= 0.995 # Умножаем логит для выбранного класса
|
41 |
-
|
42 |
-
return logits # Возвращаем логиты для двух классов
|
43 |
-
|
44 |
-
def predict(self, text):
|
45 |
-
self.to(self.device) # Переносим модель на выбранное устройство
|
46 |
-
|
47 |
-
# Токенизация текста
|
48 |
-
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256)
|
49 |
-
input_ids = inputs['input_ids'].to(self.device) # Переносим на устройство
|
50 |
-
attention_mask = inputs['attention_mask'].to(self.device) # Переносим на устройство
|
51 |
-
|
52 |
-
# Получение предсказания
|
53 |
-
self.eval() # Переключаем модель в режим оценки
|
54 |
-
with torch.no_grad():
|
55 |
-
preds = self(input_ids, attention_mask) # Получаем логиты
|
56 |
-
|
57 |
-
# Возвращаем индекс класса с наибольшей вероятностью
|
58 |
-
return torch.argmax(preds, dim=1).item() # Возвращаем индекс класса
|
59 |
-
|
60 |
-
def load_weights(self, filepath):
|
61 |
-
# Загрузка весов модели
|
62 |
-
self.load_state_dict(torch.load(filepath, map_location=self.device, weights_only=True))
|
63 |
-
|
64 |
-
# Пример инициализации модели
|
65 |
-
model_name = "DeepPavlov/rubert-base-cased"
|
66 |
-
model = PowerfulBinaryTextClassifier(model_name)
|
67 |
-
|
68 |
-
model.load_weights('model.pth')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.pth → models/model_epoch_10.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0171b7a667399691c6a0c9ca1ac01a2455cdd294db6c4c97a3f5717a75e97f38
|
3 |
+
size 2793946
|
models/model_epoch_20.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:efd2666876a870ebb7dd175e9d14e5f4a7f1c7c3adc577ad3e148480e7906326
|
3 |
+
size 2793946
|
models/model_epoch_30.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef89dab538f52b47be6a9525e672c6f8e03dcbff6bb4b6b3b26c3b824adec089
|
3 |
+
size 2793946
|
models/model_epoch_40.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:26e450b74eb7fb8fbbe6e64490aa71374fa6c9c1343bed927c238475e1920b17
|
3 |
+
size 2793946
|
models/model_epoch_50.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:463cef09e540c6cfc18dfbe817d616403d097b0e30f5f3f1ca2e6c4cdf54d4ad
|
3 |
+
size 2793946
|
models/training.log
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2024-11-28 15:01:21,773 - INFO - Training started
|
2 |
+
2024-11-28 15:01:21,774 - INFO - Training set size: 11529, Test set size: 2883
|
3 |
+
2024-11-28 15:01:21,774 - INFO - Epoch 1/50 started
|
4 |
+
2024-11-28 15:01:23,378 - INFO - Average loss for epoch 1: 0.3625
|
5 |
+
2024-11-28 15:01:23,511 - INFO - Test accuracy after epoch 1: 90.36%
|
6 |
+
2024-11-28 15:01:23,527 - INFO - Model saved: models\model_epoch_1.pt
|
7 |
+
2024-11-28 15:01:23,527 - INFO - Epoch 2/50 started
|
8 |
+
2024-11-28 15:01:25,024 - INFO - Average loss for epoch 2: 0.3320
|
9 |
+
2024-11-28 15:01:25,131 - INFO - Test accuracy after epoch 2: 90.11%
|
10 |
+
2024-11-28 15:01:25,135 - INFO - Model saved: models\model_epoch_2.pt
|
11 |
+
2024-11-28 15:01:25,135 - INFO - Epoch 3/50 started
|
12 |
+
2024-11-28 15:01:26,611 - INFO - Average loss for epoch 3: 0.3238
|
13 |
+
2024-11-28 15:01:26,708 - INFO - Test accuracy after epoch 3: 90.01%
|
14 |
+
2024-11-28 15:01:26,712 - INFO - Model saved: models\model_epoch_3.pt
|
15 |
+
2024-11-28 15:01:26,712 - INFO - Epoch 4/50 started
|
16 |
+
2024-11-28 15:01:28,205 - INFO - Average loss for epoch 4: 0.3089
|
17 |
+
2024-11-28 15:01:28,324 - INFO - Test accuracy after epoch 4: 89.94%
|
18 |
+
2024-11-28 15:01:28,328 - INFO - Model saved: models\model_epoch_4.pt
|
19 |
+
2024-11-28 15:01:28,328 - INFO - Epoch 5/50 started
|
20 |
+
2024-11-28 15:01:29,770 - INFO - Average loss for epoch 5: 0.2921
|
21 |
+
2024-11-28 15:01:29,896 - INFO - Test accuracy after epoch 5: 89.70%
|
22 |
+
2024-11-28 15:01:29,900 - INFO - Model saved: models\model_epoch_5.pt
|
23 |
+
2024-11-28 15:01:29,900 - INFO - Epoch 6/50 started
|
24 |
+
2024-11-28 15:01:31,476 - INFO - Average loss for epoch 6: 0.2802
|
25 |
+
2024-11-28 15:01:31,583 - INFO - Test accuracy after epoch 6: 89.77%
|
26 |
+
2024-11-28 15:01:31,587 - INFO - Model saved: models\model_epoch_6.pt
|
27 |
+
2024-11-28 15:01:31,587 - INFO - Epoch 7/50 started
|
28 |
+
2024-11-28 15:01:33,168 - INFO - Average loss for epoch 7: 0.2598
|
29 |
+
2024-11-28 15:01:33,291 - INFO - Test accuracy after epoch 7: 88.62%
|
30 |
+
2024-11-28 15:01:33,295 - INFO - Model saved: models\model_epoch_7.pt
|
31 |
+
2024-11-28 15:01:33,295 - INFO - Epoch 8/50 started
|
32 |
+
2024-11-28 15:01:34,921 - INFO - Average loss for epoch 8: 0.2509
|
33 |
+
2024-11-28 15:01:35,040 - INFO - Test accuracy after epoch 8: 88.97%
|
34 |
+
2024-11-28 15:01:35,044 - INFO - Model saved: models\model_epoch_8.pt
|
35 |
+
2024-11-28 15:01:35,044 - INFO - Epoch 9/50 started
|
36 |
+
2024-11-28 15:01:36,579 - INFO - Average loss for epoch 9: 0.2382
|
37 |
+
2024-11-28 15:01:36,691 - INFO - Test accuracy after epoch 9: 88.87%
|
38 |
+
2024-11-28 15:01:36,695 - INFO - Model saved: models\model_epoch_9.pt
|
39 |
+
2024-11-28 15:01:36,695 - INFO - Epoch 10/50 started
|
40 |
+
2024-11-28 15:01:38,227 - INFO - Average loss for epoch 10: 0.2180
|
41 |
+
2024-11-28 15:01:38,341 - INFO - Test accuracy after epoch 10: 88.73%
|
42 |
+
2024-11-28 15:01:38,345 - INFO - Model saved: models\model_epoch_10.pt
|
43 |
+
2024-11-28 15:01:38,345 - INFO - Epoch 11/50 started
|
44 |
+
2024-11-28 15:01:39,881 - INFO - Average loss for epoch 11: 0.2175
|
45 |
+
2024-11-28 15:01:40,020 - INFO - Test accuracy after epoch 11: 89.04%
|
46 |
+
2024-11-28 15:01:40,025 - INFO - Model saved: models\model_epoch_11.pt
|
47 |
+
2024-11-28 15:01:40,025 - INFO - Epoch 12/50 started
|
48 |
+
2024-11-28 15:01:41,607 - INFO - Average loss for epoch 12: 0.2038
|
49 |
+
2024-11-28 15:01:41,715 - INFO - Test accuracy after epoch 12: 88.80%
|
50 |
+
2024-11-28 15:01:41,719 - INFO - Model saved: models\model_epoch_12.pt
|
51 |
+
2024-11-28 15:01:41,719 - INFO - Epoch 13/50 started
|
52 |
+
2024-11-28 15:01:43,369 - INFO - Average loss for epoch 13: 0.2036
|
53 |
+
2024-11-28 15:01:43,484 - INFO - Test accuracy after epoch 13: 89.28%
|
54 |
+
2024-11-28 15:01:43,491 - INFO - Model saved: models\model_epoch_13.pt
|
55 |
+
2024-11-28 15:01:43,491 - INFO - Epoch 14/50 started
|
56 |
+
2024-11-28 15:01:45,156 - INFO - Average loss for epoch 14: 0.1987
|
57 |
+
2024-11-28 15:01:45,267 - INFO - Test accuracy after epoch 14: 89.18%
|
58 |
+
2024-11-28 15:01:45,271 - INFO - Model saved: models\model_epoch_14.pt
|
59 |
+
2024-11-28 15:01:45,271 - INFO - Epoch 15/50 started
|
60 |
+
2024-11-28 15:01:46,926 - INFO - Average loss for epoch 15: 0.1865
|
61 |
+
2024-11-28 15:01:47,042 - INFO - Test accuracy after epoch 15: 88.52%
|
62 |
+
2024-11-28 15:01:47,046 - INFO - Model saved: models\model_epoch_15.pt
|
63 |
+
2024-11-28 15:01:47,047 - INFO - Epoch 16/50 started
|
64 |
+
2024-11-28 15:01:48,560 - INFO - Average loss for epoch 16: 0.1825
|
65 |
+
2024-11-28 15:01:48,667 - INFO - Test accuracy after epoch 16: 88.94%
|
66 |
+
2024-11-28 15:01:48,671 - INFO - Model saved: models\model_epoch_16.pt
|
67 |
+
2024-11-28 15:01:48,671 - INFO - Epoch 17/50 started
|
68 |
+
2024-11-28 15:01:50,298 - INFO - Average loss for epoch 17: 0.1839
|
69 |
+
2024-11-28 15:01:50,418 - INFO - Test accuracy after epoch 17: 89.04%
|
70 |
+
2024-11-28 15:01:50,423 - INFO - Model saved: models\model_epoch_17.pt
|
71 |
+
2024-11-28 15:01:50,423 - INFO - Epoch 18/50 started
|
72 |
+
2024-11-28 15:01:51,971 - INFO - Average loss for epoch 18: 0.1800
|
73 |
+
2024-11-28 15:01:52,082 - INFO - Test accuracy after epoch 18: 89.21%
|
74 |
+
2024-11-28 15:01:52,086 - INFO - Model saved: models\model_epoch_18.pt
|
75 |
+
2024-11-28 15:01:52,086 - INFO - Epoch 19/50 started
|
76 |
+
2024-11-28 15:01:53,641 - INFO - Average loss for epoch 19: 0.1724
|
77 |
+
2024-11-28 15:01:53,752 - INFO - Test accuracy after epoch 19: 88.83%
|
78 |
+
2024-11-28 15:01:53,756 - INFO - Model saved: models\model_epoch_19.pt
|
79 |
+
2024-11-28 15:01:53,756 - INFO - Epoch 20/50 started
|
80 |
+
2024-11-28 15:01:55,336 - INFO - Average loss for epoch 20: 0.1795
|
81 |
+
2024-11-28 15:01:55,476 - INFO - Test accuracy after epoch 20: 88.66%
|
82 |
+
2024-11-28 15:01:55,480 - INFO - Model saved: models\model_epoch_20.pt
|
83 |
+
2024-11-28 15:01:55,480 - INFO - Epoch 21/50 started
|
84 |
+
2024-11-28 15:01:57,055 - INFO - Average loss for epoch 21: 0.1759
|
85 |
+
2024-11-28 15:01:57,178 - INFO - Test accuracy after epoch 21: 88.87%
|
86 |
+
2024-11-28 15:01:57,182 - INFO - Model saved: models\model_epoch_21.pt
|
87 |
+
2024-11-28 15:01:57,182 - INFO - Epoch 22/50 started
|
88 |
+
2024-11-28 15:01:58,830 - INFO - Average loss for epoch 22: 0.1708
|
89 |
+
2024-11-28 15:01:58,957 - INFO - Test accuracy after epoch 22: 88.69%
|
90 |
+
2024-11-28 15:01:58,961 - INFO - Model saved: models\model_epoch_22.pt
|
91 |
+
2024-11-28 15:01:58,961 - INFO - Epoch 23/50 started
|
92 |
+
2024-11-28 15:02:00,613 - INFO - Average loss for epoch 23: 0.1746
|
93 |
+
2024-11-28 15:02:00,715 - INFO - Test accuracy after epoch 23: 88.76%
|
94 |
+
2024-11-28 15:02:00,720 - INFO - Model saved: models\model_epoch_23.pt
|
95 |
+
2024-11-28 15:02:00,720 - INFO - Epoch 24/50 started
|
96 |
+
2024-11-28 15:02:02,331 - INFO - Average loss for epoch 24: 0.1745
|
97 |
+
2024-11-28 15:02:02,445 - INFO - Test accuracy after epoch 24: 87.96%
|
98 |
+
2024-11-28 15:02:02,449 - INFO - Model saved: models\model_epoch_24.pt
|
99 |
+
2024-11-28 15:02:02,449 - INFO - Epoch 25/50 started
|
100 |
+
2024-11-28 15:02:03,994 - INFO - Average loss for epoch 25: 0.1769
|
101 |
+
2024-11-28 15:02:04,107 - INFO - Test accuracy after epoch 25: 88.66%
|
102 |
+
2024-11-28 15:02:04,111 - INFO - Model saved: models\model_epoch_25.pt
|
103 |
+
2024-11-28 15:02:04,111 - INFO - Epoch 26/50 started
|
104 |
+
2024-11-28 15:02:05,674 - INFO - Average loss for epoch 26: 0.1730
|
105 |
+
2024-11-28 15:02:05,778 - INFO - Test accuracy after epoch 26: 88.66%
|
106 |
+
2024-11-28 15:02:05,782 - INFO - Model saved: models\model_epoch_26.pt
|
107 |
+
2024-11-28 15:02:05,782 - INFO - Epoch 27/50 started
|
108 |
+
2024-11-28 15:02:07,327 - INFO - Average loss for epoch 27: 0.1621
|
109 |
+
2024-11-28 15:02:07,430 - INFO - Test accuracy after epoch 27: 88.35%
|
110 |
+
2024-11-28 15:02:07,434 - INFO - Model saved: models\model_epoch_27.pt
|
111 |
+
2024-11-28 15:02:07,434 - INFO - Epoch 28/50 started
|
112 |
+
2024-11-28 15:02:08,943 - INFO - Average loss for epoch 28: 0.1720
|
113 |
+
2024-11-28 15:02:09,064 - INFO - Test accuracy after epoch 28: 87.72%
|
114 |
+
2024-11-28 15:02:09,068 - INFO - Model saved: models\model_epoch_28.pt
|
115 |
+
2024-11-28 15:02:09,068 - INFO - Epoch 29/50 started
|
116 |
+
2024-11-28 15:02:10,699 - INFO - Average loss for epoch 29: 0.1615
|
117 |
+
2024-11-28 15:02:10,808 - INFO - Test accuracy after epoch 29: 89.21%
|
118 |
+
2024-11-28 15:02:10,812 - INFO - Model saved: models\model_epoch_29.pt
|
119 |
+
2024-11-28 15:02:10,812 - INFO - Epoch 30/50 started
|
120 |
+
2024-11-28 15:02:12,426 - INFO - Average loss for epoch 30: 0.1733
|
121 |
+
2024-11-28 15:02:12,539 - INFO - Test accuracy after epoch 30: 89.21%
|
122 |
+
2024-11-28 15:02:12,543 - INFO - Model saved: models\model_epoch_30.pt
|
123 |
+
2024-11-28 15:02:12,543 - INFO - Epoch 31/50 started
|
124 |
+
2024-11-28 15:02:14,079 - INFO - Average loss for epoch 31: 0.1624
|
125 |
+
2024-11-28 15:02:14,189 - INFO - Test accuracy after epoch 31: 88.59%
|
126 |
+
2024-11-28 15:02:14,194 - INFO - Model saved: models\model_epoch_31.pt
|
127 |
+
2024-11-28 15:02:14,194 - INFO - Epoch 32/50 started
|
128 |
+
2024-11-28 15:02:15,776 - INFO - Average loss for epoch 32: 0.1611
|
129 |
+
2024-11-28 15:02:15,895 - INFO - Test accuracy after epoch 32: 89.11%
|
130 |
+
2024-11-28 15:02:15,899 - INFO - Model saved: models\model_epoch_32.pt
|
131 |
+
2024-11-28 15:02:15,899 - INFO - Epoch 33/50 started
|
132 |
+
2024-11-28 15:02:17,411 - INFO - Average loss for epoch 33: 0.1609
|
133 |
+
2024-11-28 15:02:17,517 - INFO - Test accuracy after epoch 33: 89.32%
|
134 |
+
2024-11-28 15:02:17,522 - INFO - Model saved: models\model_epoch_33.pt
|
135 |
+
2024-11-28 15:02:17,522 - INFO - Epoch 34/50 started
|
136 |
+
2024-11-28 15:02:19,038 - INFO - Average loss for epoch 34: 0.1652
|
137 |
+
2024-11-28 15:02:19,166 - INFO - Test accuracy after epoch 34: 88.87%
|
138 |
+
2024-11-28 15:02:19,170 - INFO - Model saved: models\model_epoch_34.pt
|
139 |
+
2024-11-28 15:02:19,170 - INFO - Epoch 35/50 started
|
140 |
+
2024-11-28 15:02:20,770 - INFO - Average loss for epoch 35: 0.1793
|
141 |
+
2024-11-28 15:02:20,886 - INFO - Test accuracy after epoch 35: 88.66%
|
142 |
+
2024-11-28 15:02:20,891 - INFO - Model saved: models\model_epoch_35.pt
|
143 |
+
2024-11-28 15:02:20,891 - INFO - Epoch 36/50 started
|
144 |
+
2024-11-28 15:02:22,635 - INFO - Average loss for epoch 36: 0.1607
|
145 |
+
2024-11-28 15:02:22,748 - INFO - Test accuracy after epoch 36: 88.90%
|
146 |
+
2024-11-28 15:02:22,753 - INFO - Model saved: models\model_epoch_36.pt
|
147 |
+
2024-11-28 15:02:22,753 - INFO - Epoch 37/50 started
|
148 |
+
2024-11-28 15:02:24,389 - INFO - Average loss for epoch 37: 0.1572
|
149 |
+
2024-11-28 15:02:24,527 - INFO - Test accuracy after epoch 37: 88.90%
|
150 |
+
2024-11-28 15:02:24,534 - INFO - Model saved: models\model_epoch_37.pt
|
151 |
+
2024-11-28 15:02:24,534 - INFO - Epoch 38/50 started
|
152 |
+
2024-11-28 15:02:26,082 - INFO - Average loss for epoch 38: 0.1631
|
153 |
+
2024-11-28 15:02:26,181 - INFO - Test accuracy after epoch 38: 88.17%
|
154 |
+
2024-11-28 15:02:26,185 - INFO - Model saved: models\model_epoch_38.pt
|
155 |
+
2024-11-28 15:02:26,185 - INFO - Epoch 39/50 started
|
156 |
+
2024-11-28 15:02:27,680 - INFO - Average loss for epoch 39: 0.1643
|
157 |
+
2024-11-28 15:02:27,787 - INFO - Test accuracy after epoch 39: 88.62%
|
158 |
+
2024-11-28 15:02:27,791 - INFO - Model saved: models\model_epoch_39.pt
|
159 |
+
2024-11-28 15:02:27,791 - INFO - Epoch 40/50 started
|
160 |
+
2024-11-28 15:02:29,421 - INFO - Average loss for epoch 40: 0.1578
|
161 |
+
2024-11-28 15:02:29,538 - INFO - Test accuracy after epoch 40: 87.96%
|
162 |
+
2024-11-28 15:02:29,542 - INFO - Model saved: models\model_epoch_40.pt
|
163 |
+
2024-11-28 15:02:29,542 - INFO - Epoch 41/50 started
|
164 |
+
2024-11-28 15:02:31,150 - INFO - Average loss for epoch 41: 0.1579
|
165 |
+
2024-11-28 15:02:31,270 - INFO - Test accuracy after epoch 41: 88.69%
|
166 |
+
2024-11-28 15:02:31,275 - INFO - Model saved: models\model_epoch_41.pt
|
167 |
+
2024-11-28 15:02:31,275 - INFO - Epoch 42/50 started
|
168 |
+
2024-11-28 15:02:33,082 - INFO - Average loss for epoch 42: 0.1575
|
169 |
+
2024-11-28 15:02:33,226 - INFO - Test accuracy after epoch 42: 88.52%
|
170 |
+
2024-11-28 15:02:33,231 - INFO - Model saved: models\model_epoch_42.pt
|
171 |
+
2024-11-28 15:02:33,232 - INFO - Epoch 43/50 started
|
172 |
+
2024-11-28 15:02:34,811 - INFO - Average loss for epoch 43: 0.1574
|
173 |
+
2024-11-28 15:02:34,911 - INFO - Test accuracy after epoch 43: 89.11%
|
174 |
+
2024-11-28 15:02:34,916 - INFO - Model saved: models\model_epoch_43.pt
|
175 |
+
2024-11-28 15:02:34,916 - INFO - Epoch 44/50 started
|
176 |
+
2024-11-28 15:02:36,448 - INFO - Average loss for epoch 44: 0.1642
|
177 |
+
2024-11-28 15:02:36,553 - INFO - Test accuracy after epoch 44: 89.04%
|
178 |
+
2024-11-28 15:02:36,557 - INFO - Model saved: models\model_epoch_44.pt
|
179 |
+
2024-11-28 15:02:36,557 - INFO - Epoch 45/50 started
|
180 |
+
2024-11-28 15:02:38,065 - INFO - Average loss for epoch 45: 0.1583
|
181 |
+
2024-11-28 15:02:38,181 - INFO - Test accuracy after epoch 45: 88.41%
|
182 |
+
2024-11-28 15:02:38,185 - INFO - Model saved: models\model_epoch_45.pt
|
183 |
+
2024-11-28 15:02:38,185 - INFO - Epoch 46/50 started
|
184 |
+
2024-11-28 15:02:39,689 - INFO - Average loss for epoch 46: 0.1613
|
185 |
+
2024-11-28 15:02:39,805 - INFO - Test accuracy after epoch 46: 89.21%
|
186 |
+
2024-11-28 15:02:39,809 - INFO - Model saved: models\model_epoch_46.pt
|
187 |
+
2024-11-28 15:02:39,809 - INFO - Epoch 47/50 started
|
188 |
+
2024-11-28 15:02:41,364 - INFO - Average loss for epoch 47: 0.1598
|
189 |
+
2024-11-28 15:02:41,467 - INFO - Test accuracy after epoch 47: 88.55%
|
190 |
+
2024-11-28 15:02:41,471 - INFO - Model saved: models\model_epoch_47.pt
|
191 |
+
2024-11-28 15:02:41,471 - INFO - Epoch 48/50 started
|
192 |
+
2024-11-28 15:02:42,977 - INFO - Average loss for epoch 48: 0.1697
|
193 |
+
2024-11-28 15:02:43,086 - INFO - Test accuracy after epoch 48: 88.48%
|
194 |
+
2024-11-28 15:02:43,090 - INFO - Model saved: models\model_epoch_48.pt
|
195 |
+
2024-11-28 15:02:43,090 - INFO - Epoch 49/50 started
|
196 |
+
2024-11-28 15:02:44,596 - INFO - Average loss for epoch 49: 0.1618
|
197 |
+
2024-11-28 15:02:44,717 - INFO - Test accuracy after epoch 49: 89.25%
|
198 |
+
2024-11-28 15:02:44,721 - INFO - Model saved: models\model_epoch_49.pt
|
199 |
+
2024-11-28 15:02:44,721 - INFO - Epoch 50/50 started
|
200 |
+
2024-11-28 15:02:46,298 - INFO - Average loss for epoch 50: 0.1642
|
201 |
+
2024-11-28 15:02:46,414 - INFO - Test accuracy after epoch 50: 88.97%
|
202 |
+
2024-11-28 15:02:46,418 - INFO - Model saved: models\model_epoch_50.pt
|
203 |
+
2024-11-28 15:02:46,418 - INFO - Training complete.
|
204 |
+
2024-11-28 15:02:46,418 - INFO - Training complete
|
notebooks/data.ipynb
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import pandas as pd\n",
|
10 |
+
"import numpy as np\n",
|
11 |
+
"\n",
|
12 |
+
"import os\n",
|
13 |
+
"os.chdir('..')"
|
14 |
+
]
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"cell_type": "code",
|
18 |
+
"execution_count": 2,
|
19 |
+
"metadata": {},
|
20 |
+
"outputs": [
|
21 |
+
{
|
22 |
+
"data": {
|
23 |
+
"text/plain": [
|
24 |
+
"['data', 'model', 'notebooks', 'test.ipynb', 'Untitled-1.ipynb', 'вф']"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
"execution_count": 2,
|
28 |
+
"metadata": {},
|
29 |
+
"output_type": "execute_result"
|
30 |
+
}
|
31 |
+
],
|
32 |
+
"source": [
|
33 |
+
"os.listdir('.')"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"cell_type": "code",
|
38 |
+
"execution_count": 3,
|
39 |
+
"metadata": {},
|
40 |
+
"outputs": [
|
41 |
+
{
|
42 |
+
"data": {
|
43 |
+
"text/html": [
|
44 |
+
"<div>\n",
|
45 |
+
"<style scoped>\n",
|
46 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
47 |
+
" vertical-align: middle;\n",
|
48 |
+
" }\n",
|
49 |
+
"\n",
|
50 |
+
" .dataframe tbody tr th {\n",
|
51 |
+
" vertical-align: top;\n",
|
52 |
+
" }\n",
|
53 |
+
"\n",
|
54 |
+
" .dataframe thead th {\n",
|
55 |
+
" text-align: right;\n",
|
56 |
+
" }\n",
|
57 |
+
"</style>\n",
|
58 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
59 |
+
" <thead>\n",
|
60 |
+
" <tr style=\"text-align: right;\">\n",
|
61 |
+
" <th></th>\n",
|
62 |
+
" <th>comment</th>\n",
|
63 |
+
" <th>toxic</th>\n",
|
64 |
+
" </tr>\n",
|
65 |
+
" </thead>\n",
|
66 |
+
" <tbody>\n",
|
67 |
+
" <tr>\n",
|
68 |
+
" <th>0</th>\n",
|
69 |
+
" <td>Верблюдов-то за что? Дебилы, бл...</td>\n",
|
70 |
+
" <td>1.0</td>\n",
|
71 |
+
" </tr>\n",
|
72 |
+
" <tr>\n",
|
73 |
+
" <th>1</th>\n",
|
74 |
+
" <td>Хохлы, это отдушина затюканого россиянина, мол...</td>\n",
|
75 |
+
" <td>1.0</td>\n",
|
76 |
+
" </tr>\n",
|
77 |
+
" <tr>\n",
|
78 |
+
" <th>2</th>\n",
|
79 |
+
" <td>Собаке - собачья смерть</td>\n",
|
80 |
+
" <td>1.0</td>\n",
|
81 |
+
" </tr>\n",
|
82 |
+
" <tr>\n",
|
83 |
+
" <th>3</th>\n",
|
84 |
+
" <td>Страницу обнови, дебил. Это тоже не оскорблени...</td>\n",
|
85 |
+
" <td>1.0</td>\n",
|
86 |
+
" </tr>\n",
|
87 |
+
" <tr>\n",
|
88 |
+
" <th>4</th>\n",
|
89 |
+
" <td>тебя не убедил 6-страничный пдф в том, что Скр...</td>\n",
|
90 |
+
" <td>1.0</td>\n",
|
91 |
+
" </tr>\n",
|
92 |
+
" <tr>\n",
|
93 |
+
" <th>...</th>\n",
|
94 |
+
" <td>...</td>\n",
|
95 |
+
" <td>...</td>\n",
|
96 |
+
" </tr>\n",
|
97 |
+
" <tr>\n",
|
98 |
+
" <th>14407</th>\n",
|
99 |
+
" <td>Вонючий совковый скот прибежал и ноет. А вот и...</td>\n",
|
100 |
+
" <td>1.0</td>\n",
|
101 |
+
" </tr>\n",
|
102 |
+
" <tr>\n",
|
103 |
+
" <th>14408</th>\n",
|
104 |
+
" <td>А кого любить? Гоблина тупорылого что-ли? Или ...</td>\n",
|
105 |
+
" <td>1.0</td>\n",
|
106 |
+
" </tr>\n",
|
107 |
+
" <tr>\n",
|
108 |
+
" <th>14409</th>\n",
|
109 |
+
" <td>Посмотрел Утомленных солнцем 2. И оказалось, ч...</td>\n",
|
110 |
+
" <td>0.0</td>\n",
|
111 |
+
" </tr>\n",
|
112 |
+
" <tr>\n",
|
113 |
+
" <th>14410</th>\n",
|
114 |
+
" <td>КРЫМОТРЕД НАРУШАЕТ ПРАВИЛА РАЗДЕЛА Т.К В НЕМ Н...</td>\n",
|
115 |
+
" <td>1.0</td>\n",
|
116 |
+
" </tr>\n",
|
117 |
+
" <tr>\n",
|
118 |
+
" <th>14411</th>\n",
|
119 |
+
" <td>До сих пор пересматриваю его видео. Орамбо кст...</td>\n",
|
120 |
+
" <td>0.0</td>\n",
|
121 |
+
" </tr>\n",
|
122 |
+
" </tbody>\n",
|
123 |
+
"</table>\n",
|
124 |
+
"<p>14412 rows × 2 columns</p>\n",
|
125 |
+
"</div>"
|
126 |
+
],
|
127 |
+
"text/plain": [
|
128 |
+
" comment toxic\n",
|
129 |
+
"0 Верблюдов-то за что? Дебилы, бл... 1.0\n",
|
130 |
+
"1 Хохлы, это отдушина затюканого россиянина, мол... 1.0\n",
|
131 |
+
"2 Собаке - собачья смерть 1.0\n",
|
132 |
+
"3 Страницу обнови, дебил. Это тоже не оскорблени... 1.0\n",
|
133 |
+
"4 тебя не убедил 6-страничный пдф в том, что Скр... 1.0\n",
|
134 |
+
"... ... ...\n",
|
135 |
+
"14407 Вонючий совковый скот прибежал и ноет. А вот и... 1.0\n",
|
136 |
+
"14408 А кого любить? Гоблина тупорылого что-ли? Или ... 1.0\n",
|
137 |
+
"14409 Посмотрел Утомленных солнцем 2. И оказалось, ч... 0.0\n",
|
138 |
+
"14410 КРЫМОТРЕД НАРУШАЕТ ПРАВИЛА РАЗДЕЛА Т.К В НЕМ Н... 1.0\n",
|
139 |
+
"14411 До сих пор пересматриваю его видео. Орамбо кст... 0.0\n",
|
140 |
+
"\n",
|
141 |
+
"[14412 rows x 2 columns]"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
"execution_count": 3,
|
145 |
+
"metadata": {},
|
146 |
+
"output_type": "execute_result"
|
147 |
+
}
|
148 |
+
],
|
149 |
+
"source": [
|
150 |
+
"df = pd.read_csv('data/data.csv')\n",
|
151 |
+
"df['comment'] = df['comment'].str.replace('\\n', '', regex=False)\n",
|
152 |
+
"df"
|
153 |
+
]
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "code",
|
157 |
+
"execution_count": 4,
|
158 |
+
"metadata": {},
|
159 |
+
"outputs": [
|
160 |
+
{
|
161 |
+
"data": {
|
162 |
+
"text/plain": [
|
163 |
+
"toxic\n",
|
164 |
+
"0.0 9586\n",
|
165 |
+
"1.0 4826\n",
|
166 |
+
"Name: count, dtype: int64"
|
167 |
+
]
|
168 |
+
},
|
169 |
+
"execution_count": 4,
|
170 |
+
"metadata": {},
|
171 |
+
"output_type": "execute_result"
|
172 |
+
}
|
173 |
+
],
|
174 |
+
"source": [
|
175 |
+
"df['toxic'].value_counts()"
|
176 |
+
]
|
177 |
+
}
|
178 |
+
],
|
179 |
+
"metadata": {
|
180 |
+
"kernelspec": {
|
181 |
+
"display_name": "Python 3",
|
182 |
+
"language": "python",
|
183 |
+
"name": "python3"
|
184 |
+
},
|
185 |
+
"language_info": {
|
186 |
+
"codemirror_mode": {
|
187 |
+
"name": "ipython",
|
188 |
+
"version": 3
|
189 |
+
},
|
190 |
+
"file_extension": ".py",
|
191 |
+
"mimetype": "text/x-python",
|
192 |
+
"name": "python",
|
193 |
+
"nbconvert_exporter": "python",
|
194 |
+
"pygments_lexer": "ipython3",
|
195 |
+
"version": "3.12.4"
|
196 |
+
}
|
197 |
+
},
|
198 |
+
"nbformat": 4,
|
199 |
+
"nbformat_minor": 2
|
200 |
+
}
|
notebooks/dataset.ipynb
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"a:\\python\\312\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
14 |
+
"a:\\python\\312\\Lib\\site-packages\\transformers\\utils\\hub.py:128: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n",
|
15 |
+
" warnings.warn(\n"
|
16 |
+
]
|
17 |
+
}
|
18 |
+
],
|
19 |
+
"source": [
|
20 |
+
"import os\n",
|
21 |
+
"os.chdir('..')\n",
|
22 |
+
"\n",
|
23 |
+
"import torch\n",
|
24 |
+
"from scr.sbert import sbert\n",
|
25 |
+
"from scr.dataset import TextDataset"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": 2,
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [
|
33 |
+
{
|
34 |
+
"name": "stderr",
|
35 |
+
"output_type": "stream",
|
36 |
+
"text": [
|
37 |
+
"C:\\Users\\timof\\AppData\\Local\\Temp\\ipykernel_12548\\753396127.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
38 |
+
" dataset = torch.load('data/dataset.pt')\n"
|
39 |
+
]
|
40 |
+
}
|
41 |
+
],
|
42 |
+
"source": [
|
43 |
+
"dataset = torch.load('data/dataset.pt')"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "code",
|
48 |
+
"execution_count": 3,
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [
|
51 |
+
{
|
52 |
+
"name": "stdout",
|
53 |
+
"output_type": "stream",
|
54 |
+
"text": [
|
55 |
+
"Размер датасета: 14412\n",
|
56 |
+
"Тексты: tensor([[-0.7709, 0.2756, -1.8136, ..., -0.1891, 0.6464, -0.0877],\n",
|
57 |
+
" [ 0.0737, 0.2665, -0.2466, ..., 0.1983, 0.9042, 0.7120],\n",
|
58 |
+
" [-0.4836, 0.2575, -0.3310, ..., -0.0648, 0.6074, -0.2436],\n",
|
59 |
+
" ...,\n",
|
60 |
+
" [ 0.5273, 0.2523, -0.4174, ..., -0.1361, 0.0777, 0.1805],\n",
|
61 |
+
" [-0.6573, 0.1075, -1.1338, ..., 0.0145, 0.0062, 0.1264],\n",
|
62 |
+
" [ 0.4965, 0.1897, -1.8090, ..., -0.0378, 0.2283, 0.6433]])\n",
|
63 |
+
"Метки: tensor([1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1., 0., 0., 0.],\n",
|
64 |
+
" dtype=torch.float64)\n"
|
65 |
+
]
|
66 |
+
}
|
67 |
+
],
|
68 |
+
"source": [
|
69 |
+
"from torch.utils.data import DataLoader\n",
|
70 |
+
"\n",
|
71 |
+
"# Проверяем размер\n",
|
72 |
+
"print(f\"Размер датасета: {len(dataset)}\")\n",
|
73 |
+
"\n",
|
74 |
+
"# Создаем DataLoader\n",
|
75 |
+
"dataloader = DataLoader(dataset, batch_size=16, shuffle=True)\n",
|
76 |
+
"\n",
|
77 |
+
"# Обрабатываем данные в батчах\n",
|
78 |
+
"for texts, labels in dataloader:\n",
|
79 |
+
" print(\"Тексты:\", texts)\n",
|
80 |
+
" print(\"Метки:\", labels)\n",
|
81 |
+
" break"
|
82 |
+
]
|
83 |
+
}
|
84 |
+
],
|
85 |
+
"metadata": {
|
86 |
+
"kernelspec": {
|
87 |
+
"display_name": "Python 3",
|
88 |
+
"language": "python",
|
89 |
+
"name": "python3"
|
90 |
+
},
|
91 |
+
"language_info": {
|
92 |
+
"codemirror_mode": {
|
93 |
+
"name": "ipython",
|
94 |
+
"version": 3
|
95 |
+
},
|
96 |
+
"file_extension": ".py",
|
97 |
+
"mimetype": "text/x-python",
|
98 |
+
"name": "python",
|
99 |
+
"nbconvert_exporter": "python",
|
100 |
+
"pygments_lexer": "ipython3",
|
101 |
+
"version": "3.12.4"
|
102 |
+
}
|
103 |
+
},
|
104 |
+
"nbformat": 4,
|
105 |
+
"nbformat_minor": 2
|
106 |
+
}
|
notebooks/evaluate.ipynb
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"a:\\python\\312\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
14 |
+
"a:\\python\\312\\Lib\\site-packages\\transformers\\utils\\hub.py:128: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n",
|
15 |
+
" warnings.warn(\n"
|
16 |
+
]
|
17 |
+
}
|
18 |
+
],
|
19 |
+
"source": [
|
20 |
+
"import os\n",
|
21 |
+
"os.chdir('..')\n",
|
22 |
+
"\n",
|
23 |
+
"import pandas as pd\n",
|
24 |
+
"import torch\n",
|
25 |
+
"from tqdm import tqdm\n",
|
26 |
+
"\n",
|
27 |
+
"tqdm.pandas()\n",
|
28 |
+
"\n",
|
29 |
+
"from scr.model import Model"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": 2,
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [
|
37 |
+
{
|
38 |
+
"name": "stderr",
|
39 |
+
"output_type": "stream",
|
40 |
+
"text": [
|
41 |
+
"C:\\Users\\timof\\AppData\\Local\\Temp\\ipykernel_19588\\3061579655.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
42 |
+
" model.load_state_dict(torch.load('models/model_epoch_50.pt'))\n"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"data": {
|
47 |
+
"text/plain": [
|
48 |
+
"<All keys matched successfully>"
|
49 |
+
]
|
50 |
+
},
|
51 |
+
"execution_count": 2,
|
52 |
+
"metadata": {},
|
53 |
+
"output_type": "execute_result"
|
54 |
+
}
|
55 |
+
],
|
56 |
+
"source": [
|
57 |
+
"model = Model()\n",
|
58 |
+
"model.eval()\n",
|
59 |
+
"model.load_state_dict(torch.load('models/model_epoch_50.pt'))"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": 3,
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [
|
67 |
+
{
|
68 |
+
"data": {
|
69 |
+
"text/html": [
|
70 |
+
"<div>\n",
|
71 |
+
"<style scoped>\n",
|
72 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
73 |
+
" vertical-align: middle;\n",
|
74 |
+
" }\n",
|
75 |
+
"\n",
|
76 |
+
" .dataframe tbody tr th {\n",
|
77 |
+
" vertical-align: top;\n",
|
78 |
+
" }\n",
|
79 |
+
"\n",
|
80 |
+
" .dataframe thead th {\n",
|
81 |
+
" text-align: right;\n",
|
82 |
+
" }\n",
|
83 |
+
"</style>\n",
|
84 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
85 |
+
" <thead>\n",
|
86 |
+
" <tr style=\"text-align: right;\">\n",
|
87 |
+
" <th></th>\n",
|
88 |
+
" <th>comment</th>\n",
|
89 |
+
" <th>toxic</th>\n",
|
90 |
+
" </tr>\n",
|
91 |
+
" </thead>\n",
|
92 |
+
" <tbody>\n",
|
93 |
+
" <tr>\n",
|
94 |
+
" <th>0</th>\n",
|
95 |
+
" <td>Верблюдов-то за что? Дебилы, бл...</td>\n",
|
96 |
+
" <td>1.0</td>\n",
|
97 |
+
" </tr>\n",
|
98 |
+
" <tr>\n",
|
99 |
+
" <th>1</th>\n",
|
100 |
+
" <td>Хохлы, это отдушина затюканого россиянина, мол...</td>\n",
|
101 |
+
" <td>1.0</td>\n",
|
102 |
+
" </tr>\n",
|
103 |
+
" <tr>\n",
|
104 |
+
" <th>2</th>\n",
|
105 |
+
" <td>Собаке - собачья смерть</td>\n",
|
106 |
+
" <td>1.0</td>\n",
|
107 |
+
" </tr>\n",
|
108 |
+
" <tr>\n",
|
109 |
+
" <th>3</th>\n",
|
110 |
+
" <td>Страницу обнови, дебил. Это тоже не оскорблени...</td>\n",
|
111 |
+
" <td>1.0</td>\n",
|
112 |
+
" </tr>\n",
|
113 |
+
" <tr>\n",
|
114 |
+
" <th>4</th>\n",
|
115 |
+
" <td>тебя не убедил 6-страничный пдф в том, что Скр...</td>\n",
|
116 |
+
" <td>1.0</td>\n",
|
117 |
+
" </tr>\n",
|
118 |
+
" <tr>\n",
|
119 |
+
" <th>...</th>\n",
|
120 |
+
" <td>...</td>\n",
|
121 |
+
" <td>...</td>\n",
|
122 |
+
" </tr>\n",
|
123 |
+
" <tr>\n",
|
124 |
+
" <th>14407</th>\n",
|
125 |
+
" <td>Вонючий совковый скот прибежал и ноет. А вот и...</td>\n",
|
126 |
+
" <td>1.0</td>\n",
|
127 |
+
" </tr>\n",
|
128 |
+
" <tr>\n",
|
129 |
+
" <th>14408</th>\n",
|
130 |
+
" <td>А кого любить? Гоблина тупорылого что-ли? Или ...</td>\n",
|
131 |
+
" <td>1.0</td>\n",
|
132 |
+
" </tr>\n",
|
133 |
+
" <tr>\n",
|
134 |
+
" <th>14409</th>\n",
|
135 |
+
" <td>Посмот��ел Утомленных солнцем 2. И оказалось, ч...</td>\n",
|
136 |
+
" <td>0.0</td>\n",
|
137 |
+
" </tr>\n",
|
138 |
+
" <tr>\n",
|
139 |
+
" <th>14410</th>\n",
|
140 |
+
" <td>КРЫМОТРЕД НАРУШАЕТ ПРАВИЛА РАЗДЕЛА Т.К В НЕМ Н...</td>\n",
|
141 |
+
" <td>1.0</td>\n",
|
142 |
+
" </tr>\n",
|
143 |
+
" <tr>\n",
|
144 |
+
" <th>14411</th>\n",
|
145 |
+
" <td>До сих пор пересматриваю его видео. Орамбо кст...</td>\n",
|
146 |
+
" <td>0.0</td>\n",
|
147 |
+
" </tr>\n",
|
148 |
+
" </tbody>\n",
|
149 |
+
"</table>\n",
|
150 |
+
"<p>14412 rows × 2 columns</p>\n",
|
151 |
+
"</div>"
|
152 |
+
],
|
153 |
+
"text/plain": [
|
154 |
+
" comment toxic\n",
|
155 |
+
"0 Верблюдов-то за что? Дебилы, бл... 1.0\n",
|
156 |
+
"1 Хохлы, это отдушина затюканого россиянина, мол... 1.0\n",
|
157 |
+
"2 Собаке - собачья смерть 1.0\n",
|
158 |
+
"3 Страницу обнови, дебил. Это тоже не оскорблени... 1.0\n",
|
159 |
+
"4 тебя не убедил 6-страничный пдф в том, что Скр... 1.0\n",
|
160 |
+
"... ... ...\n",
|
161 |
+
"14407 Вонючий совковый скот прибежал и ноет. А вот и... 1.0\n",
|
162 |
+
"14408 А кого любить? Гоблина тупорылого что-ли? Или ... 1.0\n",
|
163 |
+
"14409 Посмотрел Утомленных солнцем 2. И оказалось, ч... 0.0\n",
|
164 |
+
"14410 КРЫМОТРЕД НАРУШАЕТ ПРАВИЛА РАЗДЕЛА Т.К В НЕМ Н... 1.0\n",
|
165 |
+
"14411 До сих пор пересматриваю его видео. Орамбо кст... 0.0\n",
|
166 |
+
"\n",
|
167 |
+
"[14412 rows x 2 columns]"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
"execution_count": 3,
|
171 |
+
"metadata": {},
|
172 |
+
"output_type": "execute_result"
|
173 |
+
}
|
174 |
+
],
|
175 |
+
"source": [
|
176 |
+
"df = pd.read_csv('data/data.csv')\n",
|
177 |
+
"df['comment'] = df['comment'].str.replace('\\n', '', regex=False)\n",
|
178 |
+
"df"
|
179 |
+
]
|
180 |
+
},
|
181 |
+
{
|
182 |
+
"cell_type": "code",
|
183 |
+
"execution_count": 4,
|
184 |
+
"metadata": {},
|
185 |
+
"outputs": [
|
186 |
+
{
|
187 |
+
"name": "stderr",
|
188 |
+
"output_type": "stream",
|
189 |
+
"text": [
|
190 |
+
"100%|██████████| 14412/14412 [02:38<00:00, 90.67it/s]\n"
|
191 |
+
]
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"data": {
|
195 |
+
"text/html": [
|
196 |
+
"<div>\n",
|
197 |
+
"<style scoped>\n",
|
198 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
199 |
+
" vertical-align: middle;\n",
|
200 |
+
" }\n",
|
201 |
+
"\n",
|
202 |
+
" .dataframe tbody tr th {\n",
|
203 |
+
" vertical-align: top;\n",
|
204 |
+
" }\n",
|
205 |
+
"\n",
|
206 |
+
" .dataframe thead th {\n",
|
207 |
+
" text-align: right;\n",
|
208 |
+
" }\n",
|
209 |
+
"</style>\n",
|
210 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
211 |
+
" <thead>\n",
|
212 |
+
" <tr style=\"text-align: right;\">\n",
|
213 |
+
" <th></th>\n",
|
214 |
+
" <th>comment</th>\n",
|
215 |
+
" <th>toxic</th>\n",
|
216 |
+
" <th>predict</th>\n",
|
217 |
+
" </tr>\n",
|
218 |
+
" </thead>\n",
|
219 |
+
" <tbody>\n",
|
220 |
+
" <tr>\n",
|
221 |
+
" <th>0</th>\n",
|
222 |
+
" <td>Верблюдов-то за что? Дебилы, бл...</td>\n",
|
223 |
+
" <td>1.0</td>\n",
|
224 |
+
" <td>1.000</td>\n",
|
225 |
+
" </tr>\n",
|
226 |
+
" <tr>\n",
|
227 |
+
" <th>1</th>\n",
|
228 |
+
" <td>Хохлы, это отдушина затюканого россиянина, мол...</td>\n",
|
229 |
+
" <td>1.0</td>\n",
|
230 |
+
" <td>1.000</td>\n",
|
231 |
+
" </tr>\n",
|
232 |
+
" <tr>\n",
|
233 |
+
" <th>2</th>\n",
|
234 |
+
" <td>Собаке - собачья смерть</td>\n",
|
235 |
+
" <td>1.0</td>\n",
|
236 |
+
" <td>1.000</td>\n",
|
237 |
+
" </tr>\n",
|
238 |
+
" <tr>\n",
|
239 |
+
" <th>3</th>\n",
|
240 |
+
" <td>Страницу обнови, дебил. Это тоже не оскорблени...</td>\n",
|
241 |
+
" <td>1.0</td>\n",
|
242 |
+
" <td>1.000</td>\n",
|
243 |
+
" </tr>\n",
|
244 |
+
" <tr>\n",
|
245 |
+
" <th>4</th>\n",
|
246 |
+
" <td>тебя не убедил 6-страничный пдф в том, что Скр...</td>\n",
|
247 |
+
" <td>1.0</td>\n",
|
248 |
+
" <td>1.000</td>\n",
|
249 |
+
" </tr>\n",
|
250 |
+
" <tr>\n",
|
251 |
+
" <th>...</th>\n",
|
252 |
+
" <td>...</td>\n",
|
253 |
+
" <td>...</td>\n",
|
254 |
+
" <td>...</td>\n",
|
255 |
+
" </tr>\n",
|
256 |
+
" <tr>\n",
|
257 |
+
" <th>14407</th>\n",
|
258 |
+
" <td>Вонючий совковый скот прибежал и ноет. А вот и...</td>\n",
|
259 |
+
" <td>1.0</td>\n",
|
260 |
+
" <td>1.000</td>\n",
|
261 |
+
" </tr>\n",
|
262 |
+
" <tr>\n",
|
263 |
+
" <th>14408</th>\n",
|
264 |
+
" <td>А кого любить? Гоблина тупорылого что-ли? Или ...</td>\n",
|
265 |
+
" <td>1.0</td>\n",
|
266 |
+
" <td>0.487</td>\n",
|
267 |
+
" </tr>\n",
|
268 |
+
" <tr>\n",
|
269 |
+
" <th>14409</th>\n",
|
270 |
+
" <td>Посмотрел Утомленных солнцем 2. И оказалось, ч...</td>\n",
|
271 |
+
" <td>0.0</td>\n",
|
272 |
+
" <td>0.499</td>\n",
|
273 |
+
" </tr>\n",
|
274 |
+
" <tr>\n",
|
275 |
+
" <th>14410</th>\n",
|
276 |
+
" <td>КРЫМОТРЕД НАРУШАЕТ ПРАВИЛА РАЗДЕЛА Т.К В НЕМ Н...</td>\n",
|
277 |
+
" <td>1.0</td>\n",
|
278 |
+
" <td>1.000</td>\n",
|
279 |
+
" </tr>\n",
|
280 |
+
" <tr>\n",
|
281 |
+
" <th>14411</th>\n",
|
282 |
+
" <td>До сих пор пересматриваю его видео. Орамбо кст...</td>\n",
|
283 |
+
" <td>0.0</td>\n",
|
284 |
+
" <td>0.000</td>\n",
|
285 |
+
" </tr>\n",
|
286 |
+
" </tbody>\n",
|
287 |
+
"</table>\n",
|
288 |
+
"<p>14412 rows × 3 columns</p>\n",
|
289 |
+
"</div>"
|
290 |
+
],
|
291 |
+
"text/plain": [
|
292 |
+
" comment toxic predict\n",
|
293 |
+
"0 Верблюдов-то за что? Дебилы, бл... 1.0 1.000\n",
|
294 |
+
"1 Хохлы, это отдушина затюканого россиянина, мол... 1.0 1.000\n",
|
295 |
+
"2 Собаке - собачья смерть 1.0 1.000\n",
|
296 |
+
"3 Страницу обнови, дебил. Это тоже не оскорблени... 1.0 1.000\n",
|
297 |
+
"4 тебя не убедил 6-страничный пдф в том, что Скр... 1.0 1.000\n",
|
298 |
+
"... ... ... ...\n",
|
299 |
+
"14407 Вонючий совковый скот прибежал и ноет. А вот и... 1.0 1.000\n",
|
300 |
+
"14408 А кого любить? Гоблина тупорылого что-ли? Или ... 1.0 0.487\n",
|
301 |
+
"14409 Посмотрел Утомленных солнцем 2. И оказалось, ч... 0.0 0.499\n",
|
302 |
+
"14410 КРЫМОТРЕД НАРУШАЕТ ПРАВИЛА РАЗДЕЛА Т.К В НЕМ Н... 1.0 1.000\n",
|
303 |
+
"14411 До сих пор пересматриваю его видео. Орамбо кст... 0.0 0.000\n",
|
304 |
+
"\n",
|
305 |
+
"[14412 rows x 3 columns]"
|
306 |
+
]
|
307 |
+
},
|
308 |
+
"execution_count": 4,
|
309 |
+
"metadata": {},
|
310 |
+
"output_type": "execute_result"
|
311 |
+
}
|
312 |
+
],
|
313 |
+
"source": [
|
314 |
+
"df['predict'] = df['comment'].progress_apply(lambda x: round(float(model.predict(x).item()), 3))\n",
|
315 |
+
"df"
|
316 |
+
]
|
317 |
+
},
|
318 |
+
{
|
319 |
+
"cell_type": "code",
|
320 |
+
"execution_count": 8,
|
321 |
+
"metadata": {},
|
322 |
+
"outputs": [
|
323 |
+
{
|
324 |
+
"data": {
|
325 |
+
"text/plain": [
|
326 |
+
"True 14056\n",
|
327 |
+
"False 356\n",
|
328 |
+
"Name: count, dtype: int64"
|
329 |
+
]
|
330 |
+
},
|
331 |
+
"execution_count": 8,
|
332 |
+
"metadata": {},
|
333 |
+
"output_type": "execute_result"
|
334 |
+
}
|
335 |
+
],
|
336 |
+
"source": [
|
337 |
+
"(df['toxic'] == df['predict'].apply(round).astype(int)).value_counts()"
|
338 |
+
]
|
339 |
+
}
|
340 |
+
],
|
341 |
+
"metadata": {
|
342 |
+
"kernelspec": {
|
343 |
+
"display_name": "Python 3",
|
344 |
+
"language": "python",
|
345 |
+
"name": "python3"
|
346 |
+
},
|
347 |
+
"language_info": {
|
348 |
+
"codemirror_mode": {
|
349 |
+
"name": "ipython",
|
350 |
+
"version": 3
|
351 |
+
},
|
352 |
+
"file_extension": ".py",
|
353 |
+
"mimetype": "text/x-python",
|
354 |
+
"name": "python",
|
355 |
+
"nbconvert_exporter": "python",
|
356 |
+
"pygments_lexer": "ipython3",
|
357 |
+
"version": "3.12.4"
|
358 |
+
}
|
359 |
+
},
|
360 |
+
"nbformat": 4,
|
361 |
+
"nbformat_minor": 2
|
362 |
+
}
|
notebooks/model.ipynb
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"a:\\python\\312\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
14 |
+
"a:\\python\\312\\Lib\\site-packages\\transformers\\utils\\hub.py:128: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n",
|
15 |
+
" warnings.warn(\n"
|
16 |
+
]
|
17 |
+
}
|
18 |
+
],
|
19 |
+
"source": [
|
20 |
+
"import os\n",
|
21 |
+
"os.chdir('..')\n",
|
22 |
+
"\n",
|
23 |
+
"import pandas as pd\n",
|
24 |
+
"import torch\n",
|
25 |
+
"from torch.utils.data import DataLoader, Dataset\n",
|
26 |
+
"\n",
|
27 |
+
"from scr.dataset import TextDataset\n",
|
28 |
+
"from scr.model import Model"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": 2,
|
34 |
+
"metadata": {},
|
35 |
+
"outputs": [
|
36 |
+
{
|
37 |
+
"name": "stderr",
|
38 |
+
"output_type": "stream",
|
39 |
+
"text": [
|
40 |
+
"C:\\Users\\timof\\AppData\\Local\\Temp\\ipykernel_20804\\3112888757.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
41 |
+
" dataset = torch.load('data/dataset.pt')\n"
|
42 |
+
]
|
43 |
+
}
|
44 |
+
],
|
45 |
+
"source": [
|
46 |
+
"dataset = torch.load('data/dataset.pt')\n",
|
47 |
+
"dataloader = DataLoader(dataset, batch_size=16, shuffle=True, pin_memory=True)"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "code",
|
52 |
+
"execution_count": 3,
|
53 |
+
"metadata": {},
|
54 |
+
"outputs": [
|
55 |
+
{
|
56 |
+
"data": {
|
57 |
+
"text/plain": [
|
58 |
+
"Model(\n",
|
59 |
+
" (model): Sequential(\n",
|
60 |
+
" (0): Block(\n",
|
61 |
+
" (model): Sequential(\n",
|
62 |
+
" (0): Linear(in_features=1024, out_features=512, bias=True)\n",
|
63 |
+
" (1): Dropout(p=0.2, inplace=False)\n",
|
64 |
+
" )\n",
|
65 |
+
" )\n",
|
66 |
+
" (1): LeakyReLU(negative_slope=0.01)\n",
|
67 |
+
" (2): Block(\n",
|
68 |
+
" (model): Sequential(\n",
|
69 |
+
" (0): Linear(in_features=512, out_features=256, bias=True)\n",
|
70 |
+
" (1): Dropout(p=0.2, inplace=False)\n",
|
71 |
+
" )\n",
|
72 |
+
" )\n",
|
73 |
+
" (3): LeakyReLU(negative_slope=0.01)\n",
|
74 |
+
" (4): Block(\n",
|
75 |
+
" (model): Sequential(\n",
|
76 |
+
" (0): Linear(in_features=256, out_features=128, bias=True)\n",
|
77 |
+
" (1): Dropout(p=0.2, inplace=False)\n",
|
78 |
+
" )\n",
|
79 |
+
" )\n",
|
80 |
+
" (5): LeakyReLU(negative_slope=0.01)\n",
|
81 |
+
" (6): Block(\n",
|
82 |
+
" (model): Sequential(\n",
|
83 |
+
" (0): Linear(in_features=128, out_features=64, bias=True)\n",
|
84 |
+
" (1): Dropout(p=0.2, inplace=False)\n",
|
85 |
+
" )\n",
|
86 |
+
" )\n",
|
87 |
+
" (7): LeakyReLU(negative_slope=0.01)\n",
|
88 |
+
" (8): Block(\n",
|
89 |
+
" (model): Sequential(\n",
|
90 |
+
" (0): Linear(in_features=64, out_features=1, bias=True)\n",
|
91 |
+
" (1): Dropout(p=0.2, inplace=False)\n",
|
92 |
+
" )\n",
|
93 |
+
" )\n",
|
94 |
+
" (9): Sigmoid()\n",
|
95 |
+
" )\n",
|
96 |
+
")"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
"execution_count": 3,
|
100 |
+
"metadata": {},
|
101 |
+
"output_type": "execute_result"
|
102 |
+
}
|
103 |
+
],
|
104 |
+
"source": [
|
105 |
+
"model = Model()\n",
|
106 |
+
"model.eval()"
|
107 |
+
]
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"cell_type": "code",
|
111 |
+
"execution_count": 4,
|
112 |
+
"metadata": {},
|
113 |
+
"outputs": [
|
114 |
+
{
|
115 |
+
"name": "stderr",
|
116 |
+
"output_type": "stream",
|
117 |
+
"text": [
|
118 |
+
"C:\\Users\\timof\\AppData\\Local\\Temp\\ipykernel_20804\\3887862913.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
119 |
+
" model.load_state_dict(torch.load('models/model_epoch_50.pt'))\n"
|
120 |
+
]
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"data": {
|
124 |
+
"text/plain": [
|
125 |
+
"<All keys matched successfully>"
|
126 |
+
]
|
127 |
+
},
|
128 |
+
"execution_count": 4,
|
129 |
+
"metadata": {},
|
130 |
+
"output_type": "execute_result"
|
131 |
+
}
|
132 |
+
],
|
133 |
+
"source": [
|
134 |
+
"model.load_state_dict(torch.load('models/model_epoch_50.pt'))"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"cell_type": "code",
|
139 |
+
"execution_count": null,
|
140 |
+
"metadata": {},
|
141 |
+
"outputs": [],
|
142 |
+
"source": []
|
143 |
+
}
|
144 |
+
],
|
145 |
+
"metadata": {
|
146 |
+
"kernelspec": {
|
147 |
+
"display_name": "Python 3",
|
148 |
+
"language": "python",
|
149 |
+
"name": "python3"
|
150 |
+
},
|
151 |
+
"language_info": {
|
152 |
+
"codemirror_mode": {
|
153 |
+
"name": "ipython",
|
154 |
+
"version": 3
|
155 |
+
},
|
156 |
+
"file_extension": ".py",
|
157 |
+
"mimetype": "text/x-python",
|
158 |
+
"name": "python",
|
159 |
+
"nbconvert_exporter": "python",
|
160 |
+
"pygments_lexer": "ipython3",
|
161 |
+
"version": "3.12.4"
|
162 |
+
}
|
163 |
+
},
|
164 |
+
"nbformat": 4,
|
165 |
+
"nbformat_minor": 2
|
166 |
+
}
|
scr/__init__.py
ADDED
File without changes
|
scr/dataset.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader, Dataset
|
2 |
+
import torch
|
3 |
+
from scr.sbert import vectorize as vec
|
4 |
+
import pandas as pd
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
tqdm.pandas()
|
8 |
+
|
9 |
+
df = pd.read_csv('data/data.csv')
|
10 |
+
|
11 |
+
class TextDataset(Dataset):
|
12 |
+
def __init__(self, df):
|
13 |
+
self.vectors = torch.stack(list(df['comment'].progress_apply(lambda x: vec(x).squeeze(0))))
|
14 |
+
self.labels = torch.tensor(df['toxic'].values)
|
15 |
+
|
16 |
+
def __getitem__(self, index):
|
17 |
+
return self.vectors[index], self.labels[index]
|
18 |
+
|
19 |
+
def __len__(self):
|
20 |
+
return len(self.labels)
|
21 |
+
|
22 |
+
if __name__ == '__main__':
|
23 |
+
# Сохраняем векторизованный датасет
|
24 |
+
dataset = TextDataset(df)
|
25 |
+
torch.save(dataset, 'data/dataset.pt')
|
26 |
+
|
27 |
+
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, pin_memory=True)
|
28 |
+
|
29 |
+
# Загружаем данные с DataLoader
|
30 |
+
for texts, labels in dataloader:
|
31 |
+
print(texts, labels)
|
32 |
+
break
|
scr/model.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from .sbert import vectorize as vec
|
4 |
+
|
5 |
+
class Block(nn.Module):
|
6 |
+
def __init__(self, input_dim, output_dim):
|
7 |
+
super().__init__()
|
8 |
+
self.model = nn.Sequential(
|
9 |
+
nn.Linear(input_dim, output_dim),
|
10 |
+
nn.Dropout(0.2),
|
11 |
+
)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
return self.model(x)
|
15 |
+
|
16 |
+
class Model(nn.Module):
|
17 |
+
def __init__(self):
|
18 |
+
super().__init__()
|
19 |
+
self.model = nn.Sequential(
|
20 |
+
Block(1024, 512),
|
21 |
+
nn.LeakyReLU(),
|
22 |
+
|
23 |
+
Block(512, 256),
|
24 |
+
nn.LeakyReLU(),
|
25 |
+
|
26 |
+
Block(256, 128),
|
27 |
+
nn.LeakyReLU(),
|
28 |
+
|
29 |
+
Block(128, 64),
|
30 |
+
nn.LeakyReLU(),
|
31 |
+
|
32 |
+
Block(64, 1),
|
33 |
+
nn.Sigmoid(),
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return self.model(x)
|
38 |
+
|
39 |
+
def predict(self, text):
|
40 |
+
return self(vec(text))
|
scr/sbert.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModel
|
2 |
+
import torch
|
3 |
+
|
4 |
+
# Определяем доступное устройство
|
5 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
6 |
+
|
7 |
+
# Mean Pooling - Учитывает attention mask для корректного усреднения
|
8 |
+
def mean_pooling(model_output, attention_mask):
|
9 |
+
token_embeddings = model_output[0] # Эмбеддинги токенов
|
10 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
11 |
+
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
|
12 |
+
sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
|
13 |
+
return sum_embeddings / sum_mask
|
14 |
+
|
15 |
+
# Загрузка модели и токенизатора с HuggingFace
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained("ai-forever/sbert_large_mt_nlu_ru")
|
17 |
+
sbert = AutoModel.from_pretrained("ai-forever/sbert_large_mt_nlu_ru").to(device) # Перенос модели на устройство
|
18 |
+
|
19 |
+
def vectorize(texts, batch_size=32):
|
20 |
+
if isinstance(texts, str):
|
21 |
+
texts = [texts] # Если передана строка, оборачиваем её в список
|
22 |
+
|
23 |
+
embeddings = []
|
24 |
+
for i in range(0, len(texts), batch_size):
|
25 |
+
batch = texts[i:i + batch_size]
|
26 |
+
encoded_input = tokenizer(
|
27 |
+
batch,
|
28 |
+
padding=True,
|
29 |
+
truncation=True,
|
30 |
+
max_length=64,
|
31 |
+
return_tensors='pt'
|
32 |
+
).to(device)
|
33 |
+
|
34 |
+
with torch.no_grad():
|
35 |
+
model_output = sbert(**encoded_input)
|
36 |
+
batch_embeddings = mean_pooling(model_output, encoded_input['attention_mask']).cpu()
|
37 |
+
embeddings.append(batch_embeddings)
|
38 |
+
|
39 |
+
# Конкатенируем батчи и убираем лишнее измерение
|
40 |
+
return torch.cat(embeddings, dim=0)
|
41 |
+
|
scr/train.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import DataLoader, random_split
|
4 |
+
from tqdm import tqdm
|
5 |
+
import logging
|
6 |
+
|
7 |
+
from scr.sbert import sbert
|
8 |
+
from scr.dataset import TextDataset
|
9 |
+
from scr.model import Model
|
10 |
+
|
11 |
+
# Device configuration
|
12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
|
14 |
+
# Training parameters
|
15 |
+
EPOCHS = 50
|
16 |
+
BATCH_SIZE = 16
|
17 |
+
LEARNING_RATE = 1e-3
|
18 |
+
MODEL_SAVE_DIR = "models"
|
19 |
+
LOG_DIR = "models"
|
20 |
+
TEST_SPLIT = 0.2 # 80/20 split
|
21 |
+
|
22 |
+
# Ensure directories exist
|
23 |
+
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
24 |
+
os.makedirs(LOG_DIR, exist_ok=True)
|
25 |
+
|
26 |
+
# Setup logging
|
27 |
+
log_file = os.path.join(LOG_DIR, "training.log")
|
28 |
+
logging.basicConfig(
|
29 |
+
filename=log_file,
|
30 |
+
filemode="w",
|
31 |
+
level=logging.INFO,
|
32 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
33 |
+
)
|
34 |
+
logger = logging.getLogger()
|
35 |
+
|
36 |
+
# Load dataset
|
37 |
+
full_dataset = torch.load('data/dataset.pt')
|
38 |
+
|
39 |
+
# Split dataset into training and testing sets
|
40 |
+
train_size = int((1 - TEST_SPLIT) * len(full_dataset))
|
41 |
+
test_size = len(full_dataset) - train_size
|
42 |
+
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
|
43 |
+
|
44 |
+
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
|
45 |
+
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)
|
46 |
+
|
47 |
+
# Initialize model
|
48 |
+
model = Model().to(device)
|
49 |
+
|
50 |
+
# Loss function and optimizer
|
51 |
+
criterion = torch.nn.BCELoss()
|
52 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
53 |
+
|
54 |
+
# Accuracy calculation
|
55 |
+
def calculate_accuracy(model, dataloader, device):
|
56 |
+
model.eval()
|
57 |
+
correct = 0
|
58 |
+
total = 0
|
59 |
+
with torch.no_grad():
|
60 |
+
for texts, labels in dataloader:
|
61 |
+
texts, labels = texts.to(device), labels.to(device, dtype=torch.float)
|
62 |
+
outputs = model(texts).squeeze(1)
|
63 |
+
predictions = (outputs > 0.5).float() # Binary classification threshold
|
64 |
+
correct += (predictions == labels).sum().item()
|
65 |
+
total += labels.size(0)
|
66 |
+
return 100 * correct / total if total > 0 else 0
|
67 |
+
|
68 |
+
# Training function
|
69 |
+
def train_model(model, train_dataloader, test_dataloader, criterion, optimizer, epochs, device):
|
70 |
+
for epoch in range(epochs):
|
71 |
+
model.train()
|
72 |
+
epoch_loss = 0.0
|
73 |
+
logger.info(f"Epoch {epoch + 1}/{epochs} started")
|
74 |
+
|
75 |
+
# Training loop
|
76 |
+
for texts, labels in tqdm(
|
77 |
+
train_dataloader,
|
78 |
+
desc=f"Epoch {epoch + 1}/{epochs}",
|
79 |
+
ncols=12 # Limit progress bar width to 12 characters
|
80 |
+
):
|
81 |
+
texts, labels = texts.to(device), labels.to(device, dtype=torch.float)
|
82 |
+
|
83 |
+
# Reset gradients
|
84 |
+
optimizer.zero_grad()
|
85 |
+
|
86 |
+
# Forward pass
|
87 |
+
outputs = model(texts).squeeze(1)
|
88 |
+
loss = criterion(outputs, labels)
|
89 |
+
|
90 |
+
# Backward pass
|
91 |
+
loss.backward()
|
92 |
+
optimizer.step()
|
93 |
+
|
94 |
+
# Accumulate loss
|
95 |
+
epoch_loss += loss.item()
|
96 |
+
|
97 |
+
# Log average loss for the epoch
|
98 |
+
avg_loss = epoch_loss / len(train_dataloader)
|
99 |
+
logger.info(f"Average loss for epoch {epoch + 1}: {avg_loss:.4f}")
|
100 |
+
|
101 |
+
# Evaluate on the test set
|
102 |
+
accuracy = calculate_accuracy(model, test_dataloader, device)
|
103 |
+
logger.info(f"Test accuracy after epoch {epoch + 1}: {accuracy:.2f}%")
|
104 |
+
|
105 |
+
# Save model
|
106 |
+
model_path = os.path.join(MODEL_SAVE_DIR, f"model_epoch_{epoch + 1}.pt")
|
107 |
+
torch.save(model.state_dict(), model_path)
|
108 |
+
logger.info(f"Model saved: {model_path}")
|
109 |
+
|
110 |
+
logger.info("Training complete.")
|
111 |
+
|
112 |
+
# Run training
|
113 |
+
if __name__ == "__main__":
|
114 |
+
logger.info("Training started")
|
115 |
+
logger.info(f"Training set size: {len(train_dataset)}, Test set size: {len(test_dataset)}")
|
116 |
+
train_model(model, train_dataloader, test_dataloader, criterion, optimizer, EPOCHS, device)
|
117 |
+
logger.info("Training complete")
|