GradApplicationDocuments commited on
Commit
286ab7c
โ€ข
1 Parent(s): 0d1625a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def main():
2
+ """
3
+ Creates a Streamlit web app that classifies a given body of text as either human-made or AI-generated,
4
+ using a pre-trained model.
5
+ """
6
+ # Import libraries
7
+ import streamlit as st
8
+ import numpy as np
9
+ import joblib
10
+ import string
11
+ import time
12
+ import scipy
13
+ import spacy
14
+ import re
15
+ from transformers import AutoTokenizer
16
+ import torch
17
+ from eli5.lime import TextExplainer
18
+ from eli5.lime.samplers import MaskingTextSampler
19
+ import eli5
20
+ import shap
21
+ from custom_models import HF_DistilBertBasedModelAppDocs, HF_BertBasedModelAppDocs
22
+
23
+ # Initialize Spacy
24
+ nlp = spacy.load("en_core_web_sm")
25
+
26
+ # device to run DL model
27
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
+
29
+ def format_text(text: str) -> str:
30
+ """
31
+ This function takes a string as input and returns a formatted version of the string.
32
+ The function replaces specific substrings in the input string with empty strings,
33
+ converts the string to lowercase, removes any leading or trailing whitespace,
34
+ and removes any punctuation from the string.
35
+ """
36
+
37
+ text = nlp(text)
38
+ text = " ".join([token.text for token in text if token.ent_type_ not in ["PERSON", "DATE"]])
39
+
40
+
41
+ return text.replace("REDACTED", "").lower().replace(" "," ").replace("[Name]", "").replace("[your name]", "").replace("\n your name", "").\
42
+ replace("dear admissions committee,", "").replace("sincerely,","").\
43
+ replace("[university's name]","fordham").replace("dear sir/madam,","").\
44
+ replace("โ€“ statement of intent ","").\
45
+ replace('program: master of science in data analytics name of applicant: ',"").\
46
+ replace("data analytics", "data science").replace("| \u200b","").\
47
+ replace("m.s. in data science at lincoln center ","").\
48
+ translate(str.maketrans('', '', string.punctuation)).strip().lstrip()
49
+
50
+ # Define the function to classify text
51
+ def nb_lr(model, text):
52
+ # Clean and format the input text
53
+ text = format_text(text)
54
+ # Predict using either LR or NB and get prediction probability
55
+ prediction = model.predict([text]).item()
56
+ predict_proba = round(model.predict_proba([text]).squeeze()[prediction].item(),4)
57
+ return prediction, predict_proba
58
+
59
+ def torch_pred(tokenizer, model, text):
60
+ # DL models (BERT/DistilBERT based models)
61
+ cleaned_text_tokens = tokenizer([text], padding='max_length', max_length=512, truncation=True)
62
+ with torch.inference_mode():
63
+ text = format_text(text)
64
+ input_ids, att = cleaned_text_tokens["input_ids"], cleaned_text_tokens["attention_mask"]
65
+ input_ids = torch.tensor(input_ids).to(device)
66
+ attention_mask = torch.tensor(att).to(device)
67
+ logits = model(input_ids=input_ids, attention_mask=attention_mask)
68
+ _, prediction = torch.max(logits, 1)
69
+ prediction = prediction.item()
70
+ predict_proba = round(torch.softmax(logits, 1).cpu().squeeze().tolist()[prediction],4)
71
+ return prediction, predict_proba
72
+
73
+ def pred_str(prediction):
74
+ # Map the predicted class to string output
75
+ if prediction == 0:
76
+ return "Human-made ๐Ÿคทโ€โ™‚๏ธ๐Ÿคทโ€โ™€๏ธ"
77
+ else:
78
+ return "Revised with AI ๐Ÿฆพ"
79
+
80
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True)
81
+ def load_tokenizer(option):
82
+ if option == "BERT-based model":
83
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", padding='max_length', max_length=512, truncation=True)
84
+ else:
85
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", padding='max_length', max_length=512, truncation=True)
86
+ return tokenizer
87
+
88
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True)
89
+ def load_model(option):
90
+ if option == "BERT-based model":
91
+ model = HF_BertBasedModelAppDocs.from_pretrained("GradApplicationDocsApp/HF_BertBasedModelAppDocs3").to(device)
92
+ else:
93
+ model = HF_DistilBertBasedModelAppDocs.from_pretrained("GradApplicationDocsApp/HF_DistilBertBasedModelAppDocs3").to(device)
94
+ return model
95
+
96
+
97
+ # Streamlit app:
98
+
99
+ models_available = {"Logistic Regression":"models/baseline_model_lr.joblib",
100
+ "Naive Bayes": "models/baseline_model_nb.joblib",
101
+ "DistilBERT-based model (BERT light)": "GradApplicationDocsApp/HF_DistilBertBasedModelAppDocs3",
102
+ "BERT-based model": "GradApplicationDocsApp/HF_BertBasedModelAppDocs3"
103
+ }
104
+
105
+ st.set_page_config(page_title="AI/Human GradAppDocs", page_icon="๐Ÿค–", layout="wide")
106
+ st.title("Academic Application Document Classifier")
107
+ st.header("Is it human-made ๐Ÿ“ or Enhanced with AI ๐Ÿค– ? ")
108
+
109
+ st.markdown('AI-generated content has reached an unprecedented level of realism. The models on this website focus on identifying AI-enhanced application materials, such as Statements of Intent (SOI) and Letters of Recommendation (LOR). These models were trained using real-world SOIs and LORs, alongside a revised version of each that has been generated through AI.')
110
+ # Check the model to use
111
+ def restore_prediction_state():
112
+ if "prediction" in st.session_state:
113
+ del st.session_state.prediction
114
+ option = st.selectbox("Select a model to use:", models_available, on_change=restore_prediction_state)
115
+
116
+
117
+ # Load the selected trained model
118
+ if option in ("BERT-based model", "DistilBERT-based model (BERT light)"):
119
+ tokenizer = load_tokenizer(option)
120
+ model = load_model(option)
121
+ else:
122
+ model = joblib.load(models_available[option])
123
+
124
+
125
+ text = st.text_area("Enter either a statement of intent or a letter of recommendation:")
126
+
127
+ #Hide footer "made with streamlit"
128
+ hide_st_style = """
129
+ <style>
130
+ footer {visibility: hidden;}
131
+ header {visibility: hidden;}
132
+ </style>
133
+ """
134
+ st.markdown(hide_st_style, unsafe_allow_html=True)
135
+
136
+ # Use model
137
+ if st.button("Let's check this text!"):
138
+ if text.strip() == "":
139
+ st.error("Please enter some text")
140
+ else:
141
+ with st.spinner("Wait for the magic ๐Ÿช„๐Ÿ”ฎ"):
142
+ # Use model
143
+ if option in ("Naive Bayes", "Logistic Regression"):
144
+ prediction, predict_proba = nb_lr(model, text)
145
+ st.session_state["sklearn"] = True
146
+ else:
147
+ prediction, predict_proba = torch_pred(tokenizer, model, format_text(text))
148
+ st.session_state["torch"] = True
149
+
150
+ # Store the result in session state
151
+ st.session_state["color_pred"] = "blue" if prediction == 0 else "red"
152
+ prediction = pred_str(prediction)
153
+ st.session_state["prediction"] = prediction
154
+ st.session_state["predict_proba"] = predict_proba
155
+ st.session_state["text"] = text
156
+
157
+ # Print result
158
+ st.markdown(f"I think this text is: **:{st.session_state['color_pred']}[{st.session_state['prediction']}]** (Confidence: {st.session_state['predict_proba'] * 100}%)")
159
+
160
+ elif "prediction" in st.session_state:
161
+ # Display the stored result if available
162
+ st.markdown(f"I think this text is: **:{st.session_state['color_pred']}[{st.session_state['prediction']}]** (Confidence: {st.session_state['predict_proba'] * 100}%)")
163
+
164
+ if st.button("Model Explanation"):
165
+ # Check if there's text in the session state
166
+ if "text" in st.session_state and "prediction" in st.session_state:
167
+
168
+ if option in ("Naive Bayes", "Logistic Regression"):
169
+ with st.spinner('Wait for it ๐Ÿ’ญ...'):
170
+ explainer = TextExplainer(sampler=MaskingTextSampler())
171
+ explainer.fit(st.session_state["text"], model.predict_proba)
172
+ html = eli5.format_as_html(explainer.explain_prediction(target_names=["Human", "AI"]))
173
+ st.markdown('<span style="color:green"><strong>Green:</strong> Contributes to decision | </span><span style="color:red"><strong>Red:</strong> Opposite</span>', unsafe_allow_html=True)
174
+ else:
175
+ with st.spinner('Wait for it ๐Ÿ’ญ... BERT-based model explanations take around 4-10 minutes. In case you want to abort, please refresh the page.'):
176
+ # TORCH EXPLAINER PRED FUNC (USES logits)
177
+ def f(x):
178
+ tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=512, truncation=True) for v in x])
179
+ outputs = model(tv).detach().cpu().numpy()
180
+ scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
181
+ val = scipy.special.logit(scores[:,1]) # use one vs rest logit units
182
+ return val
183
+ # build an explainer using a token masker
184
+ explainer = shap.Explainer(f, tokenizer)
185
+ shap_values = explainer([st.session_state["text"]], fixed_context=1)
186
+ html = shap.plots.text(shap_values, display=False)
187
+ st.markdown('<span style="color:blue"><strong>Blue:</strong> Contributes to "human" | </span><span style="color:red"><strong>Red:</strong> Contributes to "AI"</span>', unsafe_allow_html=True)
188
+ # Render HTML
189
+ st.components.v1.html(html, height=500, scrolling = True)
190
+ else:
191
+ st.error("Please enter some text and click 'Let's check!' before requesting an explanation.")
192
+
193
+ if __name__ == "__main__":
194
+ main()