import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import os
import numpy as np
model_path = 'srcs/model_modify.pth'
# токенизатор
tokenizer = AutoTokenizer.from_pretrained('cointegrated/rubert-tiny-toxicity')
model = AutoModelForSequenceClassification.from_pretrained('cointegrated/rubert-tiny-toxicity', num_labels=1, ignore_mismatched_sizes=True)
# весов модифицированной модели
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
image = Image.open("media/oritoxic.jpg")
df = pd.read_csv("media/Toxic_labeled.csv")
loss_values = [0.4063596375772262, 0.402279906166038, 0.3998144585561736, 0.39567733055365567,
0.3921396666608141, 0.38956182373070186, 0.3866641920902114, 0.3879134839351564,
0.38288725781591604, 0.38198364493999004]
#Боковая панель
selected_option = st.sidebar.selectbox("Выберите из списка", ["Определение токсичность текста", "Информация о датасете", "Информация о модели"])
#st.title("Главная страница")
if selected_option == "Определение токсичность текста":
st.markdown("
Приложение для определения токсичности текста
",
unsafe_allow_html=True)
st.image(image, use_column_width=True)
user_input = st.text_area("")
# Функция предсказания токсичности
def predict_toxicity(text):
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
probability = torch.sigmoid(logits).item()
prediction = "токсичный" if probability >= 0.5 else "не токсичный"
return prediction, probability
# Тык на кнопу
if st.button("Оценить токсичность"):
if user_input:
prediction, toxicity_probability = predict_toxicity(user_input)
st.write(f'Вероятность токсичности: {toxicity_probability:.4f}')
# Прогресс бар
if 'toxicity_probability' in locals():
progress_percentage = int(toxicity_probability * 100)
progress_bar_color = f'linear-gradient(to right, rgba(0, 0, 255, 0.5) {progress_percentage}%, rgba(255, 0, 0, 0.5) {progress_percentage}%)'
st.markdown(f'',
unsafe_allow_html=True)
elif selected_option == "Информация о датасете":
st.header("Информация о датасете:")
st.dataframe(df.head())
st.write(f"Объем выборки: 14412")
st.subheader("Баланс классов в датасете:")
st.write(f"Количество записей в классе 0.0: {len(df[df['toxic'] == 0.0])}")
st.write(f"Количество записей в классе 1.0: {len(df[df['toxic'] == 1.0])}")
fig, ax = plt.subplots()
df['toxic'].value_counts().plot(kind='bar', ax=ax, color=['skyblue', 'orange'])
ax.set_xticklabels(['Не токсичный', 'Токсичный'], rotation=0)
ax.set_xlabel('Класс')
ax.set_ylabel('Количество записей')
ax.set_title('Распределение по классам')
st.pyplot(fig)
elif selected_option == "Информация о модели":
st.subheader("Информация о модели:")
st.write(f"Модель: Rubert tiny toxicity")
st.subheader("Информация о процессе обучения")
# график лосса
#st.subheader("График потерь в процессе обучения")
#st.line_chart([0.5181976270121774, 0.4342067330899996, 0.41386983832460666]) # Замените данными из ваших эпох
for epoch, loss in enumerate(loss_values, start=1):
st.write(f"Epoch {epoch}/{len(loss_values)}, Loss: {loss}
", unsafe_allow_html=True)
st.markdown(
"""
Количество эпох: 10
Размер батча: 8
Оптимизатор: Adam
Функция потерь: BCEWithLogitsLoss
learning rate: 0.00001
""",
unsafe_allow_html=True
)
st.subheader("Метрики модели:")
st.write(f"Accuracy: {0.8366:.4f}")
st.write(f"Precision: {0.8034:.4f}")
st.write(f"Recall: {0.6777:.4f}")
st.write(f"F1 Score: {0.7352:.4f}")
st.subheader("Код")
bert_model_code = """
model = BertModel(
embeddings=BertEmbeddings(
word_embeddings=Embedding(29564, 312, padding_idx=0),
position_embeddings=Embedding(512, 312),
token_type_embeddings=Embedding(2, 312),
LayerNorm=LayerNorm((312,), eps=1e-12, elementwise_affine=True),
dropout=Dropout(p=0.1, inplace=False),
),
encoder=BertEncoder(
layer=ModuleList(
BertLayer(
attention=BertAttention(
self=BertSelfAttention(
query=Linear(in_features=312, out_features=312, bias=True),
key=Linear(in_features=312, out_features=312, bias=True),
value=Linear(in_features=312, out_features=312, bias=True),
dropout=Dropout(p=0.1, inplace=False),
),
output=BertSelfOutput(
dense=Linear(in_features=312, out_features=312, bias=True),
LayerNorm=LayerNorm((312,), eps=1e-12, elementwise_affine=True),
dropout=Dropout(p=0.1, inplace=False),
),
),
intermediate=BertIntermediate(
dense=Linear(in_features=312, out_features=600, bias=True),
intermediate_act_fn=GELUActivation(),
),
output=BertOutput(
dense=Linear(in_features=600, out_features=312, bias=True),
LayerNorm=LayerNorm((312,), eps=1e-12, elementwise_affine=True),
dropout=Dropout(p=0.1, inplace=False),
),
)
)
),
pooler=BertPooler(
dense=Linear(in_features=312, out_features=312, bias=True),
activation=Tanh(),
),
dropout=Dropout(p=0.1, inplace=False),
classifier=Linear(in_features=312, out_features=1, bias=True),
)
"""
# Отображение кода в Streamlit
st.code(bert_model_code, language="python")