ACMC commited on
Commit
7e73556
0 Parent(s):

initial commit

Browse files
Files changed (7) hide show
  1. .gitattributes +35 -0
  2. .gitignore +2 -0
  3. README.md +12 -0
  4. app.py +204 -0
  5. requirements.txt +13 -0
  6. utils.py +342 -0
  7. validation.py +174 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.jsonl
2
+ __pycache__/
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Whatsapp Chats Finetuning Formatter
3
+ emoji: 👀
4
+ colorFrom: pink
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 4.20.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ from uuid import uuid4
3
+ import gradio as gr
4
+ import datasets
5
+ import json
6
+ import io
7
+ from utils import (
8
+ process_chat_file,
9
+ transform_conversations_dataset_into_training_examples,
10
+ )
11
+ from validation import (
12
+ check_format_errors,
13
+ check_token_counts,
14
+ estimate_cost,
15
+ get_distributions,
16
+ )
17
+ import matplotlib.pyplot as plt
18
+
19
+
20
+ def convert_to_dataset(files, do_spelling_correction, progress):
21
+ modified_dataset = None
22
+ for file in progress.tqdm(files, desc="Processing files"):
23
+ if modified_dataset is None:
24
+ # First file
25
+ modified_dataset = process_chat_file(file, do_spelling_correction=do_spelling_correction)
26
+ else:
27
+ # Concatenate the datasets
28
+ this_file_dataset = process_chat_file(file, do_spelling_correction=do_spelling_correction)
29
+ modified_dataset = datasets.concatenate_datasets(
30
+ [modified_dataset, this_file_dataset]
31
+ )
32
+ return modified_dataset
33
+
34
+
35
+ def file_upload_callback(files, system_prompt, do_spelling_correction, validation_split, progress=gr.Progress()):
36
+ print(f"Processing {files}")
37
+ full_system_prompt = f"""You are a chatbot. Your goal is to simulate realistic, natural chat conversations as if you were me.
38
+ # Task
39
+ A participant can send multiple messages in a row, delimited by '\"', in the following schema:
40
+ {{string}}[]. Your answer always needs to be JSON compliant. Always start your answer with [\"
41
+ # Information about me
42
+ You should use the following information about me to answer:
43
+ {system_prompt}
44
+ # Example
45
+ [{{\"role\":\"user\",\"content\":\"[\"Hello!\",\"How are you?\"]\"}},{{\"role\":\"assistant\",\"content\":\"[\"Hi!\",\"I'm doing great.\",\"What about you?\"]\"}},{{\"role\":\"user\",\"content\":\"[\"I'm doing well.\",\"Have you been travelling?\"]\"}}]
46
+ Response:
47
+ [{{\"role\":\"assistant\",\"content\":\"[\"Yes, I've been to many places.\",\"I love travelling.\"]\"}}]"""
48
+
49
+ # Avoid using the full system prompt for now, as it is too long and increases the cost of the training
50
+ full_system_prompt = system_prompt
51
+ dataset = convert_to_dataset(files=files, progress=progress, do_spelling_correction=do_spelling_correction)
52
+ training_examples_ds = transform_conversations_dataset_into_training_examples(
53
+ conversations_ds=dataset, system_prompt=full_system_prompt
54
+ )
55
+
56
+ # Split into training and validation datasets (80% and 20%)
57
+ training_examples_ds = training_examples_ds.train_test_split(test_size=validation_split, seed=42)
58
+ training_examples_ds, validation_examples_ds = training_examples_ds["train"], training_examples_ds["test"]
59
+
60
+ format_errors = check_format_errors(training_examples_ds)
61
+ distributions = get_distributions(training_examples_ds)
62
+ cost_stats = estimate_cost(training_examples_ds)
63
+
64
+ stats = {
65
+ "Format Errors": format_errors,
66
+ "Number of examples missing system message": distributions["n_missing_system"],
67
+ "Number of examples missing user message": distributions["n_missing_user"],
68
+ "Cost Statistics": cost_stats,
69
+ }
70
+
71
+ fig_num_messages_distribution_plot = plt.figure()
72
+ num_messages_distribution_plot = plt.hist(distributions["n_messages"], bins=20)
73
+
74
+ fig_num_total_tokens_per_example_plot = plt.figure()
75
+ num_total_tokens_per_example_plot = plt.hist(distributions["convo_lens"], bins=20)
76
+
77
+ fig_num_assistant_tokens_per_example_plot = plt.figure()
78
+ num_assistant_tokens_per_example_plot = plt.hist(
79
+ distributions["assistant_message_lens"],
80
+ bins=20
81
+ )
82
+
83
+ # The DownloadFile component requires a path to the file, it can't accept a buffer to keep the file in memory.
84
+ # Therefore, we need to save the buffer to a file and then pass the path to the DownloadFile component.
85
+ # However, if different users are using the app at the same time, we need to make sure that the file is unique AND that no user can access the file of another user.
86
+ # We can use a UUID generator to create a unique file name.
87
+ uuid = str(uuid4())
88
+ file_path = f"training_examples_{uuid}.jsonl"
89
+ training_examples_ds.to_json(path_or_buf=file_path, force_ascii=False)
90
+
91
+ file_path_validation = f"validation_examples_{uuid}.jsonl"
92
+ validation_examples_ds.to_json(path_or_buf=file_path_validation, force_ascii=False)
93
+
94
+ return (
95
+ file_path,
96
+ gr.update(visible=True),
97
+ file_path_validation,
98
+ gr.update(visible=True),
99
+ stats,
100
+ fig_num_messages_distribution_plot,
101
+ fig_num_total_tokens_per_example_plot,
102
+ fig_num_assistant_tokens_per_example_plot
103
+ )
104
+
105
+
106
+ def remove_file_and_hide_button(file_path):
107
+ import os
108
+
109
+ try:
110
+ os.remove(file_path)
111
+ except Exception as e:
112
+ print(f"Error removing file {file_path}: {e}")
113
+
114
+ return gr.update(visible=False)
115
+
116
+
117
+ theme = gr.themes.Default(primary_hue="cyan", secondary_hue="fuchsia")
118
+
119
+ with gr.Blocks(theme=theme) as demo:
120
+ gr.Markdown(
121
+ """
122
+ # WhatsApp Chat to Dataset Converter
123
+ Upload your WhatsApp chat files and convert them into a Dataset.
124
+ """
125
+ )
126
+ gr.Markdown(
127
+ """
128
+ ## Instructions
129
+ 1. Click on the "Upload WhatsApp Chat Files" button.
130
+ 2. Select the WhatsApp chat files you want to convert.
131
+ 3. Write a prompt about you to give context to the training examples.
132
+ 4. Click on the "Submit" button.
133
+ 5. Wait for the process to finish.
134
+ 6. Download the generated training examples as a JSONL file.
135
+ 7. Use the training examples to train your own model.
136
+ """
137
+ )
138
+
139
+ input_files = gr.File(
140
+ label="Upload WhatsApp Chat Files",
141
+ type="filepath",
142
+ file_count="multiple",
143
+ file_types=["txt"],
144
+ )
145
+
146
+ system_prompt = gr.Textbox(
147
+ label="System Prompt",
148
+ placeholder="Background information about you.",
149
+ lines=5,
150
+ info="Enter the system prompt to be used for the training examples generation. This is the background information about you that will be used to generate the training examples.",
151
+ value="""Aldan is an AI researcher who loves to play around with AI systems, travelling and learning new things.""",
152
+ )
153
+
154
+ do_spelling_correction = gr.Checkbox(
155
+ label="Do Spelling Correction (English)",
156
+ info="Check this box if you want to perform spelling correction on the chat messages before generating the training examples.",
157
+ )
158
+
159
+ # Allow the user to choose the validation split size
160
+ validation_split = gr.Slider(
161
+ minimum=0.0,
162
+ maximum=0.5,
163
+ value=0.2,
164
+ interactive=True,
165
+ label="Validation Split",
166
+ info="Choose the percentage of the dataset to be used for validation. For example, if you choose 0.2, 20% of the dataset will be used for validation and 80% for training.",
167
+ )
168
+
169
+ submit = gr.Button(value="Submit", variant="primary")
170
+
171
+ output_file = gr.DownloadButton(label="Download Generated Training Examples", visible=False, variant="primary")
172
+ output_file_validation = gr.DownloadButton(label="Download Generated Validation Examples", visible=False, variant="secondary")
173
+ # output_example = gr.JSON(label="Example Training Example")
174
+
175
+ with gr.Group():
176
+ # Statistics about the dataset
177
+ gr.Markdown("## Statistics")
178
+ written_stats = gr.JSON()
179
+ num_messages_distribution_plot = gr.Plot(label="Number of Messages Distribution")
180
+ num_total_tokens_per_example_plot = gr.Plot(label="Total Number of Tokens per Example")
181
+ num_assistant_tokens_per_example_plot = gr.Plot(
182
+ label="Number of Assistant Tokens per Example"
183
+ )
184
+
185
+ submit.click(
186
+ file_upload_callback,
187
+ inputs=[input_files, system_prompt, do_spelling_correction, validation_split],
188
+ outputs=[
189
+ output_file,
190
+ output_file,
191
+ output_file_validation,
192
+ output_file_validation,
193
+ written_stats,
194
+ num_messages_distribution_plot,
195
+ num_total_tokens_per_example_plot,
196
+ num_assistant_tokens_per_example_plot,
197
+ ]
198
+ )
199
+
200
+ output_file.click(remove_file_and_hide_button, inputs=[output_file], outputs=[output_file])
201
+ output_file_validation.click(remove_file_and_hide_button, inputs=[output_file_validation], outputs=[output_file_validation])
202
+
203
+ if __name__ == "__main__":
204
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ contextualSpellCheck==0.4.4
2
+ datasets==2.18.0
3
+ es-core-news-sm @ https://github.com/explosion/spacy-models/releases/download/es_core_news_sm-3.7.0/es_core_news_sm-3.7.0-py3-none-any.whl#sha256=61e6e5530941f5880166855f09f60d7e6ba79ec1e8e45f96244bdb1eb169eb1d
4
+ en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
5
+ gradio==4.20.1
6
+ matplotlib==3.8.3
7
+ numpy==1.26.4
8
+ pandas==2.2.1
9
+ spacy==3.7.4
10
+ tiktoken==0.6.0
11
+ torch==2.2.1
12
+ transformers==4.38.2
13
+ pyspellchecker==0.8.1
utils.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import datetime
3
+ import os
4
+ import json
5
+
6
+ import re
7
+
8
+ exp = re.compile(
9
+ r"(?P<month>\d+)/(?P<day>\d+)/(?P<year>\d+), (?P<hour>\d+):(?P<minute>\d+) - (?P<contact_name>.+): (?P<message>.+)"
10
+ )
11
+
12
+
13
+ def process_line(example):
14
+ # The lines have this format: dd/mm/yy, hh:mm - <person>: <msg>
15
+ try:
16
+ groups = exp.match(example["text"]).groupdict()
17
+ timestamp = datetime.datetime(
18
+ int(groups["year"]),
19
+ int(groups["month"]),
20
+ int(groups["day"]),
21
+ int(groups["hour"]),
22
+ int(groups["minute"]),
23
+ ).timestamp()
24
+ return {
25
+ "message": groups["message"],
26
+ "contact_name": groups["contact_name"],
27
+ "timestamp": timestamp,
28
+ }
29
+ except Exception as e:
30
+ print(e)
31
+ print(example["text"])
32
+ raise e
33
+
34
+
35
+ # %%
36
+ # Now, create message groups ('conversations')
37
+ # The idea is to group messages that are close in time
38
+ # We'll use a 240 minute threshold
39
+ MINUTES_THRESHOLD = 240
40
+
41
+
42
+ def group_messages(messages_iterable):
43
+ groups = []
44
+ current_group = [next(messages_iterable)]
45
+ for message in messages_iterable:
46
+ assert len(current_group) > 0 # We should never have an empty group
47
+ if (
48
+ message["timestamp"] - current_group[-1]["timestamp"]
49
+ < MINUTES_THRESHOLD * 60
50
+ ):
51
+ current_group.append(message)
52
+ else:
53
+ groups.append(current_group)
54
+ current_group = [message]
55
+ groups.append(current_group)
56
+ return groups
57
+
58
+
59
+ def printable_conversation(conversation):
60
+ return "\n".join(
61
+ [f"{message['contact_name']}: {message['message']}" for message in conversation]
62
+ )
63
+
64
+
65
+ # %%
66
+ # Use spacy to spell check the messages
67
+ import spacy
68
+ import contextualSpellCheck
69
+ from spellchecker import SpellChecker
70
+ spell = SpellChecker()
71
+ #nlp = spacy.load("es_core_news_sm")
72
+ nlp = spacy.load("en_core_web_sm")
73
+
74
+
75
+ def spell_check_conversation(conversation):
76
+ for i, message in enumerate(conversation["conversations"]):
77
+ # Use SpaCy to get the words
78
+ words = spell.split_words(message["message"])
79
+ print(f"Words: {words}")
80
+ corrected_message = []
81
+ for word in words:
82
+ correction = spell.correction(word)
83
+ if (correction != None) and (correction != word):
84
+ print(f"Spell check: {word} -> {correction}")
85
+ corrected_message.append(correction)
86
+ else:
87
+ corrected_message.append(word)
88
+
89
+ print(f"Corrected message: {corrected_message}")
90
+ joined_message = " ".join(corrected_message)
91
+ conversation["conversations"][i]["message"] = joined_message
92
+
93
+ return conversation
94
+
95
+
96
+ def spell_check_conversation_spacy(conversation):
97
+
98
+ nlp.add_pipe(
99
+ "contextual spellchecker",
100
+ config={
101
+ "model_name": "bert-base-multilingual-uncased",
102
+ "max_edit_dist": 2,
103
+ },
104
+ )
105
+ docs = list(nlp.pipe([msg["message"] for msg in conversation["conversations"]]))
106
+ for i, doc in enumerate(docs):
107
+ if doc._.performed_spellCheck:
108
+ print(f"Spell checked: {doc.text} -> {doc._.outcome_spellCheck}")
109
+ conversation["conversations"][i]["message"] = doc._.outcome_spellCheck
110
+
111
+ return conversation
112
+
113
+
114
+ def remove_whatapp_annotations(conversation):
115
+ """
116
+ Removes the following annotations from the messages:
117
+ - <This message was edited>
118
+ """
119
+ for message in conversation["conversations"]:
120
+ message["message"] = re.sub(
121
+ r"<This message was edited>", "", message["message"]
122
+ )
123
+ return conversation
124
+
125
+
126
+ # %%
127
+ """
128
+ Sometimes, people write concurrently in the same conversation. We'll try to detect that and reorder the messages.
129
+ For example, if we have a conversation like this:
130
+ A: Hi
131
+ A: How are you?
132
+ B: Hi
133
+ B: I'm fine, thanks
134
+ A: I'm fine too
135
+ We'll reorder it to:
136
+ A: Hi
137
+ B: Hi
138
+ A: How are you?
139
+ B: I'm fine, thanks
140
+ A: I'm fine too
141
+
142
+ To do it, we'll use MobileBERT with the next sentence prediction head. We'll use the first message as the first sentence, and the second message as the second sentence. If the model predicts that the second sentence is more likely to be the next sentence, we'll swap the messages.
143
+ """
144
+
145
+ from transformers import AutoTokenizer, AutoModelForNextSentencePrediction
146
+ import torch
147
+
148
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
149
+ model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased")
150
+ if torch.cuda.is_available():
151
+ model.cuda()
152
+
153
+
154
+ def swap_messages_if_needed(message1, message2):
155
+ # If the messages have the same contact, we don't swap them
156
+ if message1["contact_name"] == message2["contact_name"]:
157
+ return message1, message2
158
+ # The timestamp must have a difference of less than 2 minutes. First, convert to datetime
159
+ datetime1 = datetime.datetime.fromtimestamp(message1["timestamp"])
160
+ datetime2 = datetime.datetime.fromtimestamp(message2["timestamp"])
161
+ if (datetime2 - datetime1).total_seconds() > 2 * 60:
162
+ return message1, message2
163
+ # If one of the messages has less than 3 words, we don't swap them
164
+ if len(message1["message"].split()) < 3 or len(message2["message"].split()) < 3:
165
+ return message1, message2
166
+ # We'll use the first message as the first sentence, and the second message as the second sentence
167
+ inputs = tokenizer(message1["message"], message2["message"], return_tensors="pt")
168
+ reverse_inputs = tokenizer(
169
+ message2["message"], message1["message"], return_tensors="pt"
170
+ )
171
+ # Join them in a single batch
172
+ joined_inputs = torch.cat([inputs["input_ids"], reverse_inputs["input_ids"]], dim=0)
173
+ if torch.cuda.is_available():
174
+ joined_inputs = joined_inputs.cuda()
175
+ with torch.no_grad():
176
+ outputs = model(input_ids=joined_inputs)
177
+ # The output is a tuple with the logits for each class (next sentence or not)
178
+ # We'll take the first one (next sentence)
179
+ logits = outputs[0]
180
+ # Apply softmax
181
+ logits = torch.softmax(logits, dim=1)
182
+ # We have two probabilities: the probability of 1 -> 2, and the probability of 2 -> 1
183
+ # We'll take the difference
184
+ swap = logits[0, 0] - logits[1, 0] < -0.2
185
+ if swap:
186
+ # Swap the messages
187
+ print(f"YES Swapping messages: {message1['message']} <-> {message2['message']}")
188
+ return message2, message1
189
+ else:
190
+ # print(f"NOT swapping messages: {message1['message']} <-> {message2['message']}")
191
+ return message1, message2
192
+
193
+
194
+ def swap_messages_if_needed_in_conversation(conversation):
195
+ # We'll use the first message as the first sentence, and the second message as the second sentence
196
+ if len(conversation) <= 2:
197
+ return conversation
198
+ new_conversation = [
199
+ conversation[0],
200
+ conversation[1],
201
+ ] # We'll always keep the first message in the same position
202
+ for i in range(2, len(conversation)):
203
+ message1 = new_conversation[-1]
204
+ message2 = conversation[i]
205
+ message1, message2 = swap_messages_if_needed(message1, message2)
206
+ new_conversation[-1] = message1
207
+ new_conversation.append(message2)
208
+
209
+ # print(f"\nOriginal conversation:\n{printable_conversation(conversation)}")
210
+ # print(f"\nNew conversation:\n{printable_conversation(new_conversation)}")
211
+ return new_conversation
212
+
213
+
214
+ test_conversation = [
215
+ {"message": "Hola!", "contact_name": "A", "timestamp": 1},
216
+ {
217
+ "message": "Está todo bien, gracias por preguntar!",
218
+ "contact_name": "B",
219
+ "timestamp": 2,
220
+ },
221
+ {
222
+ "message": "Hola, qué tal estás? Espero que vaya todo bien por España.",
223
+ "contact_name": "A",
224
+ "timestamp": 3,
225
+ },
226
+ ]
227
+ # print(swap_messages_if_needed_in_conversation(test_conversation))
228
+
229
+ # %%
230
+ # Now, we'll train an mT5 model to generate the next message in a conversation
231
+ import os
232
+
233
+
234
+ # For the contact_name, rewrite everything that is not 'Aldi' to 'Other'
235
+ def rewrite_contact_name(conversation):
236
+ for message in conversation["conversations"]:
237
+ if message["contact_name"] != "Aldi":
238
+ message["contact_name"] = "Other"
239
+ return conversation
240
+
241
+
242
+ # %%
243
+ def process_chat_file(file, do_spelling_correction, do_reordering=False):
244
+ """
245
+ Process a chat file and return a dataset with the conversations.
246
+ """
247
+ ds = (
248
+ datasets.load_dataset("text", data_files=[file])["train"]
249
+ .filter(
250
+ # Has to begin by date, time, contact name, and contain at least a ':' symbol
251
+ lambda x: re.match(
252
+ r"^\d{1,2}/\d{1,2}/\d{1,2},\s\d{2}:\d{2}\s-\s.+:", x["text"]
253
+ )
254
+ )
255
+ .map(process_line, remove_columns=["text"])
256
+ )
257
+
258
+ # Filter out messages that just say '<Media omitted>'
259
+ ds = ds.filter(lambda x: x["message"] != "<Media omitted>")
260
+
261
+ groups = group_messages(iter(ds))
262
+ # Generate the dataset
263
+ conversations_ds = datasets.Dataset.from_dict({"conversations": groups})
264
+
265
+ # Filter out conversations with less than 10 messages
266
+ conversations_ds = conversations_ds.filter(lambda x: len(x["conversations"]) >= 10)
267
+
268
+ conversations_ds_without_whatsapp_annotations = conversations_ds.map(
269
+ remove_whatapp_annotations,
270
+ num_proc=os.cpu_count() - 1,
271
+ )
272
+
273
+ if do_spelling_correction:
274
+ spell_checked_conversations_ds = (
275
+ conversations_ds_without_whatsapp_annotations.map(spell_check_conversation)
276
+ )
277
+ else:
278
+ spell_checked_conversations_ds = conversations_ds_without_whatsapp_annotations
279
+
280
+ if do_reordering:
281
+ reordered_conversations_ds = spell_checked_conversations_ds.map(
282
+ swap_messages_if_needed_in_conversation
283
+ )
284
+ else:
285
+ reordered_conversations_ds = spell_checked_conversations_ds
286
+
287
+ changed_contact_name_ds = reordered_conversations_ds.map(
288
+ rewrite_contact_name
289
+ ) # , num_proc=os.cpu_count() - 1)
290
+
291
+ # Filter out conversations with only one contact
292
+ changed_contact_name_ds = changed_contact_name_ds.filter(
293
+ lambda x: len(set([msg["contact_name"] for msg in x["conversations"]])) > 1
294
+ )
295
+
296
+ return changed_contact_name_ds
297
+
298
+
299
+ def transform_conversations_dataset_into_training_examples(
300
+ conversations_ds, system_prompt
301
+ ):
302
+ """
303
+ Takes in a dataset with conversations and returns a dataset with training examples.
304
+
305
+ The input dataset contains a single column (conversations), with each row being a list of messages with this format:
306
+ ```
307
+ [{'contact_name': 'Aldi', 'message': <message>, 'timestamp': <time>}, {'contact_name': 'Other', 'message': <message>, 'timestamp': <time>}, ... ]
308
+ ```
309
+
310
+ Each row will be converted to fit the format of the training examples.
311
+
312
+ The training examples have the following format:
313
+ ```
314
+ {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
315
+ {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "William Shakespeare"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
316
+ {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "384,400 kilometers"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]}
317
+ ```
318
+ """
319
+
320
+ def process_one_example(example):
321
+ messages = [{"role": "system", "content": [system_prompt]}]
322
+ for msg in example["conversations"]:
323
+ converted_role = "assistant" if msg["contact_name"] == "Aldi" else "user"
324
+ if converted_role == messages[-1]["role"]:
325
+ messages[-1]["content"] += [msg["message"]]
326
+ else:
327
+ messages.append({"role": converted_role, "content": [msg["message"]]})
328
+ return {
329
+ "messages": [
330
+ {
331
+ "role": m["role"],
332
+ "content": json.dumps(m["content"], ensure_ascii=False),
333
+ }
334
+ for m in messages
335
+ ]
336
+ }
337
+
338
+ return conversations_ds.map(
339
+ process_one_example,
340
+ remove_columns=["conversations"],
341
+ num_proc=os.cpu_count() - 1,
342
+ )
validation.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from collections import defaultdict
3
+ import tiktoken
4
+
5
+
6
+ def check_format_errors(train_dataset):
7
+ """
8
+ Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
9
+ """
10
+ # Format error checks
11
+ format_errors = defaultdict(int)
12
+
13
+ for ex in train_dataset:
14
+ if not isinstance(ex, dict):
15
+ format_errors["data_type"] += 1
16
+ continue
17
+
18
+ messages = ex.get("messages", None)
19
+ if not messages:
20
+ format_errors["missing_messages_list"] += 1
21
+ continue
22
+
23
+ for message in messages:
24
+ if "role" not in message or "content" not in message:
25
+ format_errors["message_missing_key"] += 1
26
+
27
+ if any(k not in ("role", "content", "name", "function_call", "weight") for k in message):
28
+ format_errors["message_unrecognized_key"] += 1
29
+
30
+ if message.get("role", None) not in ("system", "user", "assistant", "function"):
31
+ format_errors["unrecognized_role"] += 1
32
+
33
+ content = message.get("content", None)
34
+ function_call = message.get("function_call", None)
35
+
36
+ if (not content and not function_call) or not isinstance(content, str):
37
+ format_errors["missing_content"] += 1
38
+
39
+ if not any(message.get("role", None) == "assistant" for message in messages):
40
+ format_errors["example_missing_assistant_message"] += 1
41
+
42
+ if format_errors:
43
+ print("Found errors:")
44
+ for k, v in format_errors.items():
45
+ print(f"{k}: {v}")
46
+ else:
47
+ print("No errors found")
48
+
49
+ return format_errors if format_errors else {}
50
+
51
+ def get_distributions(train_dataset):
52
+ """
53
+ Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
54
+
55
+ Gets the distributions of the number of messages per example, the total number of tokens per example, and the number of assistant tokens per example.
56
+ """
57
+ encoding = tiktoken.get_encoding("cl100k_base")
58
+
59
+ # not exact!
60
+ # simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
61
+ def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
62
+ num_tokens = 0
63
+ for message in messages:
64
+ num_tokens += tokens_per_message
65
+ for key, value in message.items():
66
+ num_tokens += len(encoding.encode(value))
67
+ if key == "name":
68
+ num_tokens += tokens_per_name
69
+ num_tokens += 3
70
+ return num_tokens
71
+
72
+ def num_assistant_tokens_from_messages(messages):
73
+ num_tokens = 0
74
+ for message in messages:
75
+ if message["role"] == "assistant":
76
+ num_tokens += len(encoding.encode(message["content"]))
77
+ return num_tokens
78
+
79
+
80
+ n_missing_system = 0
81
+ n_missing_user = 0
82
+ n_messages = []
83
+ convo_lens = []
84
+ assistant_message_lens = []
85
+
86
+ for ex in train_dataset:
87
+ messages = ex["messages"]
88
+ if not any(message["role"] == "system" for message in messages):
89
+ n_missing_system += 1
90
+ if not any(message["role"] == "user" for message in messages):
91
+ n_missing_user += 1
92
+ n_messages.append(len(messages))
93
+ convo_lens.append(num_tokens_from_messages(messages))
94
+ assistant_message_lens.append(num_assistant_tokens_from_messages(messages))
95
+
96
+ return {
97
+ "n_missing_system": n_missing_system,
98
+ "n_missing_user": n_missing_user,
99
+ "n_messages": n_messages,
100
+ "convo_lens": convo_lens,
101
+ "assistant_message_lens": assistant_message_lens
102
+ }
103
+
104
+
105
+ def check_token_counts(train_dataset):
106
+ """
107
+ Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
108
+ """
109
+ def print_distribution(values, name):
110
+ print(f"\n#### Distribution of {name}:")
111
+ print(f"min / max: {min(values)}, {max(values)}")
112
+ print(f"mean / median: {np.mean(values)}, {np.median(values)}")
113
+ print(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")
114
+
115
+
116
+
117
+ # Warnings and tokens counts
118
+ distributions = get_distributions(train_dataset)
119
+ n_missing_system = distributions["n_missing_system"]
120
+ n_missing_user = distributions["n_missing_user"]
121
+ n_messages = distributions["n_messages"]
122
+ convo_lens = distributions["convo_lens"]
123
+ assistant_message_lens = distributions["assistant_message_lens"]
124
+
125
+ print("Num examples missing system message:", n_missing_system)
126
+ print("Num examples missing user message:", n_missing_user)
127
+ print_distribution(n_messages, "num_messages_per_example")
128
+ print_distribution(convo_lens, "num_total_tokens_per_example")
129
+ print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
130
+ n_too_long = sum(l > 4096 for l in convo_lens)
131
+ print(
132
+ f"\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning"
133
+ )
134
+
135
+ return
136
+
137
+
138
+ def estimate_cost(train_dataset):
139
+ """
140
+ Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
141
+ """
142
+ distributions = get_distributions(train_dataset)
143
+ n_missing_system = distributions["n_missing_system"]
144
+ n_missing_user = distributions["n_missing_user"]
145
+ n_messages = distributions["n_messages"]
146
+ convo_lens = distributions["convo_lens"]
147
+ assistant_message_lens = distributions["assistant_message_lens"]
148
+
149
+
150
+
151
+ # Pricing and default n_epochs estimate
152
+ MAX_TOKENS_PER_EXAMPLE = 4096
153
+
154
+ TARGET_EPOCHS = 3
155
+ MIN_TARGET_EXAMPLES = 100
156
+ MAX_TARGET_EXAMPLES = 25000
157
+ MIN_DEFAULT_EPOCHS = 1
158
+ MAX_DEFAULT_EPOCHS = 25
159
+
160
+ n_epochs = TARGET_EPOCHS
161
+ n_train_examples = len(train_dataset)
162
+ if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
163
+ n_epochs = min(MAX_DEFAULT_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
164
+ elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
165
+ n_epochs = max(MIN_DEFAULT_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)
166
+
167
+ n_billing_tokens_in_dataset = sum(
168
+ min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens
169
+ )
170
+
171
+ return {
172
+ "Estimated number of tokens in dataset": n_billing_tokens_in_dataset,
173
+ f"Estimated number of tokens that will be billed (assuming {n_epochs} training epochs)": n_epochs * n_billing_tokens_in_dataset
174
+ }