1aurent's picture
Update app.py
e8b6f39
import torch
import gradio as gr
from datasets import load_dataset
from transformers import pipeline
from textwrap import dedent
from email import message_from_file
from email.header import decode_header
# select device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# load model
pipe = pipeline(model="1aurent/distilbert-base-multilingual-cased-finetuned-email-spam", device=device)
# fn to predict from text
def classify_raw(text):
return pipe(text, top_k=2)
# fn to predict from form inputs
def classify_form(mailfrom, x_mailfrom, to, reply_to, subject):
text = dedent(f"""
From: {mailfrom}
X-MailFrom: {x_mailfrom}
To: {to}
Reply-To: {reply_to}
Subject: {subject}
""").strip()
return pipe(text, top_k=2)
# helper to extract header from email
def get_header(message, header_name: str) -> str:
try:
for payload, _ in decode_header(message[header_name]):
if type(payload) == bytes:
payload = payload.decode(errors="ignore")
header = payload
header = header.replace("\n", " ")
header = header.strip()
return header
except:
return ""
# fn to predict from email file
def classify_file(file):
message = message_from_file(open(file.name))
return classify_form(
mailfrom=get_header(message, "From"),
x_mailfrom=get_header(message, "X-MailFrom"),
to=get_header(message, "To"),
reply_to=get_header(message, "Reply-To"),
subject=get_header(message, "Subject"),
)
title = "Email Spam Classifier"
description = """
Spam or ham ?
"""
demo = gr.Blocks()
raw_interface = gr.Interface(
fn=classify_raw,
inputs=gr.Textbox(
label="Formatted Email Header",
lines=5,
placeholder=dedent("""
From: Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>
X-MailFrom: Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>
To: net7 <net7@bde.enseeiht.fr>
Reply-To: Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>
Subject: Re: Demande d'un H24 net7
""").strip(),
),
outputs="json",
api_name="predict_raw_text",
)
form_interface = gr.Interface(
fn=classify_form,
inputs=[
gr.Textbox(
label="From",
placeholder="Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>",
),
gr.Textbox(
label="X-MailFrom",
placeholder="Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>",
),
gr.Textbox(
label="To",
placeholder="net7 <net7@bde.enseeiht.fr>",
),
gr.Textbox(
label="Reply-To",
placeholder="Laurent Fainsin <laurent.fainsin@etu.inp-n7.fr>",
),
gr.Textbox(
label="Subject",
placeholder="Re: Demande d'un H24 net7",
),
],
outputs="json",
api_name="predict_form",
)
file_interface = gr.Interface(
fn=classify_file,
inputs=gr.File(
label="Email File",
file_types=[".eml"],
),
outputs="json",
api_name="predict_file",
)
with demo:
gr.TabbedInterface(
interface_list=[
raw_interface,
form_interface,
file_interface
],
tab_names=[
"Raw Text",
"Form",
"File"
]
)
demo.launch()