ivanovot commited on
Commit
d1474ea
·
1 Parent(s): 6cc6c10
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ venv/
2
+ *__pycache__
3
+ *.gradio
app.py CHANGED
@@ -1,56 +1,61 @@
1
  import gradio as gr
2
- from model import model # Импортируем вашу модель
 
3
 
4
- # Токсичные и нетоксичные комментарии для тестирования
5
- examples = [
6
- ["Страницу обнови, дебил. Это тоже не оскорбление, а доказанный факт - не-дебил про себя во множественном числе писать не будет. Или мы в тебя верим - это ты и твои воображаемые друзья?"],
7
- ["УПАД Т! ТАМ НЕЛЬЗЯ СТРОИТЬ! ТЕХНОЛОГИЙ НЕТ! РАЗВОРУЮТ КАК ВСЕГДА! УЖЕ ТРЕЩИНАМИ ПОШ Л! ТУПЫЕ КИТАЗЫ НЕ МОГУТ НИЧЕГО НОРМАЛЬНО СДЕЛАТЬ!"],
8
- ["хорош врать, ты террорист-торчёк-шизофреник пруф: а вот без костюма да чутка учёный, миллиардер, филантроп"],
9
- ["Мне Вас очень жаль, если для Вас оскорбления - норма"],
10
- ["Осторожней на сверхманёврах. В предыдущей методичке у вас было написано, что добрые арабы никогда ни с кем не воевали, только торговали пряностями, лел. Шапочку из фольги сними"],
11
- ["Так то стоит около 12,5 тысяч, но мне вышло в 6636 рублей и бесплатная доставка"],
12
- ["Ну хочешь я стану твоим другом? Как тебя зовут? Чем увлекаешься?"],
13
- ["Ну так это в плане изготовления изделий своими руками,а вот готовить вроде умею.Короче буду сам на себе испытывать божественный напиток и куплю огнетушитель (промышленный на всякий случай)."],
14
- ["Я согласен, что это хорошая идея! Давайте подумаем, как можно улучшить её еще больше."],
15
- ["Очень полезная информация, спасибо за подробное объяснение! Я многому научился."],
16
- ["Мне нравится, как вы объясняете! Это действительно помогает разобраться в теме."],
17
- ["Отлично написано, теперь я лучше понимаю, как работать с этим инструментом."],
18
- ["Классная идея! Надо попробовать и посмотреть, как это работает на практике."],
19
- ["Ваши советы очень полезны. Это точно сэкономит мне время на следующий раз."],
20
- ["Спасибо за ваши рекомендации! Я обязательно попробую это в будущем."],
21
- ["Мне нравится ваше решение этой проблемы. Очень креативно и практично."],
22
- ["Спасибо за помощь! Вы помогли мне разобраться в ситуации и сэкономить много времени."],
23
- ["Согласен с вами, это действительно важный аспект, о котором стоит задуматься."],
24
- ["Очень вдохновляющий пост! Я буду следовать вашим рекомендациям."],
25
- ]
26
 
27
- # Функция для предсказания токсичности текста
 
 
 
 
28
  def predict_text(text):
 
29
  word_count = len(text.split())
30
- if word_count < 7:
31
- return "Слишком короткий текст"
32
 
33
- output = model.predict(text)
34
- return "Токсичный" if output == 1 else "Не токсичный"
 
35
 
36
- # Создаем интерфейс с улучшениями
 
 
 
 
 
 
 
 
 
37
  demo = gr.Interface(
38
  fn=predict_text, # Функция для предсказания
39
  inputs=gr.Textbox(
40
  label="Введите текст для проверки на токсичность", # Подпись для текстового поля
41
- placeholder="Напишите комментарий для анализа", # Подсказка для поля ввода
42
- lines=5, # Количество строк для текста
43
- interactive=True, # Интерактивность для пользователя
44
- ),
45
- outputs=gr.Textbox(
46
- label="Результат анализа", # Подпись для вывода
47
- placeholder="Результат токсичности текста будет здесь", # Подсказка для вывода
48
  ),
49
- live=True, # Включаем live обновление
 
 
 
 
 
 
 
 
 
 
 
 
50
  examples=examples, # Примеры для пользователей
51
- title="Тестирование токсичности текста", # Заголовок интерфейса
52
- description="Введите любой текст, чтобы проверить его на токсичность. Модель проанализирует, является ли текст токсичным или нет.", # Описание
 
53
  )
54
 
55
- # Запуск приложения с улучшенным интерфейсом
56
- demo.launch(server_name="127.0.0.1", server_port=7860)
 
1
  import gradio as gr
2
+ from scr.model import Model
3
+ import torch
4
 
5
+ # Настройка устройства и загрузка модели
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ model = Model()
9
+ model.load_state_dict(torch.load('models/model_epoch_50.pt', map_location=device))
10
+ model.eval()
11
+
12
+ # Функция для предсказания оценки токсичности текста
13
  def predict_text(text):
14
+ # Проверяем длину текста
15
  word_count = len(text.split())
16
+ if word_count < 3:
17
+ return "Слишком короткий текст", None
18
 
19
+ # Предсказываем результат
20
+ score = round(float(model.predict(text).item()), 5) # Приводим результат к числу с 5 знаками после запятой
21
+ return f"Оценка токсичности: {score}", score
22
 
23
+ # Примеры для демонстрации
24
+ examples = [
25
+ "Этот продукт ��росто великолепен, спасибо!",
26
+ "Ты ужасен, не могу терпеть твои комментарии!",
27
+ "Сегодня был хороший день, несмотря на небольшой дождь.",
28
+ "Твой проект провалился, и это только твоя вина.",
29
+ "Замечательная работа, вы молодцы!"
30
+ ]
31
+
32
+ # Создаем интерфейс
33
  demo = gr.Interface(
34
  fn=predict_text, # Функция для предсказания
35
  inputs=gr.Textbox(
36
  label="Введите текст для проверки на токсичность", # Подпись для текстового поля
37
+ placeholder="Напишите комментарий для анализа", # Подсказка для ввода
38
+ lines=5, # Количество строк
39
+ interactive=True # Включаем интерактивность
 
 
 
 
40
  ),
41
+ outputs=[
42
+ gr.Textbox(
43
+ label="Результат анализа", # Подпись для вывода
44
+ placeholder="Оценка токсичности будет показана здесь", # Подсказка для вывода
45
+ ),
46
+ gr.Slider(
47
+ label="Шкала токсичности", # Подпись шкалы
48
+ minimum=0.0,
49
+ maximum=1.0,
50
+ step=0.00001,
51
+ interactive=False, # Делаем слайдер только для вывода
52
+ )
53
+ ],
54
  examples=examples, # Примеры для пользователей
55
+ title="Toxicity Classification", # Заголовок
56
+ description="Введите текст, чтобы узнать его оценку токсичности (0 - не токсичный, 1 - максимально токсичный).", # Описание
57
+ live=True, # Автоматический запуск модели при изменении текста
58
  )
59
 
60
+ # Запуск приложения
61
+ demo.launch()
model/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .model import model
 
 
model/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (176 Bytes)
 
model/__pycache__/model.cpython-312.pyc DELETED
Binary file (3.69 kB)
 
