Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
from datasets import load_dataset
|
4 |
+
from transformers import pipeline
|
5 |
+
from textwrap import dedent
|
6 |
+
from email import message_from_file
|
7 |
+
from email.header import decode_header
|
8 |
+
|
9 |
+
# select device
|
10 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
11 |
+
|
12 |
+
# load model
|
13 |
+
pipe = pipeline(model="1aurent/distilbert-base-multilingual-cased-finetuned-email-spam", device=device)
|
14 |
+
|
15 |
+
# fn to predict from text
|
16 |
+
def classify_raw(text):
|
17 |
+
return pipe(text, top_k=2)
|
18 |
+
|
19 |
+
# fn to predict from form inputs
|
20 |
+
def classify_form(mailfrom, x_mailfrom, to, reply_to, subject):
|
21 |
+
text = dedent(f"""
|
22 |
+
From: {mailfrom}
|
23 |
+
X-MailFrom: {x_mailfrom}
|
24 |
+
To: {to}
|
25 |
+
Reply-To: {reply_to}
|
26 |
+
Subject: {subject}
|
27 |
+
""").strip()
|
28 |
+
return pipe(text, top_k=2)
|
29 |
+
|
30 |
+
# helper to extract header from email
|
31 |
+
def get_header(message, header_name: str) -> str:
|
32 |
+
try:
|
33 |
+
for payload, _ in decode_header(message[header_name]):
|
34 |
+
if type(payload) == bytes:
|
35 |
+
payload = payload.decode(errors="ignore")
|
36 |
+
header = payload
|
37 |
+
header = header.replace("\n", " ")
|
38 |
+
header = header.strip()
|
39 |
+
return header
|
40 |
+
except:
|
41 |
+
return ""
|
42 |
+
|
43 |
+
# fn to predict from email file
|
44 |
+
def classify_file(file):
|
45 |
+
message = message_from_file(open(file.name))
|
46 |
+
|
47 |
+
return classify_form(
|
48 |
+
mailfrom=get_header(message, "From"),
|
49 |
+
x_mailfrom=get_header(message, "X-MailFrom"),
|
50 |
+
to=get_header(message, "To"),
|
51 |
+
reply_to=get_header(message, "Reply-To"),
|
52 |
+
subject=get_header(message, "Subject"),
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
title = "Email Spam Classifier"
|
57 |
+
description = """
|
58 |
+
Spam or ham ?
|
59 |
+
"""
|
60 |
+
|
61 |
+
demo = gr.Blocks()
|
62 |
+
|
63 |
+
raw_interface = gr.Interface(
|
64 |
+
fn=classify_raw,
|
65 |
+
inputs=gr.Textbox(
|
66 |
+
label="Formatted Email Header",
|
67 |
+
placeholder=dedent("""
|
68 |
+
From: Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>
|
69 |
+
X-MailFrom: Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>
|
70 |
+
To: net7 <net7@bde.enseeiht.fr>
|
71 |
+
Reply-To: Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>
|
72 |
+
Subject: Re: Demande d'un H24 net7
|
73 |
+
""").strip(),
|
74 |
+
),
|
75 |
+
outputs="json",
|
76 |
+
)
|
77 |
+
|
78 |
+
form_interface = gr.Interface(
|
79 |
+
fn=classify_form,
|
80 |
+
inputs=[
|
81 |
+
gr.Textbox(
|
82 |
+
label="From",
|
83 |
+
placeholder="Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>",
|
84 |
+
),
|
85 |
+
gr.Textbox(
|
86 |
+
label="X-MailFrom",
|
87 |
+
placeholder="Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>",
|
88 |
+
),
|
89 |
+
gr.Textbox(
|
90 |
+
label="To",
|
91 |
+
placeholder="net7 <net7@bde.enseeiht.fr>",
|
92 |
+
),
|
93 |
+
gr.Textbox(
|
94 |
+
label="Reply-To",
|
95 |
+
placeholder="Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>",
|
96 |
+
),
|
97 |
+
gr.Textbox(
|
98 |
+
label="Subject",
|
99 |
+
placeholder="Re: Demande d'un H24 net7",
|
100 |
+
),
|
101 |
+
],
|
102 |
+
outputs="json",
|
103 |
+
)
|
104 |
+
|
105 |
+
file_interface = gr.Interface(
|
106 |
+
fn=classify_file,
|
107 |
+
inputs=gr.File(
|
108 |
+
label="Email File",
|
109 |
+
file_types=[".eml"],
|
110 |
+
),
|
111 |
+
outputs="json",
|
112 |
+
)
|
113 |
+
|
114 |
+
with demo:
|
115 |
+
gr.TabbedInterface(
|
116 |
+
interface_list=[
|
117 |
+
raw_interface,
|
118 |
+
form_interface,
|
119 |
+
file_interface
|
120 |
+
],
|
121 |
+
tab_names=[
|
122 |
+
"Raw Text",
|
123 |
+
"Form",
|
124 |
+
"File"
|
125 |
+
]
|
126 |
+
)
|
127 |
+
|
128 |
+
demo.launch()
|