bayes-or-spam / app.py
tbitai's picture
LLM model
8937106 verified
raw
history blame
3.81 kB
import gradio as gr
from huggingface_hub import hf_hub_download
import json
import tensorflow as tf
import numpy as np
from sentence_transformers import SentenceTransformer
# Load models
model_probs_path = hf_hub_download(repo_id="tbitai/bayes-enron1-spam", filename="probs.json")
with open(model_probs_path) as f:
model_probs = json.load(f)
nn_model_path = hf_hub_download(repo_id="tbitai/nn-enron1-spam", filename="nn-enron1-spam.keras")
nn_model = tf.keras.models.load_model(nn_model_path)
st_model = SentenceTransformer("avsolatorio/GIST-large-Embedding-v0")
llm_model_path = hf_hub_download(repo_id="tbitai/gisty-enron1-spam", filename="gisty-enron1-spam.keras")
llm_model = tf.keras.models.load_model(llm_model_path)
# Utils for Bayes
UNK = '[UNK]'
def tokenize(text):
return tf.keras.preprocessing.text.text_to_word_sequence(text)
def combine(probs):
if any(p == 0 for p in probs):
return 0
prod = np.prod(probs)
neg_prod = np.prod([1 - p for p in probs])
if prod + neg_prod == 0: # Still possible due to floating point arithmetic
return 0.5 # Assume that prod and neg_prod are equally small
return prod / (prod + neg_prod)
def get_interesting_probs(probs, intr_threshold):
return sorted(probs,
key=lambda p: abs(p - 0.5),
reverse=True)[:intr_threshold]
DEFAULT_INTR_THRESHOLD = 15
def unbias(p):
return (2 * p) / (p + 1)
# Predict functions
def predict_bayes(text, intr_threshold, unbiased=False):
words = tokenize(text)
probs = []
for w in words:
try:
p = model_probs[w]
if unbiased:
p = unbias(p)
except KeyError:
p = model_probs[UNK]
probs.append(p)
interesting_probs = get_interesting_probs(probs, intr_threshold)
return combine(interesting_probs)
def predict_nn(text):
return nn_model(np.array([text]))[0][0].numpy()
def predict_llm(text):
embedding = st_model.encode(text)
return llm_model(np.array([embedding]))[0][0].numpy()
MODELS = [
BAYES := "Bayes Enron1 spam",
NN := "NN Enron1 spam",
LLM := "GISTy Enron1 spam",
]
def predict(model, input_txt, unbiased, intr_threshold):
if model == BAYES:
return predict_bayes(input_txt, unbiased=unbiased, intr_threshold=intr_threshold)
elif model == NN:
return predict_nn(input_txt)
elif model == LLM:
return predict_llm(input_txt)
# UI
demo = gr.Interface(
fn=predict,
inputs=[
gr.Dropdown(choices=MODELS, value=BAYES, label="Model"),
gr.TextArea(label="Email"),
],
additional_inputs_accordion=gr.Accordion("Additional configuration for Bayes", open=False),
additional_inputs=[
gr.Checkbox(label="Unbias", info="Correct Graham's bias?"),
gr.Slider(minimum=1, maximum=DEFAULT_INTR_THRESHOLD + 5, step=1, value=DEFAULT_INTR_THRESHOLD,
label="Interestingness threshold",
info=f"How many of the most interesting words to select in the probability calculation? ({DEFAULT_INTR_THRESHOLD} for Graham)"),
],
outputs=[gr.Number(label="Spam probability")],
title="Bayes or Spam?",
description="Choose your model, and predict if your email is a spam! 📨",
examples=[
[BAYES, enron_email := "Enron actuals for June 26, 2000", False, DEFAULT_INTR_THRESHOLD],
[BAYES, nerissa_email := "Stop the aging clock\nNerissa", False, DEFAULT_INTR_THRESHOLD],
[BAYES, nerissa_email, True, DEFAULT_INTR_THRESHOLD],
[NN, enron_email, None, None],
[LLM, enron_email, None, None],
],
article="This is a demo of the models in the [Bayes or Spam?](https://github.com/tbitai/bayes-or-spam) project.",
)
if __name__ == "__main__":
demo.launch()