Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import os | |
import re | |
def load_model(model_name): | |
token = os.getenv("HG_TOKEN") | |
if not token: | |
raise ValueError("Hugging Face API token not found. Please set HG_TOKEN environment variable.") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, use_auth_token=token).to(device).eval() | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token) | |
return model, tokenizer, device | |
import re | |
def clean_text(text): | |
text = text.strip() | |
text = text.replace('\n', ' ') | |
text = re.sub(r'[^\w\s,.]', '', text, flags=re.UNICODE) # Оставляем только буквы, пробелы, точки и запятые | |
text = re.sub(r'\d+', '', text) # Удаляем цифры | |
if re.search(r'[а-яА-Я]', text): | |
text = re.sub(r'\b(?!@|https?://|www\.)[a-zA-Z]+\b', '', text) # Удаляем латинские слова | |
text = re.sub(r'\s+', ' ', text).strip() | |
return text.lower() | |
def classify_message(model, tokenizer, device, message): | |
encoding = tokenizer(message, padding=True, truncation=False, return_tensors='pt') | |
input_ids = encoding['input_ids'].to(device) | |
attention_mask = encoding['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask).logits | |
pred = torch.sigmoid(outputs).cpu().numpy()[0][0] | |
return pred | |
def spam_classifier_interface(message, model_choice): | |
if model_choice == "bert": | |
model1_name = "ru-spam/ruSpamNS_v9_Detector" | |
model2_name = "ru-spam/ruSpamNS_v9_Precision" | |
elif model_choice == "tinybert": | |
model1_name = "ru-spam/ruSpamNS_v9_Detector_tiny" | |
model2_name = "ru-spam/ruSpamNS_v9_Precision_tiny" | |
elif model_choice == "v10": | |
model1_name = "NeuroSpaceX/ruSpamNS_v10" | |
model2_name = "NeuroSpaceX/ruSpamNS_v10" | |
else: | |
return "Ошибка: неверный выбор модели", None | |
detector_model, detector_tokenizer, device = load_model(model1_name) | |
precision_model, precision_tokenizer, _ = load_model(model2_name) | |
message = clean_text(message) | |
detector_prob = classify_message(detector_model, detector_tokenizer, device, message) | |
result_message = f"Детектор: вероятность спама {detector_prob:.2f}\n" | |
if 0.5 <= detector_prob <= 0.8: | |
precision_prob = classify_message(precision_model, precision_tokenizer, device, message) | |
result_message += "Модель сомневается. Используется уточняющая модель.\n" | |
result_message += f"Уточнение: вероятность спама {precision_prob:.2f}\n" | |
final_result = "Спам" if precision_prob >= 0.5 else "Не спам" | |
else: | |
final_result = "Спам" if detector_prob >= 0.5 else "Не спам" | |
result_message += f"Итог: {final_result}" | |
return result_message, gr.update(value=detector_prob, maximum=1.0) | |
interface = gr.Interface( | |
fn=spam_classifier_interface, | |
inputs=[ | |
gr.Textbox(label="Введите сообщение для классификации", placeholder="Введите текст...", lines=3, elem_id="input-text"), | |
gr.Radio(["bert", "tinybert", "v10"], label="Выберите модель", elem_id="model-choice") | |
], | |
outputs=[ | |
gr.Textbox(label="Результат", placeholder="Результат появится здесь...", elem_id="result-text"), | |
gr.Slider(label="Вероятность детектора", minimum=0.0, maximum=1.0, step=0.01, interactive=False, elem_id="detector-slider") | |
], | |
title="Классификатор Спам/Не Спам", | |
description="Классифицируйте сообщения как спам или не спам, используя модели bert, tinybert или v10. Если детектор сомневается (50-80%), используется уточняющая модель для окончательного вердикта.", | |
theme="default", | |
css=""" | |
#input-text { | |
border: 2px solid #4CAF50; | |
border-radius: 5px; | |
padding: 10px; | |
} | |
#model-choice { | |
font-size: 16px; | |
color: #333; | |
} | |
#result-text { | |
font-weight: bold; | |
color: #333; | |
} | |
#detector-slider { | |
background-color: #f1f1f1; | |
border-radius: 10px; | |
} | |
""" | |
) | |
interface.launch() | |