nlp-bert-team / pages /2_comments.py
VerVelVel's picture
images
961ee03
import streamlit as st
import torch
import sys
from pathlib import Path
import time
import numpy as np
from transformers import AutoTokenizer
st.write("# Оценка степени токсичности пользовательского сообщения")
# st.write("Здесь вы можете загрузить картинку со своего устройства, либо при помощи ссылки")
# Добавление пути к проекту и моделям
project_root = Path(__file__).resolve().parents[1]
models_path = project_root / 'models'
sys.path.append(str(models_path))
from models.model2.preprocess_text import TextPreprocessorBERT
from models.model2.model import BERTClassifier
device = 'cpu'
# Загрузка модели и словаря
@st.cache_resource
def load_model():
model = BERTClassifier()
weights_path = models_path / 'model2' / 'model_weights_new.pth'
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
@st.cache_resource
def load_tokenizer():
return AutoTokenizer.from_pretrained('cointegrated/rubert-tiny-toxicity')
model = load_model()
tokenizer = load_tokenizer()
input_text = st.text_area('Введите текст сообщения')
if st.button('Предсказать'):
start_time = time.time()
# Применяем предобработку
preprocessor = TextPreprocessorBERT()
preprocessed_text = preprocessor.transform(input_text)
# Токенизация
tokens = tokenizer.encode_plus(
preprocessed_text,
add_special_tokens=True,
truncation=True,
max_length=100,
padding='max_length',
return_tensors='pt'
)
# Получаем input_ids и attention_mask из токенов
input_ids = tokens['input_ids'].to(device)
attention_mask = tokens['attention_mask'].to(device)
# Предсказание
with torch.no_grad():
output = model(input_ids, attention_mask=attention_mask)
# Интерпретация результата
prediction = torch.sigmoid(output).item()
end_time = time.time() # Останавливаем таймер
execution_time = end_time - start_time
if prediction > 0.5:
class_pred = 'TOXIC'
else:
class_pred = 'healthy'
st.subheader(f'Предсказанный класс токсичности: **{class_pred}** с вероятностью {prediction:.4f}')
# st.write(f'Предсказанный класс токсичности: {prediction:.4f}')
st.write(f'Время выполнения: {execution_time:.4f} секунд')
# Информация о первой модели
st.write("# Информация об обучении модели rubert-tiny-toxicity:")
st.write("Модель обучалась на предсказание 1 класса")
st.write("Размер датасета - 14412 текстов сообщений")
st.write("Проведена предобработка текста")
st.image(str(project_root / 'images/2_rubert_metrics.png'), width=1000)
st.write("Время обучения модели - 50 эпох")
st.write("Метрики на 50 эпохе:")
st.write("Train f1: 0.73, Val f1: 0.77")
st.write("Train acc: 0.73, Val acc: 0.74")