File size: 6,004 Bytes
3793db4
 
c583103
3793db4
 
dd706fe
 
39ccd13
3793db4
 
 
3adc964
3793db4
 
 
 
 
 
 
 
77e2bb8
 
 
 
 
 
9dccda9
77e2bb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3793db4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b15f954
3793db4
 
 
 
 
 
 
665b792
b15f954
 
 
3793db4
b15f954
783873f
 
3793db4
b15f954
3793db4
 
b15f954
 
3793db4
b15f954
3793db4
 
 
 
 
 
 
 
b15f954
3793db4
c8f1793
b15f954
 
 
d0b7fe0
c8f1793
b15f954
3793db4
 
783873f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3793db4
 
 
 
783873f
 
 
3793db4
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
import torch
import torch.nn.functional as F
import re
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
import os
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import urllib.request


# Modeļu inicializācija
model_names = ["distilbert-base-uncased", "prajjwal1/bert-tiny", "roberta-base", "google/mobilebert-uncased", "albert-base-v2", "xlm-roberta-base"]

models = {}
tokenizers = {}

# === Modelis → URL ===
model_urls = {
    "best_model_albert-base-v2.pth": "https://www.dropbox.com/scl/fi/adulme5xarg6hgxbs26fm/best_model_albert-base-v2.pth?rlkey=y17x3sw1frk83yfzt8zc00458&st=43uha18d&dl=1",
    "best_model_distilbert-base-uncased.pth": "https://www.dropbox.com/scl/fi/8y3oyfbzmbmn427e1ei3d/best_model_distilbert-base-uncased.pth?rlkey=u9rd40tdd3p781r4xtv8wi5t6&st=nfzq7x8j&dl=1",
    "best_model_google_mobilebert-uncased.pth": "https://www.dropbox.com/scl/fi/7zdarid2no1fw0b8hk0tf/best_model_google_mobilebert-uncased.pth?rlkey=w13j1jampxlt8himivj090nwv&st=0zq6yofp&dl=1",
    "best_model_prajjwal1_bert-tiny.pth": "https://www.dropbox.com/scl/fi/vscwewy4uo58o7xswokxt/best_model_prajjwal1_bert-tiny.pth?rlkey=uav8aas7fxb5nl2w5iacg1qyb&st=12mzggan&dl=1",
    "best_model_roberta-base.pth": "https://www.dropbox.com/scl/fi/zqmlzt0q6knjv096yswsr/best_model_roberta-base.pth?rlkey=hi8ddi23dnz45xt3jomxq0pek&st=2axjymyt&dl=1",
    "best_model_xlm-roberta-base.pth": "https://www.dropbox.com/scl/fi/2gao9iqesou9kb633vvan/best_model_xlm-roberta-base.pth?rlkey=acyvwt8qtle8wzle5idfo8241&st=8livizox&dl=1",
}


# === Lejupielādē modeļus, ja nav ===
for filename, url in model_urls.items():
    if not os.path.exists(filename):
        print(f"Lejupielādē: {filename}")
        try:
            urllib.request.urlretrieve(url, filename)
            print(f"  → Saglabāts: {filename}")
        except Exception as e:
            print(f"  [!] Kļūda lejupielādējot {filename}: {e}")


for model_name in model_names:
    # Tokenizators
    tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name, max_length=512)

    # Modelis ar 3 klasēm
    models[model_name] = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)

    model_file_name = re.sub(r'/', '_', model_name)
    models[model_name].load_state_dict(torch.load(f"best_model_{model_file_name}.pth", map_location=torch.device('cpu')))
    
    # Uz ierīces
    models[model_name] = models[model_name].to('cpu')
    models[model_name].eval()

# Label mapping
labels = {0: "Safe", 1: "Spam", 2: "Phishing"}

lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words('english'))

def preprocess(text):
    text = text.lower()  # Teksta pārveide atmetot lielos burtus
    text = re.sub(r'http\S+', '', text)  # URL atmešana
    text = re.sub(r"[^a-z']", ' ', text)  # atmet simbolus, kas nav burti
    text = re.sub(r'\s+', ' ', text).strip()  # atmet liekās atstarpes
    text = ' '.join([lemmatizer.lemmatize(word) for word in text.split() if word not in stop_words])  # lemmatizācija
    return text


# Classification function (single model)
def classify_email_single_model(text, model_name):
    text = preprocess(text)
    inputs = tokenizers[model_name](text, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = models[model_name](**inputs)
        prediction = torch.argmax(outputs.logits, dim=1).item()
    probs = F.softmax(outputs.logits, dim=1)
    probs_percent = probs.cpu().numpy() * 100
    response = {"prediction": labels[prediction], "probabilities": probs_percent}
    return response

    
# Classification function (all models together, probabilities for each model)
def classify_email_detailed(text):
    votes = {"Safe": 0, "Spam": 0, "Phishing": 0}
    probabilities = {}
    
    for model_name in model_names:
        response = classify_email_single_model(text, model_name)
        vote = response['prediction']
        votes[vote] += 1
        probabilities[model_name] = response['probabilities']
        
    response = ""
    i = 1
    for label, vote_count in votes.items():
        vote_or_votes = "vote" if vote_count == 1 else "votes"
        if i != 3:
            response += f"{label}: {vote_count} {vote_or_votes}, "
        else:
            response += f"{label}: {vote_count} {vote_or_votes}\n"
        i += 1
    response += "\n"

    for model_name in model_names:
        response += f"{model_name}: "
        for j, prob in enumerate(probabilities[model_name][0]):
            response += f"{labels[j]}: {prob:.2f}% "
        response += "\n"
        
    return response

# Classification function (all models together, just the votes)
def classify_email_simple(text):
    votes = {"Safe": 0, "Spam": 0, "Phishing": 0}
    
    for model_name in model_names:
        response = classify_email_single_model(text, model_name)
        vote = response['prediction']
        votes[vote] += 1
        
    response = ""
    i = 1
    for label, vote_count in votes.items():
        vote_or_votes = "vote" if vote_count == 1 else "votes"
        if i != 3:
            response += f"{label}: {vote_count} {vote_or_votes}, "
        else:
            response += f"{label}: {vote_count} {vote_or_votes}\n"
        i += 1
    response += "\n"
        
    return response


def classify_email(text, mode):
    if mode == "Tikai balsis":
        return classify_email_simple(text)
    else:
        return classify_email_detailed(text)

    
# Gradio UI
demo = gr.Interface(
    fn=classify_email,
    inputs=[gr.Textbox(lines=10, placeholder="Ievietojiet savu e-pastu šeit..."),
            gr.Radio(choices=["Tikai balsis", "Balsis un varbūtības"], label='Klasifikācijas veids')
    ],
    outputs="text",
    title="E-pastu klasifikators (vairāku modeļu balsošana)",
    description="Autori: Kristaps Tretjuks un Aleksejs Gorlovičs"
)

demo.launch(share=True)