1aurent commited on
Commit
dd07c8d
·
1 Parent(s): 1c0fccd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from datasets import load_dataset
4
+ from transformers import pipeline
5
+ from textwrap import dedent
6
+ from email import message_from_file
7
+ from email.header import decode_header
8
+
9
+ # select device
10
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
11
+
12
+ # load model
13
+ pipe = pipeline(model="1aurent/distilbert-base-multilingual-cased-finetuned-email-spam", device=device)
14
+
15
+ # fn to predict from text
16
+ def classify_raw(text):
17
+ return pipe(text, top_k=2)
18
+
19
+ # fn to predict from form inputs
20
+ def classify_form(mailfrom, x_mailfrom, to, reply_to, subject):
21
+ text = dedent(f"""
22
+ From: {mailfrom}
23
+ X-MailFrom: {x_mailfrom}
24
+ To: {to}
25
+ Reply-To: {reply_to}
26
+ Subject: {subject}
27
+ """).strip()
28
+ return pipe(text, top_k=2)
29
+
30
+ # helper to extract header from email
31
+ def get_header(message, header_name: str) -> str:
32
+ try:
33
+ for payload, _ in decode_header(message[header_name]):
34
+ if type(payload) == bytes:
35
+ payload = payload.decode(errors="ignore")
36
+ header = payload
37
+ header = header.replace("\n", " ")
38
+ header = header.strip()
39
+ return header
40
+ except:
41
+ return ""
42
+
43
+ # fn to predict from email file
44
+ def classify_file(file):
45
+ message = message_from_file(open(file.name))
46
+
47
+ return classify_form(
48
+ mailfrom=get_header(message, "From"),
49
+ x_mailfrom=get_header(message, "X-MailFrom"),
50
+ to=get_header(message, "To"),
51
+ reply_to=get_header(message, "Reply-To"),
52
+ subject=get_header(message, "Subject"),
53
+ )
54
+
55
+
56
+ title = "Email Spam Classifier"
57
+ description = """
58
+ Spam or ham ?
59
+ """
60
+
61
+ demo = gr.Blocks()
62
+
63
+ raw_interface = gr.Interface(
64
+ fn=classify_raw,
65
+ inputs=gr.Textbox(
66
+ label="Formatted Email Header",
67
+ placeholder=dedent("""
68
+ From: Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>
69
+ X-MailFrom: Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>
70
+ To: net7 <net7@bde.enseeiht.fr>
71
+ Reply-To: Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>
72
+ Subject: Re: Demande d'un H24 net7
73
+ """).strip(),
74
+ ),
75
+ outputs="json",
76
+ )
77
+
78
+ form_interface = gr.Interface(
79
+ fn=classify_form,
80
+ inputs=[
81
+ gr.Textbox(
82
+ label="From",
83
+ placeholder="Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>",
84
+ ),
85
+ gr.Textbox(
86
+ label="X-MailFrom",
87
+ placeholder="Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>",
88
+ ),
89
+ gr.Textbox(
90
+ label="To",
91
+ placeholder="net7 <net7@bde.enseeiht.fr>",
92
+ ),
93
+ gr.Textbox(
94
+ label="Reply-To",
95
+ placeholder="Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>",
96
+ ),
97
+ gr.Textbox(
98
+ label="Subject",
99
+ placeholder="Re: Demande d'un H24 net7",
100
+ ),
101
+ ],
102
+ outputs="json",
103
+ )
104
+
105
+ file_interface = gr.Interface(
106
+ fn=classify_file,
107
+ inputs=gr.File(
108
+ label="Email File",
109
+ file_types=[".eml"],
110
+ ),
111
+ outputs="json",
112
+ )
113
+
114
+ with demo:
115
+ gr.TabbedInterface(
116
+ interface_list=[
117
+ raw_interface,
118
+ form_interface,
119
+ file_interface
120
+ ],
121
+ tab_names=[
122
+ "Raw Text",
123
+ "Form",
124
+ "File"
125
+ ]
126
+ )
127
+
128
+ demo.launch()