A Vo commited on
Commit
40884a0
β€’
1 Parent(s): a48cee6

Added eval metrics, comments

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. app.py +187 -30
  3. reddit_cleansed_data.csv +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.csv filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,16 +1,23 @@
1
  # Imports
2
  # Core Imports
3
  import torch
 
4
  # Model-related Imports
5
  from transformers import BartTokenizer, BartForConditionalGeneration # fine-tuned BART model
6
  from transformers import AutoTokenizer, AutoModelForTokenClassification # restore punct
7
  from transformers import pipeline # restore punct
8
  import gradio as gr
9
 
 
 
 
 
 
 
10
 
11
 
12
  # Instantiate model to restore punctuation
13
- print("1/4 - Instantiating model to restore punctuation")
14
 
15
  punct_model_path = "felflare/bert-restore-punctuation"
16
  # Load punct tokenizer and model
@@ -21,7 +28,7 @@ punct_restorer = pipeline("token-classification", model=punct_model, tokenizer=p
21
 
22
 
23
  # Instantiate fine-tuned horror BART model
24
- print("2/4 - Instantiating two-sentence horror generation model")
25
 
26
  model_path = 'voacado/bart-two-sentence-horror'
27
  # Load tokenizer and model
@@ -30,8 +37,108 @@ model = BartForConditionalGeneration.from_pretrained(model_path)
30
 
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # Set up inference
34
- print("3/4 - Setting parameters for inference")
35
 
36
  # Set the model to evaluation mode
37
  model.eval()
@@ -40,10 +147,20 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
  model.to(device)
41
 
42
  # Restore punct
43
- def restore_punctuation(text, restorer):
 
 
 
 
 
 
 
 
 
 
44
  # Use the model to predict punctuation
45
  punctuated_output = restorer(text)
46
- punctuated_text = []
47
 
48
  # Define punctuation marks (note: not including left-side because we want space still)
49
  punctuation_marks = ["!", "?", ".", "-", ":", ";", "'", "’", ",", ")", "]", "}", "…", "”", "’’", "''"]
@@ -53,65 +170,105 @@ def restore_punctuation(text, restorer):
53
 
54
  # If token is punctuation, append to previous token
55
  if cur_token in punctuation_marks:
56
- punctuated_text[-1] += cur_token
57
 
58
  # If previous token is quotations, append to previous token
59
- elif punctuated_text and punctuated_text[-1] in ["'", "’", "β€œ", "β€˜", "β€˜β€˜", "β€œβ€œ"]:
60
- punctuated_text[-1] += cur_token
61
 
62
  # If token is a contraction or a quote, append to previous token (no space)
63
  elif cur_token.lower() in ["s", "t", "re", "ve", "ll", "d", "m"]:
64
  # Remove space for contractions
65
- punctuated_text[-1] += cur_token
66
 
67
  # if prediction is LABEL_0, token should be capitalized
68
  elif elem.get('entity') == 'LABEL_0':
69
- punctuated_text.append(cur_token.capitalize())
70
 
71
  # else if prediction is LABEL_1, token should be lowercase
72
  # elif elem.get('entity') == 'LABEL_1':
73
  else:
74
- punctuated_text.append(cur_token)
75
 
76
  # If there's no period at the end of the story, add one
77
- if punctuated_text[-1][-1] != '.':
78
- punctuated_text[-1] = punctuated_text[-1] + '.'
79
 
80
- return ' '.join(punctuated_text)
81
 
82
- def generate_text(input_text):
83
- # Encode the input text
84
- input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
85
 
86
- # Generate text
87
- with torch.no_grad():
88
- output_ids = model.generate(input_ids, max_length=50)
89
 
90
- # Decode the generated text
91
- generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
92
-
93
- # Restore punctuation
94
- generated_text_punct = restore_punctuation(generated_text, punct_restorer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- return generated_text_punct
 
 
 
 
97
 
98
 
99
 
100
  # Create gradio demo
101
- print("4/4 - Launching demo")
102
 
103
  title = "πŸ‘» 🫣 Generate a Two-Sentence Horror Story 😱 πŸ‘»"
104
  description = """
105
  <center>The bot was trained to generate two-sentence horror stories based on r/TwoSentenceHorror. <i>Spooky!</i></center>
106
  """
107
 
108
- article = "Check out [the subreddit](https://www.reddit.com/r/TwoSentenceHorror) that this demo is based off of. Or, check out the dataset [here](https://www.kaggle.com/datasets/voanthony/two-sentence-horror-jan-2015-apr-2023)."
 
 
 
 
 
 
 
 
109
 
110
 
111
  demo = gr.Interface(
112
  fn=generate_text,
113
- inputs=gr.Textbox(lines=4, placeholder="Enter the first sentence of your horror story here...", label="First Sentence"),
114
- outputs=gr.Textbox(lines=4, label="Second Sentence"),
 
 
 
 
 
 
 
 
 
115
  title=title,
116
  description=description,
117
  article=article,
 
1
  # Imports
2
  # Core Imports
3
  import torch
4
+
5
  # Model-related Imports
6
  from transformers import BartTokenizer, BartForConditionalGeneration # fine-tuned BART model
7
  from transformers import AutoTokenizer, AutoModelForTokenClassification # restore punct
8
  from transformers import pipeline # restore punct
9
  import gradio as gr
10
 
11
+ # Evaluation Imports
12
+ from sklearn.feature_extraction.text import TfidfVectorizer
13
+ from sklearn.metrics.pairwise import cosine_similarity
14
+ import pandas as pd
15
+ import string
16
+
17
 
18
 
19
  # Instantiate model to restore punctuation
20
+ print("1/7 - Instantiating model to restore punctuation")
21
 
22
  punct_model_path = "felflare/bert-restore-punctuation"
23
  # Load punct tokenizer and model
 
28
 
29
 
30
  # Instantiate fine-tuned horror BART model
31
+ print("2/7 - Instantiating two-sentence horror generation model")
32
 
33
  model_path = 'voacado/bart-two-sentence-horror'
34
  # Load tokenizer and model
 
37
 
38
 
39
 
40
+ # Load data for evaluation metrics
41
+ print("3/7 - Reading in data")
42
+ data = pd.read_csv("./reddit_cleansed_data.csv")
43
+ data['weighted_score'] = data['score'] + (10 * data['num_comments']) + (100 * data['gilded_count'])
44
+ dataset_stories = (data['title'] + ' ' + data['selftext']).to_list()
45
+
46
+
47
+
48
+ # Instantiate evaluation metrics - Cosine Similarity with TF-IDF
49
+ print("4/7 - Instantiating evaluation metrics - Cosine Similarity with TF-IDF")
50
+ # Pre-vectorize dataset
51
+ vectorizer = TfidfVectorizer()
52
+ dataset_matrix = vectorizer.fit_transform(dataset_stories)
53
+
54
+ def eval_cosine_similarity(input_sentence: str) -> [str, str]:
55
+ """
56
+ Evaluate cosine similarity between input sentence and each story in the dataset.
57
+
58
+ Args:
59
+ input_sentence (str): user story (first sentence)
60
+
61
+ Returns:
62
+ [str, str]: most similar story, weighted score
63
+ """
64
+ # Vectorize input sentence using the existing vocab
65
+ input_vec = vectorizer.transform([input_sentence])
66
+ # Get cosine similarity
67
+ similarities = cosine_similarity(input_vec, dataset_matrix)
68
+ # Find most similar story
69
+ most_similar_story_idx = similarities.argmax()
70
+ most_similar_story = dataset_stories[most_similar_story_idx]
71
+ # Get weighted score of most similar story
72
+ weighted_score = data['weighted_score'][most_similar_story_idx]
73
+
74
+ return most_similar_story, weighted_score
75
+
76
+
77
+
78
+ # Instantiate evaluation metrics - Jaccard Similarity
79
+ print("5/7 - Instantiating evaluation metrics - Jaccard Similarity")
80
+ def tokenize(text: str):
81
+ """
82
+ Convert text to lowercase and remove punctuation, then tokenize.
83
+
84
+ Args:
85
+ text (str): user story
86
+
87
+ Returns:
88
+ set: set of tokens
89
+ """
90
+ text = text.lower()
91
+ text = text.translate(str.maketrans('', '', string.punctuation))
92
+ tokens = text.split()
93
+ return set(tokens)
94
+
95
+ def jaccard_similarity(set1: set, set2: set):
96
+ """
97
+ Calculate Jaccard similarity between two sets.
98
+
99
+ Args:
100
+ set1 (set): user_tokens
101
+ set2 (set): story_tokens
102
+
103
+ Returns:
104
+ float: Jaccard similarity
105
+ """
106
+ intersection = set1.intersection(set2)
107
+ union = set1.union(set2)
108
+ return len(intersection) / len(union)
109
+
110
+ def eval_jaccard_similarity(input_sentence: str) -> [str, str]:
111
+ """
112
+ Evaluate Jaccard similarity between input sentence and each story in the dataset.
113
+
114
+ Args:
115
+ input_sentence (str): user story (first sentence)
116
+
117
+ Returns:
118
+ [str, str]: most similar story, weighted score
119
+ """
120
+ # Tokenize the user story
121
+ user_tokens = tokenize(input_sentence)
122
+
123
+ # Initialize variables to find the most similar story
124
+ max_similarity = 0
125
+ most_similar_story = ''
126
+
127
+ # Compare with each story in the dataset
128
+ for story in dataset_stories:
129
+ story_tokens = tokenize(story)
130
+ similarity = jaccard_similarity(user_tokens, story_tokens)
131
+ if similarity > max_similarity:
132
+ max_similarity = similarity
133
+ most_similar_story = story
134
+ max_score = data['weighted_score'][dataset_stories.index(story)]
135
+
136
+ return most_similar_story, max_score
137
+
138
+
139
+
140
  # Set up inference
141
+ print("6/7 - Setting parameters for inference")
142
 
143
  # Set the model to evaluation mode
144
  model.eval()
 
147
  model.to(device)
148
 
149
  # Restore punct
150
+ def restore_punctuation(text: str, restorer: pipeline) -> str:
151
+ """
152
+ Restore punctuation to text.
153
+
154
+ Args:
155
+ text (str): full story (first and second sentences)
156
+ restorer (pipeline): model that restores punctuation
157
+
158
+ Returns:
159
+ str: punctuated text (based on input)
160
+ """
161
  # Use the model to predict punctuation
162
  punctuated_output = restorer(text)
163
+ punct_text = []
164
 
165
  # Define punctuation marks (note: not including left-side because we want space still)
166
  punctuation_marks = ["!", "?", ".", "-", ":", ";", "'", "’", ",", ")", "]", "}", "…", "”", "’’", "''"]
 
170
 
171
  # If token is punctuation, append to previous token
172
  if cur_token in punctuation_marks:
173
+ punct_text[-1] += cur_token
174
 
175
  # If previous token is quotations, append to previous token
176
+ elif punct_text and punct_text[-1] in ["'", "’", "β€œ", "β€˜", "β€˜β€˜", "β€œβ€œ"]:
177
+ punct_text[-1] += cur_token
178
 
179
  # If token is a contraction or a quote, append to previous token (no space)
180
  elif cur_token.lower() in ["s", "t", "re", "ve", "ll", "d", "m"]:
181
  # Remove space for contractions
182
+ punct_text[-1] += cur_token
183
 
184
  # if prediction is LABEL_0, token should be capitalized
185
  elif elem.get('entity') == 'LABEL_0':
186
+ punct_text.append(cur_token.capitalize())
187
 
188
  # else if prediction is LABEL_1, token should be lowercase
189
  # elif elem.get('entity') == 'LABEL_1':
190
  else:
191
+ punct_text.append(cur_token)
192
 
193
  # If there's no period at the end of the story, add one
194
+ if punct_text[-1][-1] != '.':
195
+ punct_text[-1] = punct_text[-1] + '.'
196
 
197
+ return ' '.join(punct_text)
198
 
199
+ def generate_text(input_text: str, full_sentence: str) -> [str, str, float, str, float]:
200
+ """
201
+ Generate the second sentence of the horror story given the first (input_text).
202
 
203
+ Args:
204
+ input_text (str): first sentence of the horror story
205
+ full_sentence (str): full story (first and second sentences)
206
 
207
+ Returns:
208
+ gen_text_punct (str): second sentence of the horror story
209
+ similar_story_cosine (str): most similar story (cosine similarity)
210
+ cosine_score (float): score of most similar story (cosine similarity)
211
+ similar_story_jaccard (str): most similar story (Jaccard similarity)
212
+ jaccard_score (float): score of most similar story (Jaccard similarity)
213
+ """
214
+ # If user only enters first sentence, generate second sentence
215
+ if not full_sentence:
216
+ # Encode the input text
217
+ input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
218
+
219
+ # Generate text
220
+ with torch.no_grad():
221
+ output_ids = model.generate(input_ids, max_length=50)
222
+
223
+ # Decode the generated text
224
+ gen_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
225
+
226
+ # Restore punctuation
227
+ gen_text_punct = restore_punctuation(gen_text, punct_restorer)
228
+ full_sentence = input_text + ' ' + gen_text_punct
229
+ else:
230
+ gen_text_punct = "N/A"
231
 
232
+ # Calculate Cosine and Jaccard similarity
233
+ similar_story_cosine, cosine_score = eval_cosine_similarity(full_sentence)
234
+ similar_story_jaccard, jaccard_score = eval_jaccard_similarity(full_sentence)
235
+
236
+ return gen_text_punct, similar_story_cosine, cosine_score, similar_story_jaccard, jaccard_score
237
 
238
 
239
 
240
  # Create gradio demo
241
+ print("7/7 - Launching demo")
242
 
243
  title = "πŸ‘» 🫣 Generate a Two-Sentence Horror Story 😱 πŸ‘»"
244
  description = """
245
  <center>The bot was trained to generate two-sentence horror stories based on r/TwoSentenceHorror. <i>Spooky!</i></center>
246
  """
247
 
248
+ article = """
249
+ Check out [the subreddit](https://www.reddit.com/r/TwoSentenceHorror) that this demo is based off of. Or, check out the dataset [here](https://www.kaggle.com/datasets/voanthony/two-sentence-horror-jan-2015-apr-2023).
250
+
251
+ The language model is fine-tuned from ['facebook/bart-base'](https://huggingface.co/facebook/bart-base). We import, then update the weights for the model to generate two-sentence horror stories. The model is fine-tuned over 3 epochs to avoid catastrophic forgetting. We also use a separate model (['felflare/bert-restore-punctuation'](https://huggingface.co/felflare/bert-restore-punctuation?text=My+name+is+wolfgang+and+I+live+in+berlin)) to restore punctuation.
252
+
253
+ For evaluation, the generated story is compared to the most similar Reddit post (using either cosine or Jaccard similarity). The score of the most similar post is also returned. The score is calculated as the sum of the post score, 10 * number of comments, and 100 * number of gilds. The score is used as a proxy for the popularity of the post.
254
+
255
+ Users may also enter an entire story in the second input prompt rather than generating the remainder of the story. This will be used for evaluation metrics and no story will be generated.
256
+ """
257
 
258
 
259
  demo = gr.Interface(
260
  fn=generate_text,
261
+ inputs=[
262
+ gr.Textbox(lines=4, placeholder="Enter the first sentence of your horror story here...", label="First Sentence"),
263
+ gr.Textbox(lines=4, placeholder="Or, enter full story for evaluation here...", label="Eval - Full Story")
264
+ ],
265
+ outputs=[
266
+ gr.Textbox(lines=4, label="Generated Second Sentence"),
267
+ gr.Textbox(lines=3, label="Cosine Similarity - Sentence"),
268
+ gr.Textbox(lines=1, label="Cosine Similarity - Post Score"),
269
+ gr.Textbox(lines=3, label="Jaccard Similarity - Sentence"),
270
+ gr.Textbox(lines=1, label="Jaccard Similarity - Post Score")
271
+ ],
272
  title=title,
273
  description=description,
274
  article=article,
reddit_cleansed_data.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4107fb5ebe3aa92bd7cc775fcdccab5d07f45bce613f184bb8dd0f4ed808e628
3
+ size 20222577