Spaces:
Running
Running
File size: 3,805 Bytes
850dd49 cc21853 8937106 850dd49 5f92565 cc21853 5f92565 8937106 5f92565 cc21853 5f92565 cc21853 5f92565 cc21853 6441204 cc21853 5f92565 f409800 8937106 cc21853 5f92565 8937106 cc21853 d1a418a cc21853 5f92565 8937106 5f92565 850dd49 f8dfc8e cc21853 5f92565 83db710 5f92565 cc21853 f409800 f8dfc8e 850dd49 8937106 f8dfc8e 8937106 80b487d 8937106 f8dfc8e 9bd0746 850dd49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
from huggingface_hub import hf_hub_download
import json
import tensorflow as tf
import numpy as np
from sentence_transformers import SentenceTransformer
# Load models
model_probs_path = hf_hub_download(repo_id="tbitai/bayes-enron1-spam", filename="probs.json")
with open(model_probs_path) as f:
model_probs = json.load(f)
nn_model_path = hf_hub_download(repo_id="tbitai/nn-enron1-spam", filename="nn-enron1-spam.keras")
nn_model = tf.keras.models.load_model(nn_model_path)
st_model = SentenceTransformer("avsolatorio/GIST-large-Embedding-v0")
llm_model_path = hf_hub_download(repo_id="tbitai/gisty-enron1-spam", filename="gisty-enron1-spam.keras")
llm_model = tf.keras.models.load_model(llm_model_path)
# Utils for Bayes
UNK = '[UNK]'
def tokenize(text):
return tf.keras.preprocessing.text.text_to_word_sequence(text)
def combine(probs):
if any(p == 0 for p in probs):
return 0
prod = np.prod(probs)
neg_prod = np.prod([1 - p for p in probs])
if prod + neg_prod == 0: # Still possible due to floating point arithmetic
return 0.5 # Assume that prod and neg_prod are equally small
return prod / (prod + neg_prod)
def get_interesting_probs(probs, intr_threshold):
return sorted(probs,
key=lambda p: abs(p - 0.5),
reverse=True)[:intr_threshold]
DEFAULT_INTR_THRESHOLD = 15
def unbias(p):
return (2 * p) / (p + 1)
# Predict functions
def predict_bayes(text, intr_threshold, unbiased=False):
words = tokenize(text)
probs = []
for w in words:
try:
p = model_probs[w]
if unbiased:
p = unbias(p)
except KeyError:
p = model_probs[UNK]
probs.append(p)
interesting_probs = get_interesting_probs(probs, intr_threshold)
return combine(interesting_probs)
def predict_nn(text):
return nn_model(np.array([text]))[0][0].numpy()
def predict_llm(text):
embedding = st_model.encode(text)
return llm_model(np.array([embedding]))[0][0].numpy()
MODELS = [
BAYES := "Bayes Enron1 spam",
NN := "NN Enron1 spam",
LLM := "GISTy Enron1 spam",
]
def predict(model, input_txt, unbiased, intr_threshold):
if model == BAYES:
return predict_bayes(input_txt, unbiased=unbiased, intr_threshold=intr_threshold)
elif model == NN:
return predict_nn(input_txt)
elif model == LLM:
return predict_llm(input_txt)
# UI
demo = gr.Interface(
fn=predict,
inputs=[
gr.Dropdown(choices=MODELS, value=BAYES, label="Model"),
gr.TextArea(label="Email"),
],
additional_inputs_accordion=gr.Accordion("Additional configuration for Bayes", open=False),
additional_inputs=[
gr.Checkbox(label="Unbias", info="Correct Graham's bias?"),
gr.Slider(minimum=1, maximum=DEFAULT_INTR_THRESHOLD + 5, step=1, value=DEFAULT_INTR_THRESHOLD,
label="Interestingness threshold",
info=f"How many of the most interesting words to select in the probability calculation? ({DEFAULT_INTR_THRESHOLD} for Graham)"),
],
outputs=[gr.Number(label="Spam probability")],
title="Bayes or Spam?",
description="Choose your model, and predict if your email is a spam! 📨",
examples=[
[BAYES, enron_email := "Enron actuals for June 26, 2000", False, DEFAULT_INTR_THRESHOLD],
[BAYES, nerissa_email := "Stop the aging clock\nNerissa", False, DEFAULT_INTR_THRESHOLD],
[BAYES, nerissa_email, True, DEFAULT_INTR_THRESHOLD],
[NN, enron_email, None, None],
[LLM, enron_email, None, None],
],
article="This is a demo of the models in the [Bayes or Spam?](https://github.com/tbitai/bayes-or-spam) project.",
)
if __name__ == "__main__":
demo.launch() |