File size: 3,805 Bytes
850dd49
cc21853
 
 
 
8937106
850dd49
5f92565
 
 
cc21853
 
 
 
5f92565
 
 
8937106
 
 
 
5f92565
 
 
cc21853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f92565
 
cc21853
 
 
5f92565
 
 
cc21853
 
6441204
 
 
 
 
 
 
 
 
cc21853
 
 
5f92565
 
f409800
8937106
 
 
 
cc21853
 
5f92565
8937106
cc21853
 
d1a418a
cc21853
 
5f92565
 
8937106
 
5f92565
 
 
850dd49
 
 
f8dfc8e
cc21853
5f92565
 
83db710
5f92565
cc21853
f409800
 
 
f8dfc8e
 
850dd49
8937106
f8dfc8e
8937106
80b487d
 
8937106
 
f8dfc8e
9bd0746
850dd49
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
from huggingface_hub import hf_hub_download
import json
import tensorflow as tf
import numpy as np
from sentence_transformers import SentenceTransformer


# Load models

model_probs_path = hf_hub_download(repo_id="tbitai/bayes-enron1-spam", filename="probs.json")
with open(model_probs_path) as f:
    model_probs = json.load(f)

nn_model_path = hf_hub_download(repo_id="tbitai/nn-enron1-spam", filename="nn-enron1-spam.keras")
nn_model = tf.keras.models.load_model(nn_model_path)

st_model = SentenceTransformer("avsolatorio/GIST-large-Embedding-v0")
llm_model_path = hf_hub_download(repo_id="tbitai/gisty-enron1-spam", filename="gisty-enron1-spam.keras")
llm_model = tf.keras.models.load_model(llm_model_path)


# Utils for Bayes

UNK = '[UNK]'

def tokenize(text):
    return tf.keras.preprocessing.text.text_to_word_sequence(text)

def combine(probs):
    if any(p == 0 for p in probs):
        return 0
    prod = np.prod(probs)
    neg_prod = np.prod([1 - p for p in probs])
    if prod + neg_prod == 0:  # Still possible due to floating point arithmetic
        return 0.5  # Assume that prod and neg_prod are equally small
    return prod / (prod + neg_prod)

def get_interesting_probs(probs, intr_threshold):
    return sorted(probs,
                  key=lambda p: abs(p - 0.5),
                  reverse=True)[:intr_threshold]

DEFAULT_INTR_THRESHOLD = 15

def unbias(p):
    return (2 * p) / (p + 1)


# Predict functions

def predict_bayes(text, intr_threshold, unbiased=False):
    words = tokenize(text)
    probs = []
    for w in words:
        try:
            p = model_probs[w]
            if unbiased:
                p = unbias(p)
        except KeyError:
            p = model_probs[UNK]
        probs.append(p)
    interesting_probs = get_interesting_probs(probs, intr_threshold)
    return combine(interesting_probs)

def predict_nn(text):
    return nn_model(np.array([text]))[0][0].numpy()

def predict_llm(text):
    embedding = st_model.encode(text)
    return llm_model(np.array([embedding]))[0][0].numpy()

MODELS = [
    BAYES := "Bayes Enron1 spam",
    NN := "NN Enron1 spam",
    LLM := "GISTy Enron1 spam",
]

def predict(model, input_txt, unbiased, intr_threshold):
    if model == BAYES:
        return predict_bayes(input_txt, unbiased=unbiased, intr_threshold=intr_threshold)
    elif model == NN:
        return predict_nn(input_txt)
    elif model == LLM:
        return predict_llm(input_txt)


# UI

demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Dropdown(choices=MODELS, value=BAYES, label="Model"),
        gr.TextArea(label="Email"),
    ],
    additional_inputs_accordion=gr.Accordion("Additional configuration for Bayes", open=False),
    additional_inputs=[
        gr.Checkbox(label="Unbias", info="Correct Graham's bias?"),
        gr.Slider(minimum=1, maximum=DEFAULT_INTR_THRESHOLD + 5, step=1, value=DEFAULT_INTR_THRESHOLD, 
                  label="Interestingness threshold", 
                  info=f"How many of the most interesting words to select in the probability calculation? ({DEFAULT_INTR_THRESHOLD} for Graham)"),
    ],
    outputs=[gr.Number(label="Spam probability")],
    title="Bayes or Spam?",
    description="Choose your model, and predict if your email is a spam! 📨",
    examples=[
        [BAYES, enron_email := "Enron actuals for June 26, 2000", False, DEFAULT_INTR_THRESHOLD],
        [BAYES, nerissa_email := "Stop the aging clock\nNerissa", False, DEFAULT_INTR_THRESHOLD],
        [BAYES, nerissa_email, True, DEFAULT_INTR_THRESHOLD],
        [NN, enron_email, None, None],
        [LLM, enron_email, None, None],
    ],
    article="This is a demo of the models in the [Bayes or Spam?](https://github.com/tbitai/bayes-or-spam) project.",
)

if __name__ == "__main__":
    demo.launch()