Spaces:
Running
Running
File size: 6,004 Bytes
3793db4 c583103 3793db4 dd706fe 39ccd13 3793db4 3adc964 3793db4 77e2bb8 9dccda9 77e2bb8 3793db4 b15f954 3793db4 665b792 b15f954 3793db4 b15f954 783873f 3793db4 b15f954 3793db4 b15f954 3793db4 b15f954 3793db4 b15f954 3793db4 c8f1793 b15f954 d0b7fe0 c8f1793 b15f954 3793db4 783873f 3793db4 783873f 3793db4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import gradio as gr
import torch
import torch.nn.functional as F
import re
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
import os
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import urllib.request
# Modeļu inicializācija
model_names = ["distilbert-base-uncased", "prajjwal1/bert-tiny", "roberta-base", "google/mobilebert-uncased", "albert-base-v2", "xlm-roberta-base"]
models = {}
tokenizers = {}
# === Modelis → URL ===
model_urls = {
"best_model_albert-base-v2.pth": "https://www.dropbox.com/scl/fi/adulme5xarg6hgxbs26fm/best_model_albert-base-v2.pth?rlkey=y17x3sw1frk83yfzt8zc00458&st=43uha18d&dl=1",
"best_model_distilbert-base-uncased.pth": "https://www.dropbox.com/scl/fi/8y3oyfbzmbmn427e1ei3d/best_model_distilbert-base-uncased.pth?rlkey=u9rd40tdd3p781r4xtv8wi5t6&st=nfzq7x8j&dl=1",
"best_model_google_mobilebert-uncased.pth": "https://www.dropbox.com/scl/fi/7zdarid2no1fw0b8hk0tf/best_model_google_mobilebert-uncased.pth?rlkey=w13j1jampxlt8himivj090nwv&st=0zq6yofp&dl=1",
"best_model_prajjwal1_bert-tiny.pth": "https://www.dropbox.com/scl/fi/vscwewy4uo58o7xswokxt/best_model_prajjwal1_bert-tiny.pth?rlkey=uav8aas7fxb5nl2w5iacg1qyb&st=12mzggan&dl=1",
"best_model_roberta-base.pth": "https://www.dropbox.com/scl/fi/zqmlzt0q6knjv096yswsr/best_model_roberta-base.pth?rlkey=hi8ddi23dnz45xt3jomxq0pek&st=2axjymyt&dl=1",
"best_model_xlm-roberta-base.pth": "https://www.dropbox.com/scl/fi/2gao9iqesou9kb633vvan/best_model_xlm-roberta-base.pth?rlkey=acyvwt8qtle8wzle5idfo8241&st=8livizox&dl=1",
}
# === Lejupielādē modeļus, ja nav ===
for filename, url in model_urls.items():
if not os.path.exists(filename):
print(f"Lejupielādē: {filename}")
try:
urllib.request.urlretrieve(url, filename)
print(f" → Saglabāts: {filename}")
except Exception as e:
print(f" [!] Kļūda lejupielādējot {filename}: {e}")
for model_name in model_names:
# Tokenizators
tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name, max_length=512)
# Modelis ar 3 klasēm
models[model_name] = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
model_file_name = re.sub(r'/', '_', model_name)
models[model_name].load_state_dict(torch.load(f"best_model_{model_file_name}.pth", map_location=torch.device('cpu')))
# Uz ierīces
models[model_name] = models[model_name].to('cpu')
models[model_name].eval()
# Label mapping
labels = {0: "Safe", 1: "Spam", 2: "Phishing"}
lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words('english'))
def preprocess(text):
text = text.lower() # Teksta pārveide atmetot lielos burtus
text = re.sub(r'http\S+', '', text) # URL atmešana
text = re.sub(r"[^a-z']", ' ', text) # atmet simbolus, kas nav burti
text = re.sub(r'\s+', ' ', text).strip() # atmet liekās atstarpes
text = ' '.join([lemmatizer.lemmatize(word) for word in text.split() if word not in stop_words]) # lemmatizācija
return text
# Classification function (single model)
def classify_email_single_model(text, model_name):
text = preprocess(text)
inputs = tokenizers[model_name](text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = models[model_name](**inputs)
prediction = torch.argmax(outputs.logits, dim=1).item()
probs = F.softmax(outputs.logits, dim=1)
probs_percent = probs.cpu().numpy() * 100
response = {"prediction": labels[prediction], "probabilities": probs_percent}
return response
# Classification function (all models together, probabilities for each model)
def classify_email_detailed(text):
votes = {"Safe": 0, "Spam": 0, "Phishing": 0}
probabilities = {}
for model_name in model_names:
response = classify_email_single_model(text, model_name)
vote = response['prediction']
votes[vote] += 1
probabilities[model_name] = response['probabilities']
response = ""
i = 1
for label, vote_count in votes.items():
vote_or_votes = "vote" if vote_count == 1 else "votes"
if i != 3:
response += f"{label}: {vote_count} {vote_or_votes}, "
else:
response += f"{label}: {vote_count} {vote_or_votes}\n"
i += 1
response += "\n"
for model_name in model_names:
response += f"{model_name}: "
for j, prob in enumerate(probabilities[model_name][0]):
response += f"{labels[j]}: {prob:.2f}% "
response += "\n"
return response
# Classification function (all models together, just the votes)
def classify_email_simple(text):
votes = {"Safe": 0, "Spam": 0, "Phishing": 0}
for model_name in model_names:
response = classify_email_single_model(text, model_name)
vote = response['prediction']
votes[vote] += 1
response = ""
i = 1
for label, vote_count in votes.items():
vote_or_votes = "vote" if vote_count == 1 else "votes"
if i != 3:
response += f"{label}: {vote_count} {vote_or_votes}, "
else:
response += f"{label}: {vote_count} {vote_or_votes}\n"
i += 1
response += "\n"
return response
def classify_email(text, mode):
if mode == "Tikai balsis":
return classify_email_simple(text)
else:
return classify_email_detailed(text)
# Gradio UI
demo = gr.Interface(
fn=classify_email,
inputs=[gr.Textbox(lines=10, placeholder="Ievietojiet savu e-pastu šeit..."),
gr.Radio(choices=["Tikai balsis", "Balsis un varbūtības"], label='Klasifikācijas veids')
],
outputs="text",
title="E-pastu klasifikators (vairāku modeļu balsošana)",
description="Autori: Kristaps Tretjuks un Aleksejs Gorlovičs"
)
demo.launch(share=True) |