model/model.py DELETED
@@ -1,68 +0,0 @@
1
- import torch
2
- from transformers import BertTokenizer, BertModel
3
- import torch.nn as nn
4
-
5
- class PowerfulBinaryTextClassifier(nn.Module):
6
- def __init__(self, model_name, lstm_hidden_size=256, num_layers=3, dropout_rate=0.2):
7
- super(PowerfulBinaryTextClassifier, self).__init__()
8
- self.bert = BertModel.from_pretrained(model_name)
9
-
10
- # Добавляем несколько LSTM слоев с большим размером скрытого состояния
11
- self.lstm = nn.LSTM(input_size=self.bert.config.hidden_size,
12
- hidden_size=lstm_hidden_size,
13
- num_layers=num_layers,
14
- batch_first=True,
15
- bidirectional=True,
16
- dropout=dropout_rate if num_layers > 1 else 0)
17
-
18
- # Полносвязный блок с увеличенным количеством нейронов и слоев Dropout
19
- self.fc = nn.Sequential(
20
- nn.Linear(lstm_hidden_size * 2, 2), # полносвязный слой
21
- nn.Sigmoid()
22
- )
23
-
24
- self.tokenizer = BertTokenizer.from_pretrained(model_name) # Инициализация токенизатора
25
-
26
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
27
-
28
- def forward(self, input_ids, attention_mask):
29
- outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
30
- bert_outputs = outputs.last_hidden_state # (batch_size, sequence_length, hidden_size)
31
-
32
- # Применяем LSTM
33
- lstm_out, _ = self.lstm(bert_outputs) # (batch_size, sequence_length, lstm_hidden_size * 2)
34
-
35
- # Берем выход последнего временного шага для классификации
36
- last_time_step = lstm_out[:, -1, :] # (batch_size, lstm_hidden_size * 2)
37
-
38
- logits = self.fc(last_time_step) # Применяем полносвязный блок
39
-
40
- logits[:, 1] -= 0.995 # Умножаем логит для выбранного класса
41
-
42
- return logits # Возвращаем логиты для двух классов
43
-
44
- def predict(self, text):
45
- self.to(self.device) # Переносим модель на выбранное устройство
46
-
47
- # Токенизация текста
48
- inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256)
49
- input_ids = inputs['input_ids'].to(self.device) # Переносим на устройство
50
- attention_mask = inputs['attention_mask'].to(self.device) # Переносим на устройство
51
-
52
- # Получение предсказания
53
- self.eval() # Переключаем модель в режим оценки
54
- with torch.no_grad():
55
- preds = self(input_ids, attention_mask) # Получаем логиты
56
-
57
- # Возвращаем индекс класса с наибольшей вероятностью
58
- return torch.argmax(preds, dim=1).item() # Возвращаем индекс класса
59
-
60
- def load_weights(self, filepath):
61
- # Загрузка весов модели
62
- self.load_state_dict(torch.load(filepath, map_location=self.device, weights_only=True))
63
-
64
- # Пример инициализации модели
65
- model_name = "DeepPavlov/rubert-base-cased"
66
- model = PowerfulBinaryTextClassifier(model_name)
67
-
68
- model.load_weights('model.pth')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.pth → models/model_epoch_10.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d450dffd595a4046900b27864c6df703e62a658bf5e74e3a4230a2c86040f359
3
- size 732509954
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0171b7a667399691c6a0c9ca1ac01a2455cdd294db6c4c97a3f5717a75e97f38
3
+ size 2793946
models/model_epoch_20.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efd2666876a870ebb7dd175e9d14e5f4a7f1c7c3adc577ad3e148480e7906326
3
+ size 2793946
models/model_epoch_30.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef89dab538f52b47be6a9525e672c6f8e03dcbff6bb4b6b3b26c3b824adec089
3
+ size 2793946
models/model_epoch_40.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26e450b74eb7fb8fbbe6e64490aa71374fa6c9c1343bed927c238475e1920b17
3
+ size 2793946
models/model_epoch_50.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:463cef09e540c6cfc18dfbe817d616403d097b0e30f5f3f1ca2e6c4cdf54d4ad
3
+ size 2793946
models/training.log ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-11-28 15:01:21,773 - INFO - Training started
2
+ 2024-11-28 15:01:21,774 - INFO - Training set size: 11529, Test set size: 2883
3
+ 2024-11-28 15:01:21,774 - INFO - Epoch 1/50 started
4
+ 2024-11-28 15:01:23,378 - INFO - Average loss for epoch 1: 0.3625
5
+ 2024-11-28 15:01:23,511 - INFO - Test accuracy after epoch 1: 90.36%
6
+ 2024-11-28 15:01:23,527 - INFO - Model saved: models\model_epoch_1.pt
7
+ 2024-11-28 15:01:23,527 - INFO - Epoch 2/50 started
8
+ 2024-11-28 15:01:25,024 - INFO - Average loss for epoch 2: 0.3320
9
+ 2024-11-28 15:01:25,131 - INFO - Test accuracy after epoch 2: 90.11%
10
+ 2024-11-28 15:01:25,135 - INFO - Model saved: models\model_epoch_2.pt
11
+ 2024-11-28 15:01:25,135 - INFO - Epoch 3/50 started
12
+ 2024-11-28 15:01:26,611 - INFO - Average loss for epoch 3: 0.3238
13
+ 2024-11-28 15:01:26,708 - INFO - Test accuracy after epoch 3: 90.01%
14
+ 2024-11-28 15:01:26,712 - INFO - Model saved: models\model_epoch_3.pt
15
+ 2024-11-28 15:01:26,712 - INFO - Epoch 4/50 started
16
+ 2024-11-28 15:01:28,205 - INFO - Average loss for epoch 4: 0.3089
17
+ 2024-11-28 15:01:28,324 - INFO - Test accuracy after epoch 4: 89.94%
18
+ 2024-11-28 15:01:28,328 - INFO - Model saved: models\model_epoch_4.pt
19
+ 2024-11-28 15:01:28,328 - INFO - Epoch 5/50 started
20
+ 2024-11-28 15:01:29,770 - INFO - Average loss for epoch 5: 0.2921
21
+ 2024-11-28 15:01:29,896 - INFO - Test accuracy after epoch 5: 89.70%
22
+ 2024-11-28 15:01:29,900 - INFO - Model saved: models\model_epoch_5.pt
23
+ 2024-11-28 15:01:29,900 - INFO - Epoch 6/50 started
24
+ 2024-11-28 15:01:31,476 - INFO - Average loss for epoch 6: 0.2802
25
+ 2024-11-28 15:01:31,583 - INFO - Test accuracy after epoch 6: 89.77%
26
+ 2024-11-28 15:01:31,587 - INFO - Model saved: models\model_epoch_6.pt
27
+ 2024-11-28 15:01:31,587 - INFO - Epoch 7/50 started
28
+ 2024-11-28 15:01:33,168 - INFO - Average loss for epoch 7: 0.2598
29
+ 2024-11-28 15:01:33,291 - INFO - Test accuracy after epoch 7: 88.62%
30
+ 2024-11-28 15:01:33,295 - INFO - Model saved: models\model_epoch_7.pt
31
+ 2024-11-28 15:01:33,295 - INFO - Epoch 8/50 started
32
+ 2024-11-28 15:01:34,921 - INFO - Average loss for epoch 8: 0.2509
33
+ 2024-11-28 15:01:35,040 - INFO - Test accuracy after epoch 8: 88.97%
34
+ 2024-11-28 15:01:35,044 - INFO - Model saved: models\model_epoch_8.pt
35
+ 2024-11-28 15:01:35,044 - INFO - Epoch 9/50 started
36
+ 2024-11-28 15:01:36,579 - INFO - Average loss for epoch 9: 0.2382
37
+ 2024-11-28 15:01:36,691 - INFO - Test accuracy after epoch 9: 88.87%
38
+ 2024-11-28 15:01:36,695 - INFO - Model saved: models\model_epoch_9.pt
39
+ 2024-11-28 15:01:36,695 - INFO - Epoch 10/50 started
40
+ 2024-11-28 15:01:38,227 - INFO - Average loss for epoch 10: 0.2180
41
+ 2024-11-28 15:01:38,341 - INFO - Test accuracy after epoch 10: 88.73%
42
+ 2024-11-28 15:01:38,345 - INFO - Model saved: models\model_epoch_10.pt
43
+ 2024-11-28 15:01:38,345 - INFO - Epoch 11/50 started
44
+ 2024-11-28 15:01:39,881 - INFO - Average loss for epoch 11: 0.2175
45
+ 2024-11-28 15:01:40,020 - INFO - Test accuracy after epoch 11: 89.04%
46
+ 2024-11-28 15:01:40,025 - INFO - Model saved: models\model_epoch_11.pt
47
+ 2024-11-28 15:01:40,025 - INFO - Epoch 12/50 started
48
+ 2024-11-28 15:01:41,607 - INFO - Average loss for epoch 12: 0.2038
49
+ 2024-11-28 15:01:41,715 - INFO - Test accuracy after epoch 12: 88.80%
50
+ 2024-11-28 15:01:41,719 - INFO - Model saved: models\model_epoch_12.pt
51
+ 2024-11-28 15:01:41,719 - INFO - Epoch 13/50 started
52
+ 2024-11-28 15:01:43,369 - INFO - Average loss for epoch 13: 0.2036
53
+ 2024-11-28 15:01:43,484 - INFO - Test accuracy after epoch 13: 89.28%
54
+ 2024-11-28 15:01:43,491 - INFO - Model saved: models\model_epoch_13.pt
55
+ 2024-11-28 15:01:43,491 - INFO - Epoch 14/50 started
56
+ 2024-11-28 15:01:45,156 - INFO - Average loss for epoch 14: 0.1987
57
+ 2024-11-28 15:01:45,267 - INFO - Test accuracy after epoch 14: 89.18%
58
+ 2024-11-28 15:01:45,271 - INFO - Model saved: models\model_epoch_14.pt
59
+ 2024-11-28 15:01:45,271 - INFO - Epoch 15/50 started
60
+ 2024-11-28 15:01:46,926 - INFO - Average loss for epoch 15: 0.1865
61
+ 2024-11-28 15:01:47,042 - INFO - Test accuracy after epoch 15: 88.52%
62
+ 2024-11-28 15:01:47,046 - INFO - Model saved: models\model_epoch_15.pt
63
+ 2024-11-28 15:01:47,047 - INFO - Epoch 16/50 started
64
+ 2024-11-28 15:01:48,560 - INFO - Average loss for epoch 16: 0.1825
65
+ 2024-11-28 15:01:48,667 - INFO - Test accuracy after epoch 16: 88.94%
66
+ 2024-11-28 15:01:48,671 - INFO - Model saved: models\model_epoch_16.pt
67
+ 2024-11-28 15:01:48,671 - INFO - Epoch 17/50 started
68
+ 2024-11-28 15:01:50,298 - INFO - Average loss for epoch 17: 0.1839
69
+ 2024-11-28 15:01:50,418 - INFO - Test accuracy after epoch 17: 89.04%
70
+ 2024-11-28 15:01:50,423 - INFO - Model saved: models\model_epoch_17.pt
71
+ 2024-11-28 15:01:50,423 - INFO - Epoch 18/50 started
72
+ 2024-11-28 15:01:51,971 - INFO - Average loss for epoch 18: 0.1800
73
+ 2024-11-28 15:01:52,082 - INFO - Test accuracy after epoch 18: 89.21%
74
+ 2024-11-28 15:01:52,086 - INFO - Model saved: models\model_epoch_18.pt
75
+ 2024-11-28 15:01:52,086 - INFO - Epoch 19/50 started
76
+ 2024-11-28 15:01:53,641 - INFO - Average loss for epoch 19: 0.1724
77
+ 2024-11-28 15:01:53,752 - INFO - Test accuracy after epoch 19: 88.83%
78
+ 2024-11-28 15:01:53,756 - INFO - Model saved: models\model_epoch_19.pt
79
+ 2024-11-28 15:01:53,756 - INFO - Epoch 20/50 started
80
+ 2024-11-28 15:01:55,336 - INFO - Average loss for epoch 20: 0.1795
81
+ 2024-11-28 15:01:55,476 - INFO - Test accuracy after epoch 20: 88.66%
82
+ 2024-11-28 15:01:55,480 - INFO - Model saved: models\model_epoch_20.pt
83
+ 2024-11-28 15:01:55,480 - INFO - Epoch 21/50 started
84
+ 2024-11-28 15:01:57,055 - INFO - Average loss for epoch 21: 0.1759
85
+ 2024-11-28 15:01:57,178 - INFO - Test accuracy after epoch 21: 88.87%
86
+ 2024-11-28 15:01:57,182 - INFO - Model saved: models\model_epoch_21.pt
87
+ 2024-11-28 15:01:57,182 - INFO - Epoch 22/50 started
88
+ 2024-11-28 15:01:58,830 - INFO - Average loss for epoch 22: 0.1708
89
+ 2024-11-28 15:01:58,957 - INFO - Test accuracy after epoch 22: 88.69%
90
+ 2024-11-28 15:01:58,961 - INFO - Model saved: models\model_epoch_22.pt
91
+ 2024-11-28 15:01:58,961 - INFO - Epoch 23/50 started
92
+ 2024-11-28 15:02:00,613 - INFO - Average loss for epoch 23: 0.1746
93
+ 2024-11-28 15:02:00,715 - INFO - Test accuracy after epoch 23: 88.76%
94
+ 2024-11-28 15:02:00,720 - INFO - Model saved: models\model_epoch_23.pt
95
+ 2024-11-28 15:02:00,720 - INFO - Epoch 24/50 started
96
+ 2024-11-28 15:02:02,331 - INFO - Average loss for epoch 24: 0.1745
97
+ 2024-11-28 15:02:02,445 - INFO - Test accuracy after epoch 24: 87.96%
98
+ 2024-11-28 15:02:02,449 - INFO - Model saved: models\model_epoch_24.pt
99
+ 2024-11-28 15:02:02,449 - INFO - Epoch 25/50 started
100
+ 2024-11-28 15:02:03,994 - INFO - Average loss for epoch 25: 0.1769
101
+ 2024-11-28 15:02:04,107 - INFO - Test accuracy after epoch 25: 88.66%
102
+ 2024-11-28 15:02:04,111 - INFO - Model saved: models\model_epoch_25.pt
103
+ 2024-11-28 15:02:04,111 - INFO - Epoch 26/50 started
104
+ 2024-11-28 15:02:05,674 - INFO - Average loss for epoch 26: 0.1730
105
+ 2024-11-28 15:02:05,778 - INFO - Test accuracy after epoch 26: 88.66%
106
+ 2024-11-28 15:02:05,782 - INFO - Model saved: models\model_epoch_26.pt
107
+ 2024-11-28 15:02:05,782 - INFO - Epoch 27/50 started
108
+ 2024-11-28 15:02:07,327 - INFO - Average loss for epoch 27: 0.1621
109
+ 2024-11-28 15:02:07,430 - INFO - Test accuracy after epoch 27: 88.35%
110
+ 2024-11-28 15:02:07,434 - INFO - Model saved: models\model_epoch_27.pt
111
+ 2024-11-28 15:02:07,434 - INFO - Epoch 28/50 started
112
+ 2024-11-28 15:02:08,943 - INFO - Average loss for epoch 28: 0.1720
113
+ 2024-11-28 15:02:09,064 - INFO - Test accuracy after epoch 28: 87.72%
114
+ 2024-11-28 15:02:09,068 - INFO - Model saved: models\model_epoch_28.pt
115
+ 2024-11-28 15:02:09,068 - INFO - Epoch 29/50 started
116
+ 2024-11-28 15:02:10,699 - INFO - Average loss for epoch 29: 0.1615
117
+ 2024-11-28 15:02:10,808 - INFO - Test accuracy after epoch 29: 89.21%
118
+ 2024-11-28 15:02:10,812 - INFO - Model saved: models\model_epoch_29.pt
119
+ 2024-11-28 15:02:10,812 - INFO - Epoch 30/50 started
120
+ 2024-11-28 15:02:12,426 - INFO - Average loss for epoch 30: 0.1733
121
+ 2024-11-28 15:02:12,539 - INFO - Test accuracy after epoch 30: 89.21%
122
+ 2024-11-28 15:02:12,543 - INFO - Model saved: models\model_epoch_30.pt
123
+ 2024-11-28 15:02:12,543 - INFO - Epoch 31/50 started
124
+ 2024-11-28 15:02:14,079 - INFO - Average loss for epoch 31: 0.1624
125
+ 2024-11-28 15:02:14,189 - INFO - Test accuracy after epoch 31: 88.59%
126
+ 2024-11-28 15:02:14,194 - INFO - Model saved: models\model_epoch_31.pt
127
+ 2024-11-28 15:02:14,194 - INFO - Epoch 32/50 started
128
+ 2024-11-28 15:02:15,776 - INFO - Average loss for epoch 32: 0.1611
129
+ 2024-11-28 15:02:15,895 - INFO - Test accuracy after epoch 32: 89.11%
130
+ 2024-11-28 15:02:15,899 - INFO - Model saved: models\model_epoch_32.pt
131
+ 2024-11-28 15:02:15,899 - INFO - Epoch 33/50 started
132
+ 2024-11-28 15:02:17,411 - INFO - Average loss for epoch 33: 0.1609
133
+ 2024-11-28 15:02:17,517 - INFO - Test accuracy after epoch 33: 89.32%
134
+ 2024-11-28 15:02:17,522 - INFO - Model saved: models\model_epoch_33.pt
135
+ 2024-11-28 15:02:17,522 - INFO - Epoch 34/50 started
136
+ 2024-11-28 15:02:19,038 - INFO - Average loss for epoch 34: 0.1652
137
+ 2024-11-28 15:02:19,166 - INFO - Test accuracy after epoch 34: 88.87%
138
+ 2024-11-28 15:02:19,170 - INFO - Model saved: models\model_epoch_34.pt
139
+ 2024-11-28 15:02:19,170 - INFO - Epoch 35/50 started
140
+ 2024-11-28 15:02:20,770 - INFO - Average loss for epoch 35: 0.1793
141
+ 2024-11-28 15:02:20,886 - INFO - Test accuracy after epoch 35: 88.66%
142
+ 2024-11-28 15:02:20,891 - INFO - Model saved: models\model_epoch_35.pt
143
+ 2024-11-28 15:02:20,891 - INFO - Epoch 36/50 started
144
+ 2024-11-28 15:02:22,635 - INFO - Average loss for epoch 36: 0.1607
145
+ 2024-11-28 15:02:22,748 - INFO - Test accuracy after epoch 36: 88.90%
146
+ 2024-11-28 15:02:22,753 - INFO - Model saved: models\model_epoch_36.pt
147
+ 2024-11-28 15:02:22,753 - INFO - Epoch 37/50 started
148
+ 2024-11-28 15:02:24,389 - INFO - Average loss for epoch 37: 0.1572
149
+ 2024-11-28 15:02:24,527 - INFO - Test accuracy after epoch 37: 88.90%
150
+ 2024-11-28 15:02:24,534 - INFO - Model saved: models\model_epoch_37.pt
151
+ 2024-11-28 15:02:24,534 - INFO - Epoch 38/50 started
152
+ 2024-11-28 15:02:26,082 - INFO - Average loss for epoch 38: 0.1631
153
+ 2024-11-28 15:02:26,181 - INFO - Test accuracy after epoch 38: 88.17%
154
+ 2024-11-28 15:02:26,185 - INFO - Model saved: models\model_epoch_38.pt
155
+ 2024-11-28 15:02:26,185 - INFO - Epoch 39/50 started
156
+ 2024-11-28 15:02:27,680 - INFO - Average loss for epoch 39: 0.1643
157
+ 2024-11-28 15:02:27,787 - INFO - Test accuracy after epoch 39: 88.62%
158
+ 2024-11-28 15:02:27,791 - INFO - Model saved: models\model_epoch_39.pt
159
+ 2024-11-28 15:02:27,791 - INFO - Epoch 40/50 started
160
+ 2024-11-28 15:02:29,421 - INFO - Average loss for epoch 40: 0.1578
161
+ 2024-11-28 15:02:29,538 - INFO - Test accuracy after epoch 40: 87.96%
162
+ 2024-11-28 15:02:29,542 - INFO - Model saved: models\model_epoch_40.pt
163
+ 2024-11-28 15:02:29,542 - INFO - Epoch 41/50 started
164
+ 2024-11-28 15:02:31,150 - INFO - Average loss for epoch 41: 0.1579
165
+ 2024-11-28 15:02:31,270 - INFO - Test accuracy after epoch 41: 88.69%
166
+ 2024-11-28 15:02:31,275 - INFO - Model saved: models\model_epoch_41.pt
167
+ 2024-11-28 15:02:31,275 - INFO - Epoch 42/50 started
168
+ 2024-11-28 15:02:33,082 - INFO - Average loss for epoch 42: 0.1575
169
+ 2024-11-28 15:02:33,226 - INFO - Test accuracy after epoch 42: 88.52%
170
+ 2024-11-28 15:02:33,231 - INFO - Model saved: models\model_epoch_42.pt
171
+ 2024-11-28 15:02:33,232 - INFO - Epoch 43/50 started
172
+ 2024-11-28 15:02:34,811 - INFO - Average loss for epoch 43: 0.1574
173
+ 2024-11-28 15:02:34,911 - INFO - Test accuracy after epoch 43: 89.11%
174
+ 2024-11-28 15:02:34,916 - INFO - Model saved: models\model_epoch_43.pt
175
+ 2024-11-28 15:02:34,916 - INFO - Epoch 44/50 started
176
+ 2024-11-28 15:02:36,448 - INFO - Average loss for epoch 44: 0.1642
177
+ 2024-11-28 15:02:36,553 - INFO - Test accuracy after epoch 44: 89.04%
178
+ 2024-11-28 15:02:36,557 - INFO - Model saved: models\model_epoch_44.pt
179
+ 2024-11-28 15:02:36,557 - INFO - Epoch 45/50 started
180
+ 2024-11-28 15:02:38,065 - INFO - Average loss for epoch 45: 0.1583
181
+ 2024-11-28 15:02:38,181 - INFO - Test accuracy after epoch 45: 88.41%
182
+ 2024-11-28 15:02:38,185 - INFO - Model saved: models\model_epoch_45.pt
183
+ 2024-11-28 15:02:38,185 - INFO - Epoch 46/50 started
184
+ 2024-11-28 15:02:39,689 - INFO - Average loss for epoch 46: 0.1613
185
+ 2024-11-28 15:02:39,805 - INFO - Test accuracy after epoch 46: 89.21%
186
+ 2024-11-28 15:02:39,809 - INFO - Model saved: models\model_epoch_46.pt
187
+ 2024-11-28 15:02:39,809 - INFO - Epoch 47/50 started
188
+ 2024-11-28 15:02:41,364 - INFO - Average loss for epoch 47: 0.1598
189
+ 2024-11-28 15:02:41,467 - INFO - Test accuracy after epoch 47: 88.55%
190
+ 2024-11-28 15:02:41,471 - INFO - Model saved: models\model_epoch_47.pt
191
+ 2024-11-28 15:02:41,471 - INFO - Epoch 48/50 started
192
+ 2024-11-28 15:02:42,977 - INFO - Average loss for epoch 48: 0.1697
193
+ 2024-11-28 15:02:43,086 - INFO - Test accuracy after epoch 48: 88.48%
194
+ 2024-11-28 15:02:43,090 - INFO - Model saved: models\model_epoch_48.pt
195
+ 2024-11-28 15:02:43,090 - INFO - Epoch 49/50 started
196
+ 2024-11-28 15:02:44,596 - INFO - Average loss for epoch 49: 0.1618
197
+ 2024-11-28 15:02:44,717 - INFO - Test accuracy after epoch 49: 89.25%
198
+ 2024-11-28 15:02:44,721 - INFO - Model saved: models\model_epoch_49.pt
199
+ 2024-11-28 15:02:44,721 - INFO - Epoch 50/50 started
200
+ 2024-11-28 15:02:46,298 - INFO - Average loss for epoch 50: 0.1642
201
+ 2024-11-28 15:02:46,414 - INFO - Test accuracy after epoch 50: 88.97%
202
+ 2024-11-28 15:02:46,418 - INFO - Model saved: models\model_epoch_50.pt
203
+ 2024-11-28 15:02:46,418 - INFO - Training complete.
204
+ 2024-11-28 15:02:46,418 - INFO - Training complete
notebooks/data.ipynb ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import pandas as pd\n",
10
+ "import numpy as np\n",
11
+ "\n",
12
+ "import os\n",
13
+ "os.chdir('..')"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 2,
19
+ "metadata": {},
20
+ "outputs": [
21
+ {
22
+ "data": {
23
+ "text/plain": [
24
+ "['data', 'model', 'notebooks', 'test.ipynb', 'Untitled-1.ipynb', 'вф']"
25
+ ]
26
+ },
27
+ "execution_count": 2,
28
+ "metadata": {},
29
+ "output_type": "execute_result"
30
+ }
31
+ ],
32
+ "source": [
33
+ "os.listdir('.')"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 3,
39
+ "metadata": {},
40
+ "outputs": [
41
+ {
42
+ "data": {
43
+ "text/html": [
44
+ "<div>\n",
45
+ "<style scoped>\n",
46
+ " .dataframe tbody tr th:only-of-type {\n",
47
+ " vertical-align: middle;\n",
48
+ " }\n",
49
+ "\n",
50
+ " .dataframe tbody tr th {\n",
51
+ " vertical-align: top;\n",
52
+ " }\n",
53
+ "\n",
54
+ " .dataframe thead th {\n",
55
+ " text-align: right;\n",
56
+ " }\n",
57
+ "</style>\n",
58
+ "<table border=\"1\" class=\"dataframe\">\n",
59
+ " <thead>\n",
60
+ " <tr style=\"text-align: right;\">\n",
61
+ " <th></th>\n",
62
+ " <th>comment</th>\n",
63
+ " <th>toxic</th>\n",
64
+ " </tr>\n",
65
+ " </thead>\n",
66
+ " <tbody>\n",
67
+ " <tr>\n",
68
+ " <th>0</th>\n",
69
+ " <td>Верблюдов-то за что? Дебилы, бл...</td>\n",
70
+ " <td>1.0</td>\n",
71
+ " </tr>\n",
72
+ " <tr>\n",
73
+ " <th>1</th>\n",
74
+ " <td>Хохлы, это отдушина затюканого россиянина, мол...</td>\n",
75
+ " <td>1.0</td>\n",
76
+ " </tr>\n",
77
+ " <tr>\n",
78
+ " <th>2</th>\n",
79
+ " <td>Собаке - собачья смерть</td>\n",
80
+ " <td>1.0</td>\n",
81
+ " </tr>\n",
82
+ " <tr>\n",
83
+ " <th>3</th>\n",
84
+ " <td>Страницу обнови, дебил. Это тоже не оскорблени...</td>\n",
85
+ " <td>1.0</td>\n",
86
+ " </tr>\n",
87
+ " <tr>\n",
88
+ " <th>4</th>\n",
89
+ " <td>тебя не убедил 6-страничный пдф в том, что Скр...</td>\n",
90
+ " <td>1.0</td>\n",
91
+ " </tr>\n",
92
+ " <tr>\n",
93
+ " <th>...</th>\n",
94
+ " <td>...</td>\n",
95
+ " <td>...</td>\n",
96
+ " </tr>\n",
97
+ " <tr>\n",
98
+ " <th>14407</th>\n",
99
+ " <td>Вонючий совковый скот прибежал и ноет. А вот и...</td>\n",
100
+ " <td>1.0</td>\n",
101
+ " </tr>\n",
102
+ " <tr>\n",
103
+ " <th>14408</th>\n",
104
+ " <td>А кого любить? Гоблина тупорылого что-ли? Или ...</td>\n",
105
+ " <td>1.0</td>\n",
106
+ " </tr>\n",
107
+ " <tr>\n",
108
+ " <th>14409</th>\n",
109
+ " <td>Посмотрел Утомленных солнцем 2. И оказалось, ч...</td>\n",
110
+ " <td>0.0</td>\n",
111
+ " </tr>\n",
112
+ " <tr>\n",
113
+ " <th>14410</th>\n",
114
+ " <td>КРЫМОТРЕД НАРУШАЕТ ПРАВИЛА РАЗДЕЛА Т.К В НЕМ Н...</td>\n",
115
+ " <td>1.0</td>\n",
116
+ " </tr>\n",
117
+ " <tr>\n",
118
+ " <th>14411</th>\n",
119
+ " <td>До сих пор пересматриваю его видео. Орамбо кст...</td>\n",
120
+ " <td>0.0</td>\n",
121
+ " </tr>\n",
122
+ " </tbody>\n",
123
+ "</table>\n",
124
+ "<p>14412 rows × 2 columns</p>\n",
125
+ "</div>"
126
+ ],
127
+ "text/plain": [
128
+ " comment toxic\n",
129
+ "0 Верблюдов-то за что? Дебилы, бл... 1.0\n",
130
+ "1 Хохлы, это отдушина затюканого россиянина, мол... 1.0\n",
131
+ "2 Собаке - собачья смерть 1.0\n",
132
+ "3 Страницу обнови, дебил. Это тоже не оскорблени... 1.0\n",
133
+ "4 тебя не убедил 6-страничный пдф в том, что Скр... 1.0\n",
134
+ "... ... ...\n",
135
+ "14407 Вонючий совковый скот прибежал и ноет. А вот и... 1.0\n",
136
+ "14408 А кого любить? Гоблина тупорылого что-ли? Или ... 1.0\n",
137
+ "14409 Посмотрел Утомленных солнцем 2. И оказалось, ч... 0.0\n",
138
+ "14410 КРЫМОТРЕД НАРУШАЕТ ПРАВИЛА РАЗДЕЛА Т.К В НЕМ Н... 1.0\n",
139
+ "14411 До сих пор пересматриваю его видео. Орамбо кст... 0.0\n",
140
+ "\n",
141
+ "[14412 rows x 2 columns]"
142
+ ]
143
+ },
144
+ "execution_count": 3,
145
+ "metadata": {},
146
+ "output_type": "execute_result"
147
+ }
148
+ ],
149
+ "source": [
150
+ "df = pd.read_csv('data/data.csv')\n",
151
+ "df['comment'] = df['comment'].str.replace('\\n', '', regex=False)\n",
152
+ "df"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 4,
158
+ "metadata": {},
159
+ "outputs": [
160
+ {
161
+ "data": {
162
+ "text/plain": [
163
+ "toxic\n",
164
+ "0.0 9586\n",
165
+ "1.0 4826\n",
166
+ "Name: count, dtype: int64"
167
+ ]
168
+ },
169
+ "execution_count": 4,
170
+ "metadata": {},
171
+ "output_type": "execute_result"
172
+ }
173
+ ],
174
+ "source": [
175
+ "df['toxic'].value_counts()"
176
+ ]
177
+ }
178
+ ],
179
+ "metadata": {
180
+ "kernelspec": {
181
+ "display_name": "Python 3",
182
+ "language": "python",
183
+ "name": "python3"
184
+ },
185
+ "language_info": {
186
+ "codemirror_mode": {
187
+ "name": "ipython",
188
+ "version": 3
189
+ },
190
+ "file_extension": ".py",
191
+ "mimetype": "text/x-python",
192
+ "name": "python",
193
+ "nbconvert_exporter": "python",
194
+ "pygments_lexer": "ipython3",
195
+ "version": "3.12.4"
196
+ }
197
+ },
198
+ "nbformat": 4,
199
+ "nbformat_minor": 2
200
+ }
notebooks/dataset.ipynb ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "a:\\python\\312\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n",
14
+ "a:\\python\\312\\Lib\\site-packages\\transformers\\utils\\hub.py:128: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n",
15
+ " warnings.warn(\n"
16
+ ]
17
+ }
18
+ ],
19
+ "source": [
20
+ "import os\n",
21
+ "os.chdir('..')\n",
22
+ "\n",
23
+ "import torch\n",
24
+ "from scr.sbert import sbert\n",
25
+ "from scr.dataset import TextDataset"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 2,
31
+ "metadata": {},
32
+ "outputs": [
33
+ {
34
+ "name": "stderr",
35
+ "output_type": "stream",
36
+ "text": [
37
+ "C:\\Users\\timof\\AppData\\Local\\Temp\\ipykernel_12548\\753396127.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
38
+ " dataset = torch.load('data/dataset.pt')\n"
39
+ ]
40
+ }
41
+ ],
42
+ "source": [
43
+ "dataset = torch.load('data/dataset.pt')"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": 3,
49
+ "metadata": {},
50
+ "outputs": [
51
+ {
52
+ "name": "stdout",
53
+ "output_type": "stream",
54
+ "text": [
55
+ "Размер датасета: 14412\n",
56
+ "Тексты: tensor([[-0.7709, 0.2756, -1.8136, ..., -0.1891, 0.6464, -0.0877],\n",
57
+ " [ 0.0737, 0.2665, -0.2466, ..., 0.1983, 0.9042, 0.7120],\n",
58
+ " [-0.4836, 0.2575, -0.3310, ..., -0.0648, 0.6074, -0.2436],\n",
59
+ " ...,\n",
60
+ " [ 0.5273, 0.2523, -0.4174, ..., -0.1361, 0.0777, 0.1805],\n",
61
+ " [-0.6573, 0.1075, -1.1338, ..., 0.0145, 0.0062, 0.1264],\n",
62
+ " [ 0.4965, 0.1897, -1.8090, ..., -0.0378, 0.2283, 0.6433]])\n",
63
+ "Метки: tensor([1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1., 0., 0., 0.],\n",
64
+ " dtype=torch.float64)\n"
65
+ ]
66
+ }
67
+ ],
68
+ "source": [
69
+ "from torch.utils.data import DataLoader\n",
70
+ "\n",
71
+ "# Проверяем размер\n",
72
+ "print(f\"Размер датасета: {len(dataset)}\")\n",
73
+ "\n",
74
+ "# Создаем DataLoader\n",
75
+ "dataloader = DataLoader(dataset, batch_size=16, shuffle=True)\n",
76
+ "\n",
77
+ "# Обрабатываем данные в батчах\n",
78
+ "for texts, labels in dataloader:\n",
79
+ " print(\"Тексты:\", texts)\n",
80
+ " print(\"Метки:\", labels)\n",
81
+ " break"
82
+ ]
83
+ }
84
+ ],
85
+ "metadata": {
86
+ "kernelspec": {
87
+ "display_name": "Python 3",
88
+ "language": "python",
89
+ "name": "python3"
90
+ },
91
+ "language_info": {
92
+ "codemirror_mode": {
93
+ "name": "ipython",
94
+ "version": 3
95
+ },
96
+ "file_extension": ".py",
97
+ "mimetype": "text/x-python",
98
+ "name": "python",
99
+ "nbconvert_exporter": "python",
100
+ "pygments_lexer": "ipython3",
101
+ "version": "3.12.4"
102
+ }
103
+ },
104
+ "nbformat": 4,
105
+ "nbformat_minor": 2
106
+ }
notebooks/evaluate.ipynb ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "a:\\python\\312\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n",
14
+ "a:\\python\\312\\Lib\\site-packages\\transformers\\utils\\hub.py:128: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n",
15
+ " warnings.warn(\n"
16
+ ]
17
+ }
18
+ ],
19
+ "source": [
20
+ "import os\n",
21
+ "os.chdir('..')\n",
22
+ "\n",
23
+ "import pandas as pd\n",
24
+ "import torch\n",
25
+ "from tqdm import tqdm\n",
26
+ "\n",
27
+ "tqdm.pandas()\n",
28
+ "\n",
29
+ "from scr.model import Model"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 2,
35
+ "metadata": {},
36
+ "outputs": [
37
+ {
38
+ "name": "stderr",
39
+ "output_type": "stream",
40
+ "text": [
41
+ "C:\\Users\\timof\\AppData\\Local\\Temp\\ipykernel_19588\\3061579655.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
42
+ " model.load_state_dict(torch.load('models/model_epoch_50.pt'))\n"
43
+ ]
44
+ },
45
+ {
46
+ "data": {
47
+ "text/plain": [
48
+ "<All keys matched successfully>"
49
+ ]
50
+ },
51
+ "execution_count": 2,
52
+ "metadata": {},
53
+ "output_type": "execute_result"
54
+ }
55
+ ],
56
+ "source": [
57
+ "model = Model()\n",
58
+ "model.eval()\n",
59
+ "model.load_state_dict(torch.load('models/model_epoch_50.pt'))"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": 3,
65
+ "metadata": {},
66
+ "outputs": [
67
+ {
68
+ "data": {
69
+ "text/html": [
70
+ "<div>\n",
71
+ "<style scoped>\n",
72
+ " .dataframe tbody tr th:only-of-type {\n",
73
+ " vertical-align: middle;\n",
74
+ " }\n",
75
+ "\n",
76
+ " .dataframe tbody tr th {\n",
77
+ " vertical-align: top;\n",
78
+ " }\n",
79
+ "\n",
80
+ " .dataframe thead th {\n",
81
+ " text-align: right;\n",
82
+ " }\n",
83
+ "</style>\n",
84
+ "<table border=\"1\" class=\"dataframe\">\n",
85
+ " <thead>\n",
86
+ " <tr style=\"text-align: right;\">\n",
87
+ " <th></th>\n",
88
+ " <th>comment</th>\n",
89
+ " <th>toxic</th>\n",
90
+ " </tr>\n",
91
+ " </thead>\n",
92
+ " <tbody>\n",
93
+ " <tr>\n",
94
+ " <th>0</th>\n",
95
+ " <td>Верблюдов-то за что? Дебилы, бл...</td>\n",
96
+ " <td>1.0</td>\n",
97
+ " </tr>\n",
98
+ " <tr>\n",
99
+ " <th>1</th>\n",
100
+ " <td>Хохлы, это отдушина затюканого россиянина, мол...</td>\n",
101
+ " <td>1.0</td>\n",
102
+ " </tr>\n",
103
+ " <tr>\n",
104
+ " <th>2</th>\n",
105
+ " <td>Собаке - собачья смерть</td>\n",
106
+ " <td>1.0</td>\n",
107
+ " </tr>\n",
108
+ " <tr>\n",
109
+ " <th>3</th>\n",
110
+ " <td>Страницу обнови, дебил. Это тоже не оскорблени...</td>\n",
111
+ " <td>1.0</td>\n",
112
+ " </tr>\n",
113
+ " <tr>\n",
114
+ " <th>4</th>\n",
115
+ " <td>тебя не убедил 6-страничный пдф в том, что Скр...</td>\n",
116
+ " <td>1.0</td>\n",
117
+ " </tr>\n",
118
+ " <tr>\n",
119
+ " <th>...</th>\n",
120
+ " <td>...</td>\n",
121
+ " <td>...</td>\n",
122
+ " </tr>\n",
123
+ " <tr>\n",
124
+ " <th>14407</th>\n",
125
+ " <td>Вонючий совковый скот прибежал и ноет. А вот и...</td>\n",
126
+ " <td>1.0</td>\n",
127
+ " </tr>\n",
128
+ " <tr>\n",
129
+ " <th>14408</th>\n",
130
+ " <td>А кого любить? Гоблина тупорылого что-ли? Или ...</td>\n",
131
+ " <td>1.0</td>\n",
132
+ " </tr>\n",
133
+ " <tr>\n",
134
+ " <th>14409</th>\n",
135
+ " <td>Посмот��ел Утомленных солнцем 2. И оказалось, ч...</td>\n",
136
+ " <td>0.0</td>\n",
137
+ " </tr>\n",
138
+ " <tr>\n",
139
+ " <th>14410</th>\n",
140
+ " <td>КРЫМОТРЕД НАРУШАЕТ ПРАВИЛА РАЗДЕЛА Т.К В НЕМ Н...</td>\n",
141
+ " <td>1.0</td>\n",
142
+ " </tr>\n",
143
+ " <tr>\n",
144
+ " <th>14411</th>\n",
145
+ " <td>До сих пор пересматриваю его видео. Орамбо кст...</td>\n",
146
+ " <td>0.0</td>\n",
147
+ " </tr>\n",
148
+ " </tbody>\n",
149
+ "</table>\n",
150
+ "<p>14412 rows × 2 columns</p>\n",
151
+ "</div>"
152
+ ],
153
+ "text/plain": [
154
+ " comment toxic\n",
155
+ "0 Верблюдов-то за что? Дебилы, бл... 1.0\n",
156
+ "1 Хохлы, это отдушина затюканого россиянина, мол... 1.0\n",
157
+ "2 Собаке - собачья смерть 1.0\n",
158
+ "3 Страницу обнови, дебил. Это тоже не оскорблени... 1.0\n",
159
+ "4 тебя не убедил 6-страничный пдф в том, что Скр... 1.0\n",
160
+ "... ... ...\n",
161
+ "14407 Вонючий совковый скот прибежал и ноет. А вот и... 1.0\n",
162
+ "14408 А кого любить? Гоблина тупорылого что-ли? Или ... 1.0\n",
163
+ "14409 Посмотрел Утомленных солнцем 2. И оказалось, ч... 0.0\n",
164
+ "14410 КРЫМОТРЕД НАРУШАЕТ ПРАВИЛА РАЗДЕЛА Т.К В НЕМ Н... 1.0\n",
165
+ "14411 До сих пор пересматриваю его видео. Орамбо кст... 0.0\n",
166
+ "\n",
167
+ "[14412 rows x 2 columns]"
168
+ ]
169
+ },
170
+ "execution_count": 3,
171
+ "metadata": {},
172
+ "output_type": "execute_result"
173
+ }
174
+ ],
175
+ "source": [
176
+ "df = pd.read_csv('data/data.csv')\n",
177
+ "df['comment'] = df['comment'].str.replace('\\n', '', regex=False)\n",
178
+ "df"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": 4,
184
+ "metadata": {},
185
+ "outputs": [
186
+ {
187
+ "name": "stderr",
188
+ "output_type": "stream",
189
+ "text": [
190
+ "100%|██████████| 14412/14412 [02:38<00:00, 90.67it/s]\n"
191
+ ]
192
+ },
193
+ {
194
+ "data": {
195
+ "text/html": [
196
+ "<div>\n",
197
+ "<style scoped>\n",
198
+ " .dataframe tbody tr th:only-of-type {\n",
199
+ " vertical-align: middle;\n",
200
+ " }\n",
201
+ "\n",
202
+ " .dataframe tbody tr th {\n",
203
+ " vertical-align: top;\n",
204
+ " }\n",
205
+ "\n",
206
+ " .dataframe thead th {\n",
207
+ " text-align: right;\n",
208
+ " }\n",
209
+ "</style>\n",
210
+ "<table border=\"1\" class=\"dataframe\">\n",
211
+ " <thead>\n",
212
+ " <tr style=\"text-align: right;\">\n",
213
+ " <th></th>\n",
214
+ " <th>comment</th>\n",
215
+ " <th>toxic</th>\n",
216
+ " <th>predict</th>\n",
217
+ " </tr>\n",
218
+ " </thead>\n",
219
+ " <tbody>\n",
220
+ " <tr>\n",
221
+ " <th>0</th>\n",
222
+ " <td>Верблюдов-то за что? Дебилы, бл...</td>\n",
223
+ " <td>1.0</td>\n",
224
+ " <td>1.000</td>\n",
225
+ " </tr>\n",
226
+ " <tr>\n",
227
+ " <th>1</th>\n",
228
+ " <td>Хохлы, это отдушина затюканого россиянина, мол...</td>\n",
229
+ " <td>1.0</td>\n",
230
+ " <td>1.000</td>\n",
231
+ " </tr>\n",
232
+ " <tr>\n",
233
+ " <th>2</th>\n",
234
+ " <td>Собаке - собачья смерть</td>\n",
235
+ " <td>1.0</td>\n",
236
+ " <td>1.000</td>\n",
237
+ " </tr>\n",
238
+ " <tr>\n",
239
+ " <th>3</th>\n",
240
+ " <td>Страницу обнови, дебил. Это тоже не оскорблени...</td>\n",
241
+ " <td>1.0</td>\n",
242
+ " <td>1.000</td>\n",
243
+ " </tr>\n",
244
+ " <tr>\n",
245
+ " <th>4</th>\n",
246
+ " <td>тебя не убедил 6-страничный пдф в том, что Скр...</td>\n",
247
+ " <td>1.0</td>\n",
248
+ " <td>1.000</td>\n",
249
+ " </tr>\n",
250
+ " <tr>\n",
251
+ " <th>...</th>\n",
252
+ " <td>...</td>\n",
253
+ " <td>...</td>\n",
254
+ " <td>...</td>\n",
255
+ " </tr>\n",
256
+ " <tr>\n",
257
+ " <th>14407</th>\n",
258
+ " <td>Вонючий совковый скот прибежал и ноет. А вот и...</td>\n",
259
+ " <td>1.0</td>\n",
260
+ " <td>1.000</td>\n",
261
+ " </tr>\n",
262
+ " <tr>\n",
263
+ " <th>14408</th>\n",
264
+ " <td>А кого любить? Гоблина тупорылого что-ли? Или ...</td>\n",
265
+ " <td>1.0</td>\n",
266
+ " <td>0.487</td>\n",
267
+ " </tr>\n",
268
+ " <tr>\n",
269
+ " <th>14409</th>\n",
270
+ " <td>Посмотрел Утомленных солнцем 2. И оказалось, ч...</td>\n",
271
+ " <td>0.0</td>\n",
272
+ " <td>0.499</td>\n",
273
+ " </tr>\n",
274
+ " <tr>\n",
275
+ " <th>14410</th>\n",
276
+ " <td>КРЫМОТРЕД НАРУШАЕТ ПРАВИЛА РАЗДЕЛА Т.К В НЕМ Н...</td>\n",
277
+ " <td>1.0</td>\n",
278
+ " <td>1.000</td>\n",
279
+ " </tr>\n",
280
+ " <tr>\n",
281
+ " <th>14411</th>\n",
282
+ " <td>До сих пор пересматриваю его видео. Орамбо кст...</td>\n",
283
+ " <td>0.0</td>\n",
284
+ " <td>0.000</td>\n",
285
+ " </tr>\n",
286
+ " </tbody>\n",
287
+ "</table>\n",
288
+ "<p>14412 rows × 3 columns</p>\n",
289
+ "</div>"
290
+ ],
291
+ "text/plain": [
292
+ " comment toxic predict\n",
293
+ "0 Верблюдов-то за что? Дебилы, бл... 1.0 1.000\n",
294
+ "1 Хохлы, это отдушина затюканого россиянина, мол... 1.0 1.000\n",
295
+ "2 Собаке - собачья смерть 1.0 1.000\n",
296
+ "3 Страницу обнови, дебил. Это тоже не оскорблени... 1.0 1.000\n",
297
+ "4 тебя не убедил 6-страничный пдф в том, что Скр... 1.0 1.000\n",
298
+ "... ... ... ...\n",
299
+ "14407 Вонючий совковый скот прибежал и ноет. А вот и... 1.0 1.000\n",
300
+ "14408 А кого любить? Гоблина тупорылого что-ли? Или ... 1.0 0.487\n",
301
+ "14409 Посмотрел Утомленных солнцем 2. И оказалось, ч... 0.0 0.499\n",
302
+ "14410 КРЫМОТРЕД НАРУШАЕТ ПРАВИЛА РАЗДЕЛА Т.К В НЕМ Н... 1.0 1.000\n",
303
+ "14411 До сих пор пересматриваю его видео. Орамбо кст... 0.0 0.000\n",
304
+ "\n",
305
+ "[14412 rows x 3 columns]"
306
+ ]
307
+ },
308
+ "execution_count": 4,
309
+ "metadata": {},
310
+ "output_type": "execute_result"
311
+ }
312
+ ],
313
+ "source": [
314
+ "df['predict'] = df['comment'].progress_apply(lambda x: round(float(model.predict(x).item()), 3))\n",
315
+ "df"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "execution_count": 8,
321
+ "metadata": {},
322
+ "outputs": [
323
+ {
324
+ "data": {
325
+ "text/plain": [
326
+ "True 14056\n",
327
+ "False 356\n",
328
+ "Name: count, dtype: int64"
329
+ ]
330
+ },
331
+ "execution_count": 8,
332
+ "metadata": {},
333
+ "output_type": "execute_result"
334
+ }
335
+ ],
336
+ "source": [
337
+ "(df['toxic'] == df['predict'].apply(round).astype(int)).value_counts()"
338
+ ]
339
+ }
340
+ ],
341
+ "metadata": {
342
+ "kernelspec": {
343
+ "display_name": "Python 3",
344
+ "language": "python",
345
+ "name": "python3"
346
+ },
347
+ "language_info": {
348
+ "codemirror_mode": {
349
+ "name": "ipython",
350
+ "version": 3
351
+ },
352
+ "file_extension": ".py",
353
+ "mimetype": "text/x-python",
354
+ "name": "python",
355
+ "nbconvert_exporter": "python",
356
+ "pygments_lexer": "ipython3",
357
+ "version": "3.12.4"
358
+ }
359
+ },
360
+ "nbformat": 4,
361
+ "nbformat_minor": 2
362
+ }
notebooks/model.ipynb ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "a:\\python\\312\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n",
14
+ "a:\\python\\312\\Lib\\site-packages\\transformers\\utils\\hub.py:128: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n",
15
+ " warnings.warn(\n"
16
+ ]
17
+ }
18
+ ],
19
+ "source": [
20
+ "import os\n",
21
+ "os.chdir('..')\n",
22
+ "\n",
23
+ "import pandas as pd\n",
24
+ "import torch\n",
25
+ "from torch.utils.data import DataLoader, Dataset\n",
26
+ "\n",
27
+ "from scr.dataset import TextDataset\n",
28
+ "from scr.model import Model"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 2,
34
+ "metadata": {},
35
+ "outputs": [
36
+ {
37
+ "name": "stderr",
38
+ "output_type": "stream",
39
+ "text": [
40
+ "C:\\Users\\timof\\AppData\\Local\\Temp\\ipykernel_20804\\3112888757.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
41
+ " dataset = torch.load('data/dataset.pt')\n"
42
+ ]
43
+ }
44
+ ],
45
+ "source": [
46
+ "dataset = torch.load('data/dataset.pt')\n",
47
+ "dataloader = DataLoader(dataset, batch_size=16, shuffle=True, pin_memory=True)"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": 3,
53
+ "metadata": {},
54
+ "outputs": [
55
+ {
56
+ "data": {
57
+ "text/plain": [
58
+ "Model(\n",
59
+ " (model): Sequential(\n",
60
+ " (0): Block(\n",
61
+ " (model): Sequential(\n",
62
+ " (0): Linear(in_features=1024, out_features=512, bias=True)\n",
63
+ " (1): Dropout(p=0.2, inplace=False)\n",
64
+ " )\n",
65
+ " )\n",
66
+ " (1): LeakyReLU(negative_slope=0.01)\n",
67
+ " (2): Block(\n",
68
+ " (model): Sequential(\n",
69
+ " (0): Linear(in_features=512, out_features=256, bias=True)\n",
70
+ " (1): Dropout(p=0.2, inplace=False)\n",
71
+ " )\n",
72
+ " )\n",
73
+ " (3): LeakyReLU(negative_slope=0.01)\n",
74
+ " (4): Block(\n",
75
+ " (model): Sequential(\n",
76
+ " (0): Linear(in_features=256, out_features=128, bias=True)\n",
77
+ " (1): Dropout(p=0.2, inplace=False)\n",
78
+ " )\n",
79
+ " )\n",
80
+ " (5): LeakyReLU(negative_slope=0.01)\n",
81
+ " (6): Block(\n",
82
+ " (model): Sequential(\n",
83
+ " (0): Linear(in_features=128, out_features=64, bias=True)\n",
84
+ " (1): Dropout(p=0.2, inplace=False)\n",
85
+ " )\n",
86
+ " )\n",
87
+ " (7): LeakyReLU(negative_slope=0.01)\n",
88
+ " (8): Block(\n",
89
+ " (model): Sequential(\n",
90
+ " (0): Linear(in_features=64, out_features=1, bias=True)\n",
91
+ " (1): Dropout(p=0.2, inplace=False)\n",
92
+ " )\n",
93
+ " )\n",
94
+ " (9): Sigmoid()\n",
95
+ " )\n",
96
+ ")"
97
+ ]
98
+ },
99
+ "execution_count": 3,
100
+ "metadata": {},
101
+ "output_type": "execute_result"
102
+ }
103
+ ],
104
+ "source": [
105
+ "model = Model()\n",
106
+ "model.eval()"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": 4,
112
+ "metadata": {},
113
+ "outputs": [
114
+ {
115
+ "name": "stderr",
116
+ "output_type": "stream",
117
+ "text": [
118
+ "C:\\Users\\timof\\AppData\\Local\\Temp\\ipykernel_20804\\3887862913.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
119
+ " model.load_state_dict(torch.load('models/model_epoch_50.pt'))\n"
120
+ ]
121
+ },
122
+ {
123
+ "data": {
124
+ "text/plain": [
125
+ "<All keys matched successfully>"
126
+ ]
127
+ },
128
+ "execution_count": 4,
129
+ "metadata": {},
130
+ "output_type": "execute_result"
131
+ }
132
+ ],
133
+ "source": [
134
+ "model.load_state_dict(torch.load('models/model_epoch_50.pt'))"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": []
143
+ }
144
+ ],
145
+ "metadata": {
146
+ "kernelspec": {
147
+ "display_name": "Python 3",
148
+ "language": "python",
149
+ "name": "python3"
150
+ },
151
+ "language_info": {
152
+ "codemirror_mode": {
153
+ "name": "ipython",
154
+ "version": 3
155
+ },
156
+ "file_extension": ".py",
157
+ "mimetype": "text/x-python",
158
+ "name": "python",
159
+ "nbconvert_exporter": "python",
160
+ "pygments_lexer": "ipython3",
161
+ "version": "3.12.4"
162
+ }
163
+ },
164
+ "nbformat": 4,
165
+ "nbformat_minor": 2
166
+ }
scr/__init__.py ADDED
File without changes
scr/dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader, Dataset
2
+ import torch
3
+ from scr.sbert import vectorize as vec
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+
7
+ tqdm.pandas()
8
+
9
+ df = pd.read_csv('data/data.csv')
10
+
11
+ class TextDataset(Dataset):
12
+ def __init__(self, df):
13
+ self.vectors = torch.stack(list(df['comment'].progress_apply(lambda x: vec(x).squeeze(0))))
14
+ self.labels = torch.tensor(df['toxic'].values)
15
+
16
+ def __getitem__(self, index):
17
+ return self.vectors[index], self.labels[index]
18
+
19
+ def __len__(self):
20
+ return len(self.labels)
21
+
22
+ if __name__ == '__main__':
23
+ # Сохраняем векторизованный датасет
24
+ dataset = TextDataset(df)
25
+ torch.save(dataset, 'data/dataset.pt')
26
+
27
+ dataloader = DataLoader(dataset, batch_size=16, shuffle=True, pin_memory=True)
28
+
29
+ # Загружаем данные с DataLoader
30
+ for texts, labels in dataloader:
31
+ print(texts, labels)
32
+ break
scr/model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .sbert import vectorize as vec
4
+
5
+ class Block(nn.Module):
6
+ def __init__(self, input_dim, output_dim):
7
+ super().__init__()
8
+ self.model = nn.Sequential(
9
+ nn.Linear(input_dim, output_dim),
10
+ nn.Dropout(0.2),
11
+ )
12
+
13
+ def forward(self, x):
14
+ return self.model(x)
15
+
16
+ class Model(nn.Module):
17
+ def __init__(self):
18
+ super().__init__()
19
+ self.model = nn.Sequential(
20
+ Block(1024, 512),
21
+ nn.LeakyReLU(),
22
+
23
+ Block(512, 256),
24
+ nn.LeakyReLU(),
25
+
26
+ Block(256, 128),
27
+ nn.LeakyReLU(),
28
+
29
+ Block(128, 64),
30
+ nn.LeakyReLU(),
31
+
32
+ Block(64, 1),
33
+ nn.Sigmoid(),
34
+ )
35
+
36
+ def forward(self, x):
37
+ return self.model(x)
38
+
39
+ def predict(self, text):
40
+ return self(vec(text))
scr/sbert.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ import torch
3
+
4
+ # Определяем доступное устройство
5
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
+
7
+ # Mean Pooling - Учитывает attention mask для корректного усреднения
8
+ def mean_pooling(model_output, attention_mask):
9
+ token_embeddings = model_output[0] # Эмбеддинги токенов
10
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
11
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
12
+ sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
13
+ return sum_embeddings / sum_mask
14
+
15
+ # Загрузка модели и токенизатора с HuggingFace
16
+ tokenizer = AutoTokenizer.from_pretrained("ai-forever/sbert_large_mt_nlu_ru")
17
+ sbert = AutoModel.from_pretrained("ai-forever/sbert_large_mt_nlu_ru").to(device) # Перенос модели на устройство
18
+
19
+ def vectorize(texts, batch_size=32):
20
+ if isinstance(texts, str):
21
+ texts = [texts] # Если передана строка, оборачиваем её в список
22
+
23
+ embeddings = []
24
+ for i in range(0, len(texts), batch_size):
25
+ batch = texts[i:i + batch_size]
26
+ encoded_input = tokenizer(
27
+ batch,
28
+ padding=True,
29
+ truncation=True,
30
+ max_length=64,
31
+ return_tensors='pt'
32
+ ).to(device)
33
+
34
+ with torch.no_grad():
35
+ model_output = sbert(**encoded_input)
36
+ batch_embeddings = mean_pooling(model_output, encoded_input['attention_mask']).cpu()
37
+ embeddings.append(batch_embeddings)
38
+
39
+ # Конкатенируем батчи и убираем лишнее измерение
40
+ return torch.cat(embeddings, dim=0)
41
+
scr/train.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import DataLoader, random_split
4
+ from tqdm import tqdm
5
+ import logging
6
+
7
+ from scr.sbert import sbert
8
+ from scr.dataset import TextDataset
9
+ from scr.model import Model
10
+
11
+ # Device configuration
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ # Training parameters
15
+ EPOCHS = 50
16
+ BATCH_SIZE = 16
17
+ LEARNING_RATE = 1e-3
18
+ MODEL_SAVE_DIR = "models"
19
+ LOG_DIR = "models"
20
+ TEST_SPLIT = 0.2 # 80/20 split
21
+
22
+ # Ensure directories exist
23
+ os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
24
+ os.makedirs(LOG_DIR, exist_ok=True)
25
+
26
+ # Setup logging
27
+ log_file = os.path.join(LOG_DIR, "training.log")
28
+ logging.basicConfig(
29
+ filename=log_file,
30
+ filemode="w",
31
+ level=logging.INFO,
32
+ format="%(asctime)s - %(levelname)s - %(message)s"
33
+ )
34
+ logger = logging.getLogger()
35
+
36
+ # Load dataset
37
+ full_dataset = torch.load('data/dataset.pt')
38
+
39
+ # Split dataset into training and testing sets
40
+ train_size = int((1 - TEST_SPLIT) * len(full_dataset))
41
+ test_size = len(full_dataset) - train_size
42
+ train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
43
+
44
+ train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
45
+ test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)
46
+
47
+ # Initialize model
48
+ model = Model().to(device)
49
+
50
+ # Loss function and optimizer
51
+ criterion = torch.nn.BCELoss()
52
+ optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
53
+
54
+ # Accuracy calculation
55
+ def calculate_accuracy(model, dataloader, device):
56
+ model.eval()
57
+ correct = 0
58
+ total = 0
59
+ with torch.no_grad():
60
+ for texts, labels in dataloader:
61
+ texts, labels = texts.to(device), labels.to(device, dtype=torch.float)
62
+ outputs = model(texts).squeeze(1)
63
+ predictions = (outputs > 0.5).float() # Binary classification threshold
64
+ correct += (predictions == labels).sum().item()
65
+ total += labels.size(0)
66
+ return 100 * correct / total if total > 0 else 0
67
+
68
+ # Training function
69
+ def train_model(model, train_dataloader, test_dataloader, criterion, optimizer, epochs, device):
70
+ for epoch in range(epochs):
71
+ model.train()
72
+ epoch_loss = 0.0
73
+ logger.info(f"Epoch {epoch + 1}/{epochs} started")
74
+
75
+ # Training loop
76
+ for texts, labels in tqdm(
77
+ train_dataloader,
78
+ desc=f"Epoch {epoch + 1}/{epochs}",
79
+ ncols=12 # Limit progress bar width to 12 characters
80
+ ):
81
+ texts, labels = texts.to(device), labels.to(device, dtype=torch.float)
82
+
83
+ # Reset gradients
84
+ optimizer.zero_grad()
85
+
86
+ # Forward pass
87
+ outputs = model(texts).squeeze(1)
88
+ loss = criterion(outputs, labels)
89
+
90
+ # Backward pass
91
+ loss.backward()
92
+ optimizer.step()
93
+
94
+ # Accumulate loss
95
+ epoch_loss += loss.item()
96
+
97
+ # Log average loss for the epoch
98
+ avg_loss = epoch_loss / len(train_dataloader)
99
+ logger.info(f"Average loss for epoch {epoch + 1}: {avg_loss:.4f}")
100
+
101
+ # Evaluate on the test set
102
+ accuracy = calculate_accuracy(model, test_dataloader, device)
103
+ logger.info(f"Test accuracy after epoch {epoch + 1}: {accuracy:.2f}%")
104
+
105
+ # Save model
106
+ model_path = os.path.join(MODEL_SAVE_DIR, f"model_epoch_{epoch + 1}.pt")
107
+ torch.save(model.state_dict(), model_path)
108
+ logger.info(f"Model saved: {model_path}")
109
+
110
+ logger.info("Training complete.")
111
+
112
+ # Run training
113
+ if __name__ == "__main__":
114
+ logger.info("Training started")
115
+ logger.info(f"Training set size: {len(train_dataset)}, Test set size: {len(test_dataset)}")
116
+ train_model(model, train_dataloader, test_dataloader, criterion, optimizer, EPOCHS, device)
117
+ logger.info("Training complete")