File size: 4,754 Bytes
bed6925
 
 
3230342
721e8e1
bed6925
 
3230342
 
 
bed6925
8a55818
3230342
bed6925
 
235d32d
 
988cfb6
 
 
235d32d
 
 
 
 
988cfb6
721e8e1
235d32d
fac56da
0a9e16a
bed6925
 
 
 
 
988cfb6
bed6925
7546ea6
 
 
 
 
 
 
de1e010
 
 
7546ea6
 
de1e010
7546ea6
 
988cfb6
7546ea6
de1e010
7546ea6
 
 
 
 
 
 
 
de1e010
5746974
7546ea6
5746974
 
 
7546ea6
 
de1e010
7546ea6
5746974
 
7546ea6
5746974
7546ea6
de1e010
5746974
 
 
 
 
 
 
7546ea6
 
 
 
5746974
 
 
 
7546ea6
5746974
 
 
 
 
 
de1e010
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import os
import re

def load_model(model_name):
    token = os.getenv("HG_TOKEN")
    if not token:
        raise ValueError("Hugging Face API token not found. Please set HG_TOKEN environment variable.")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, use_auth_token=token).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
    return model, tokenizer, device

import re

def clean_text(text):
    text = text.strip()
    text = text.replace('\n', ' ')
    text = re.sub(r'[^\w\s,.]', '', text, flags=re.UNICODE)  # Оставляем только буквы, пробелы, точки и запятые
    text = re.sub(r'\d+', '', text)  # Удаляем цифры
    if re.search(r'[а-яА-Я]', text):  
        text = re.sub(r'\b(?!@|https?://|www\.)[a-zA-Z]+\b', '', text)  # Удаляем латинские слова
    text = re.sub(r'\s+', ' ', text).strip()  
    return text.lower()


def classify_message(model, tokenizer, device, message):
    encoding = tokenizer(message, padding=True, truncation=False, return_tensors='pt')
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask).logits
        pred = torch.sigmoid(outputs).cpu().numpy()[0][0]
    return pred

def spam_classifier_interface(message, model_choice):
    if model_choice == "bert":
        model1_name = "ru-spam/ruSpamNS_v9_Detector"
        model2_name = "ru-spam/ruSpamNS_v9_Precision"
    elif model_choice == "tinybert":
        model1_name = "ru-spam/ruSpamNS_v9_Detector_tiny"
        model2_name = "ru-spam/ruSpamNS_v9_Precision_tiny"
    elif model_choice == "v10":
        model1_name = "NeuroSpaceX/ruSpamNS_v10"
        model2_name = "NeuroSpaceX/ruSpamNS_v10"
    else:
        return "Ошибка: неверный выбор модели", None
    
    detector_model, detector_tokenizer, device = load_model(model1_name)
    precision_model, precision_tokenizer, _ = load_model(model2_name)
    message = clean_text(message)
    detector_prob = classify_message(detector_model, detector_tokenizer, device, message)
    
    result_message = f"Детектор: вероятность спама {detector_prob:.2f}\n"
    if 0.5 <= detector_prob <= 0.8:
        precision_prob = classify_message(precision_model, precision_tokenizer, device, message)
        result_message += "Модель сомневается. Используется уточняющая модель.\n"
        result_message += f"Уточнение: вероятность спама {precision_prob:.2f}\n"
        final_result = "Спам" if precision_prob >= 0.5 else "Не спам"
    else:
        final_result = "Спам" if detector_prob >= 0.5 else "Не спам"
    
    result_message += f"Итог: {final_result}"
    return result_message, gr.update(value=detector_prob, maximum=1.0)

interface = gr.Interface(
    fn=spam_classifier_interface,
    inputs=[
        gr.Textbox(label="Введите сообщение для классификации", placeholder="Введите текст...", lines=3, elem_id="input-text"),
        gr.Radio(["bert", "tinybert", "v10"], label="Выберите модель", elem_id="model-choice")
    ],
    outputs=[
        gr.Textbox(label="Результат", placeholder="Результат появится здесь...", elem_id="result-text"),
        gr.Slider(label="Вероятность детектора", minimum=0.0, maximum=1.0, step=0.01, interactive=False, elem_id="detector-slider")
    ],
    title="Классификатор Спам/Не Спам",
    description="Классифицируйте сообщения как спам или не спам, используя модели bert, tinybert или v10. Если детектор сомневается (50-80%), используется уточняющая модель для окончательного вердикта.",
    theme="default",
    css="""
        #input-text {
            border: 2px solid #4CAF50;
            border-radius: 5px;
            padding: 10px;
        }
        #model-choice {
            font-size: 16px;
            color: #333;
        }
        #result-text {
            font-weight: bold;
            color: #333;
        }
        #detector-slider {
            background-color: #f1f1f1;
            border-radius: 10px;
        }
    """
)

interface.launch()