import torch import gradio as gr from datasets import load_dataset from transformers import pipeline from textwrap import dedent from email import message_from_file from email.header import decode_header # select device device = "cuda:0" if torch.cuda.is_available() else "cpu" # load model pipe = pipeline(model="1aurent/distilbert-base-multilingual-cased-finetuned-email-spam", device=device) # fn to predict from text def classify_raw(text): return pipe(text, top_k=2) # fn to predict from form inputs def classify_form(mailfrom, x_mailfrom, to, reply_to, subject): text = dedent(f""" From: {mailfrom} X-MailFrom: {x_mailfrom} To: {to} Reply-To: {reply_to} Subject: {subject} """).strip() return pipe(text, top_k=2) # helper to extract header from email def get_header(message, header_name: str) -> str: try: for payload, _ in decode_header(message[header_name]): if type(payload) == bytes: payload = payload.decode(errors="ignore") header = payload header = header.replace("\n", " ") header = header.strip() return header except: return "" # fn to predict from email file def classify_file(file): message = message_from_file(open(file.name)) return classify_form( mailfrom=get_header(message, "From"), x_mailfrom=get_header(message, "X-MailFrom"), to=get_header(message, "To"), reply_to=get_header(message, "Reply-To"), subject=get_header(message, "Subject"), ) title = "Email Spam Classifier" description = """ Spam or ham ? """ demo = gr.Blocks() raw_interface = gr.Interface( fn=classify_raw, inputs=gr.Textbox( label="Formatted Email Header", lines=5, placeholder=dedent(""" From: Laurent Fainsin X-MailFrom: Laurent Fainsin To: net7 Reply-To: Laurent Fainsin Subject: Re: Demande d'un H24 net7 """).strip(), ), outputs="json", api_name="predict_raw_text", ) form_interface = gr.Interface( fn=classify_form, inputs=[ gr.Textbox( label="From", placeholder="Laurent Fainsin ", ), gr.Textbox( label="X-MailFrom", placeholder="Laurent Fainsin ", ), gr.Textbox( label="To", placeholder="net7 ", ), gr.Textbox( label="Reply-To", placeholder="Laurent Fainsin ", ), gr.Textbox( label="Subject", placeholder="Re: Demande d'un H24 net7", ), ], outputs="json", api_name="predict_form", ) file_interface = gr.Interface( fn=classify_file, inputs=gr.File( label="Email File", file_types=[".eml"], ), outputs="json", api_name="predict_file", ) with demo: gr.TabbedInterface( interface_list=[ raw_interface, form_interface, file_interface ], tab_names=[ "Raw Text", "Form", "File" ] ) demo.launch()