File size: 3,410 Bytes
961ee03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import streamlit as st
import torch
import sys
from pathlib import Path
import time
import numpy as np
from transformers import AutoTokenizer


st.write("# Оценка степени токсичности пользовательского сообщения")
# st.write("Здесь вы можете загрузить картинку со своего устройства, либо при помощи ссылки")

# Добавление пути к проекту и моделям
project_root = Path(__file__).resolve().parents[1]
models_path = project_root / 'models'
sys.path.append(str(models_path))
from models.model2.preprocess_text import TextPreprocessorBERT
from models.model2.model import BERTClassifier

device = 'cpu'

# Загрузка модели и словаря
@st.cache_resource
def load_model():
    model = BERTClassifier()
    weights_path = models_path / 'model2' / 'model_weights_new.pth'
    state_dict = torch.load(weights_path, map_location=device)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    return model

@st.cache_resource
def load_tokenizer():
    return AutoTokenizer.from_pretrained('cointegrated/rubert-tiny-toxicity')

model = load_model()
tokenizer = load_tokenizer()

input_text = st.text_area('Введите текст сообщения')

if st.button('Предсказать'):
    start_time = time.time()
    # Применяем предобработку
    preprocessor = TextPreprocessorBERT()
    preprocessed_text = preprocessor.transform(input_text)
    
    # Токенизация
    tokens = tokenizer.encode_plus(
        preprocessed_text,
        add_special_tokens=True,
        truncation=True,
        max_length=100,
        padding='max_length',
        return_tensors='pt'
    )
    
    # Получаем input_ids и attention_mask из токенов
    input_ids = tokens['input_ids'].to(device)
    attention_mask = tokens['attention_mask'].to(device)
    
    # Предсказание
    with torch.no_grad():
        output = model(input_ids, attention_mask=attention_mask)
    
    # Интерпретация результата
    prediction = torch.sigmoid(output).item()
    end_time = time.time()  # Останавливаем таймер
    execution_time = end_time - start_time  
    if prediction > 0.5:
        class_pred = 'TOXIC'
    else:
        class_pred = 'healthy'
    st.subheader(f'Предсказанный класс токсичности: **{class_pred}** с вероятностью {prediction:.4f}')
    # st.write(f'Предсказанный класс токсичности: {prediction:.4f}')
    st.write(f'Время выполнения: {execution_time:.4f} секунд')



# Информация о первой модели
st.write("# Информация об обучении модели rubert-tiny-toxicity:")
st.write("Модель обучалась на предсказание 1 класса")
st.write("Размер датасета - 14412 текстов сообщений")
st.write("Проведена предобработка текста")

st.image(str(project_root / 'images/2_rubert_metrics.png'), width=1000)
st.write("Время обучения модели - 50 эпох")
st.write("Метрики на 50 эпохе:")
st.write("Train f1: 0.73, Val f1: 0.77")
st.write("Train acc: 0.73, Val acc: 0.74")