acmc commited on
Commit
bf9e30f
1 Parent(s): 7e73556

Adapting to GCP

Browse files
Files changed (3) hide show
  1. app.py +107 -32
  2. utils.py +79 -27
  3. validation.py +10 -10
app.py CHANGED
@@ -10,7 +10,6 @@ from utils import (
10
  )
11
  from validation import (
12
  check_format_errors,
13
- check_token_counts,
14
  estimate_cost,
15
  get_distributions,
16
  )
@@ -22,44 +21,79 @@ def convert_to_dataset(files, do_spelling_correction, progress):
22
  for file in progress.tqdm(files, desc="Processing files"):
23
  if modified_dataset is None:
24
  # First file
25
- modified_dataset = process_chat_file(file, do_spelling_correction=do_spelling_correction)
 
 
26
  else:
27
  # Concatenate the datasets
28
- this_file_dataset = process_chat_file(file, do_spelling_correction=do_spelling_correction)
 
 
29
  modified_dataset = datasets.concatenate_datasets(
30
  [modified_dataset, this_file_dataset]
31
  )
32
  return modified_dataset
33
 
34
 
35
- def file_upload_callback(files, system_prompt, do_spelling_correction, validation_split, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
36
  print(f"Processing {files}")
37
  full_system_prompt = f"""You are a chatbot. Your goal is to simulate realistic, natural chat conversations as if you were me.
38
  # Task
39
- A participant can send multiple messages in a row, delimited by '\"', in the following schema:
40
- {{string}}[]. Your answer always needs to be JSON compliant. Always start your answer with [\"
41
  # Information about me
42
  You should use the following information about me to answer:
43
- {system_prompt}
44
- # Example
45
- [{{\"role\":\"user\",\"content\":\"[\"Hello!\",\"How are you?\"]\"}},{{\"role\":\"assistant\",\"content\":\"[\"Hi!\",\"I'm doing great.\",\"What about you?\"]\"}},{{\"role\":\"user\",\"content\":\"[\"I'm doing well.\",\"Have you been travelling?\"]\"}}]
46
- Response:
47
- [{{\"role\":\"assistant\",\"content\":\"[\"Yes, I've been to many places.\",\"I love travelling.\"]\"}}]"""
48
-
49
- # Avoid using the full system prompt for now, as it is too long and increases the cost of the training
50
- full_system_prompt = system_prompt
51
- dataset = convert_to_dataset(files=files, progress=progress, do_spelling_correction=do_spelling_correction)
 
 
52
  training_examples_ds = transform_conversations_dataset_into_training_examples(
53
- conversations_ds=dataset, system_prompt=full_system_prompt
 
 
 
 
54
  )
55
 
56
  # Split into training and validation datasets (80% and 20%)
57
- training_examples_ds = training_examples_ds.train_test_split(test_size=validation_split, seed=42)
58
- training_examples_ds, validation_examples_ds = training_examples_ds["train"], training_examples_ds["test"]
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- format_errors = check_format_errors(training_examples_ds)
61
- distributions = get_distributions(training_examples_ds)
62
- cost_stats = estimate_cost(training_examples_ds)
 
 
 
 
 
 
63
 
64
  stats = {
65
  "Format Errors": format_errors,
@@ -76,8 +110,7 @@ Response:
76
 
77
  fig_num_assistant_tokens_per_example_plot = plt.figure()
78
  num_assistant_tokens_per_example_plot = plt.hist(
79
- distributions["assistant_message_lens"],
80
- bins=20
81
  )
82
 
83
  # The DownloadFile component requires a path to the file, it can't accept a buffer to keep the file in memory.
@@ -99,7 +132,7 @@ Response:
99
  stats,
100
  fig_num_messages_distribution_plot,
101
  fig_num_total_tokens_per_example_plot,
102
- fig_num_assistant_tokens_per_example_plot
103
  )
104
 
105
 
@@ -151,6 +184,24 @@ with gr.Blocks(theme=theme) as demo:
151
  value="""Aldan is an AI researcher who loves to play around with AI systems, travelling and learning new things.""",
152
  )
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  do_spelling_correction = gr.Checkbox(
155
  label="Do Spelling Correction (English)",
156
  info="Check this box if you want to perform spelling correction on the chat messages before generating the training examples.",
@@ -168,23 +219,41 @@ with gr.Blocks(theme=theme) as demo:
168
 
169
  submit = gr.Button(value="Submit", variant="primary")
170
 
171
- output_file = gr.DownloadButton(label="Download Generated Training Examples", visible=False, variant="primary")
172
- output_file_validation = gr.DownloadButton(label="Download Generated Validation Examples", visible=False, variant="secondary")
 
 
 
 
 
 
173
  # output_example = gr.JSON(label="Example Training Example")
174
 
175
  with gr.Group():
176
  # Statistics about the dataset
177
  gr.Markdown("## Statistics")
178
  written_stats = gr.JSON()
179
- num_messages_distribution_plot = gr.Plot(label="Number of Messages Distribution")
180
- num_total_tokens_per_example_plot = gr.Plot(label="Total Number of Tokens per Example")
 
 
 
 
181
  num_assistant_tokens_per_example_plot = gr.Plot(
182
  label="Number of Assistant Tokens per Example"
183
  )
184
 
185
  submit.click(
186
  file_upload_callback,
187
- inputs=[input_files, system_prompt, do_spelling_correction, validation_split],
 
 
 
 
 
 
 
 
188
  outputs=[
189
  output_file,
190
  output_file,
@@ -194,11 +263,17 @@ with gr.Blocks(theme=theme) as demo:
194
  num_messages_distribution_plot,
195
  num_total_tokens_per_example_plot,
196
  num_assistant_tokens_per_example_plot,
197
- ]
198
  )
199
 
200
- output_file.click(remove_file_and_hide_button, inputs=[output_file], outputs=[output_file])
201
- output_file_validation.click(remove_file_and_hide_button, inputs=[output_file_validation], outputs=[output_file_validation])
 
 
 
 
 
 
202
 
203
  if __name__ == "__main__":
204
  demo.launch()
 
10
  )
11
  from validation import (
12
  check_format_errors,
 
13
  estimate_cost,
14
  get_distributions,
15
  )
 
21
  for file in progress.tqdm(files, desc="Processing files"):
22
  if modified_dataset is None:
23
  # First file
24
+ modified_dataset = process_chat_file(
25
+ file, do_spelling_correction=do_spelling_correction
26
+ )
27
  else:
28
  # Concatenate the datasets
29
+ this_file_dataset = process_chat_file(
30
+ file, do_spelling_correction=do_spelling_correction
31
+ )
32
  modified_dataset = datasets.concatenate_datasets(
33
  [modified_dataset, this_file_dataset]
34
  )
35
  return modified_dataset
36
 
37
 
38
+ def file_upload_callback(
39
+ files,
40
+ system_prompt,
41
+ do_spelling_correction,
42
+ validation_split,
43
+ user_role,
44
+ model_role,
45
+ whatsapp_name,
46
+ progress=gr.Progress(),
47
+ ):
48
  print(f"Processing {files}")
49
  full_system_prompt = f"""You are a chatbot. Your goal is to simulate realistic, natural chat conversations as if you were me.
50
  # Task
51
+ The {model_role} and the {user_role} can send multiple messages in a row, as a JSON list of strings. Your answer always needs to be JSON compliant. The strings are delimited by double quotes ("). The strings are separated by a comma (,). The list is delimited by square brackets ([, ]). Always start your answer with [", and close it with "]. Do not write anything else in your answer after "].
 
52
  # Information about me
53
  You should use the following information about me to answer:
54
+ {system_prompt}"""
55
+ # Example
56
+ # [{{\"role\":\"user\",\"content\":\"[\"Hello!\",\"How are you?\"]\"}},{{\"role\":\"assistant\",\"content\":\"[\"Hi!\",\"I'm doing great.\",\"What about you?\"]\"}},{{\"role\":\"user\",\"content\":\"[\"I'm doing well.\",\"Have you been travelling?\"]\"}}]
57
+ # Response:
58
+ # [{{\"role\":\"assistant\",\"content\":\"[\"Yes, I've been to many places.\",\"I love travelling.\"]\"}}]"""
59
+
60
+ # # Avoid using the full system prompt for now, as it is too long and increases the cost of the training
61
+ # full_system_prompt = system_prompt
62
+ dataset = convert_to_dataset(
63
+ files=files, progress=progress, do_spelling_correction=do_spelling_correction
64
+ )
65
  training_examples_ds = transform_conversations_dataset_into_training_examples(
66
+ conversations_ds=dataset,
67
+ system_prompt=full_system_prompt,
68
+ user_role=user_role,
69
+ model_role=model_role,
70
+ whatsapp_name=whatsapp_name,
71
  )
72
 
73
  # Split into training and validation datasets (80% and 20%)
74
+ training_examples_ds = training_examples_ds.train_test_split(
75
+ test_size=validation_split, seed=42
76
+ )
77
+ training_examples_ds, validation_examples_ds = (
78
+ training_examples_ds["train"],
79
+ training_examples_ds["test"],
80
+ )
81
+ training_examples_ds = training_examples_ds#.select(
82
+ # range(min(250, len(training_examples_ds)))
83
+ #)
84
+ validation_examples_ds = validation_examples_ds.select(
85
+ range(min(200, len(validation_examples_ds)))
86
+ )
87
 
88
+ format_errors = check_format_errors(
89
+ training_examples_ds, user_role=user_role, model_role=model_role
90
+ )
91
+ distributions = get_distributions(
92
+ training_examples_ds, user_role=user_role, model_role=model_role
93
+ )
94
+ cost_stats = estimate_cost(
95
+ training_examples_ds, user_role=user_role, model_role=model_role
96
+ )
97
 
98
  stats = {
99
  "Format Errors": format_errors,
 
110
 
111
  fig_num_assistant_tokens_per_example_plot = plt.figure()
112
  num_assistant_tokens_per_example_plot = plt.hist(
113
+ distributions["assistant_message_lens"], bins=20
 
114
  )
115
 
116
  # The DownloadFile component requires a path to the file, it can't accept a buffer to keep the file in memory.
 
132
  stats,
133
  fig_num_messages_distribution_plot,
134
  fig_num_total_tokens_per_example_plot,
135
+ fig_num_assistant_tokens_per_example_plot,
136
  )
137
 
138
 
 
184
  value="""Aldan is an AI researcher who loves to play around with AI systems, travelling and learning new things.""",
185
  )
186
 
187
+ whatsapp_name = gr.Textbox(
188
+ label="Your WhatsApp Name",
189
+ placeholder="Your WhatsApp Name",
190
+ info="Enter your WhatsApp name as it appears in your profile. It needs to match exactly your name. If you're unsure, you can check the chat messages to see it.",
191
+ )
192
+
193
+ user_role = gr.Textbox(
194
+ label="Role for User",
195
+ info="This is a technical parameter. If you don't know what to write, just type 'user'.",
196
+ value="user",
197
+ )
198
+
199
+ model_role = gr.Textbox(
200
+ label="Role for Model",
201
+ info="This is a technical parameter. If you don't know what to write, just type 'model'.",
202
+ value="model",
203
+ )
204
+
205
  do_spelling_correction = gr.Checkbox(
206
  label="Do Spelling Correction (English)",
207
  info="Check this box if you want to perform spelling correction on the chat messages before generating the training examples.",
 
219
 
220
  submit = gr.Button(value="Submit", variant="primary")
221
 
222
+ output_file = gr.DownloadButton(
223
+ label="Download Generated Training Examples", visible=False, variant="primary"
224
+ )
225
+ output_file_validation = gr.DownloadButton(
226
+ label="Download Generated Validation Examples",
227
+ visible=False,
228
+ variant="secondary",
229
+ )
230
  # output_example = gr.JSON(label="Example Training Example")
231
 
232
  with gr.Group():
233
  # Statistics about the dataset
234
  gr.Markdown("## Statistics")
235
  written_stats = gr.JSON()
236
+ num_messages_distribution_plot = gr.Plot(
237
+ label="Number of Messages Distribution"
238
+ )
239
+ num_total_tokens_per_example_plot = gr.Plot(
240
+ label="Total Number of Tokens per Example"
241
+ )
242
  num_assistant_tokens_per_example_plot = gr.Plot(
243
  label="Number of Assistant Tokens per Example"
244
  )
245
 
246
  submit.click(
247
  file_upload_callback,
248
+ inputs=[
249
+ input_files,
250
+ system_prompt,
251
+ do_spelling_correction,
252
+ validation_split,
253
+ user_role,
254
+ model_role,
255
+ whatsapp_name,
256
+ ],
257
  outputs=[
258
  output_file,
259
  output_file,
 
263
  num_messages_distribution_plot,
264
  num_total_tokens_per_example_plot,
265
  num_assistant_tokens_per_example_plot,
266
+ ],
267
  )
268
 
269
+ output_file.click(
270
+ remove_file_and_hide_button, inputs=[output_file], outputs=[output_file]
271
+ )
272
+ output_file_validation.click(
273
+ remove_file_and_hide_button,
274
+ inputs=[output_file_validation],
275
+ outputs=[output_file_validation],
276
+ )
277
 
278
  if __name__ == "__main__":
279
  demo.launch()
utils.py CHANGED
@@ -35,8 +35,9 @@ def process_line(example):
35
  # %%
36
  # Now, create message groups ('conversations')
37
  # The idea is to group messages that are close in time
38
- # We'll use a 240 minute threshold
39
- MINUTES_THRESHOLD = 240
 
40
 
41
 
42
  def group_messages(messages_iterable):
@@ -67,8 +68,9 @@ def printable_conversation(conversation):
67
  import spacy
68
  import contextualSpellCheck
69
  from spellchecker import SpellChecker
 
70
  spell = SpellChecker()
71
- #nlp = spacy.load("es_core_news_sm")
72
  nlp = spacy.load("en_core_web_sm")
73
 
74
 
@@ -262,8 +264,10 @@ def process_chat_file(file, do_spelling_correction, do_reordering=False):
262
  # Generate the dataset
263
  conversations_ds = datasets.Dataset.from_dict({"conversations": groups})
264
 
265
- # Filter out conversations with less than 10 messages
266
- conversations_ds = conversations_ds.filter(lambda x: len(x["conversations"]) >= 10)
 
 
267
 
268
  conversations_ds_without_whatsapp_annotations = conversations_ds.map(
269
  remove_whatapp_annotations,
@@ -296,8 +300,12 @@ def process_chat_file(file, do_spelling_correction, do_reordering=False):
296
  return changed_contact_name_ds
297
 
298
 
 
 
 
 
299
  def transform_conversations_dataset_into_training_examples(
300
- conversations_ds, system_prompt
301
  ):
302
  """
303
  Takes in a dataset with conversations and returns a dataset with training examples.
@@ -317,26 +325,70 @@ def transform_conversations_dataset_into_training_examples(
317
  ```
318
  """
319
 
320
- def process_one_example(example):
321
- messages = [{"role": "system", "content": [system_prompt]}]
322
- for msg in example["conversations"]:
323
- converted_role = "assistant" if msg["contact_name"] == "Aldi" else "user"
324
- if converted_role == messages[-1]["role"]:
325
- messages[-1]["content"] += [msg["message"]]
326
- else:
327
- messages.append({"role": converted_role, "content": [msg["message"]]})
328
- return {
329
- "messages": [
330
- {
331
- "role": m["role"],
332
- "content": json.dumps(m["content"], ensure_ascii=False),
333
- }
334
- for m in messages
335
- ]
336
- }
337
-
338
- return conversations_ds.map(
339
- process_one_example,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  remove_columns=["conversations"],
341
- num_proc=os.cpu_count() - 1,
 
 
 
 
 
 
 
342
  )
 
 
 
35
  # %%
36
  # Now, create message groups ('conversations')
37
  # The idea is to group messages that are close in time
38
+ # We'll use a 180 minute threshold
39
+ MINUTES_THRESHOLD = 180
40
+ MIN_MESSAGES_THRESHOLD = 5
41
 
42
 
43
  def group_messages(messages_iterable):
 
68
  import spacy
69
  import contextualSpellCheck
70
  from spellchecker import SpellChecker
71
+
72
  spell = SpellChecker()
73
+ # nlp = spacy.load("es_core_news_sm")
74
  nlp = spacy.load("en_core_web_sm")
75
 
76
 
 
264
  # Generate the dataset
265
  conversations_ds = datasets.Dataset.from_dict({"conversations": groups})
266
 
267
+ # Filter out conversations with less than 5 messages
268
+ conversations_ds = conversations_ds.filter(
269
+ lambda x: len(x["conversations"]) >= MIN_MESSAGES_THRESHOLD
270
+ )
271
 
272
  conversations_ds_without_whatsapp_annotations = conversations_ds.map(
273
  remove_whatapp_annotations,
 
300
  return changed_contact_name_ds
301
 
302
 
303
+ SPLIT_CONVERSATION_THRESHOLD = 40
304
+ MAX_CHARACTERS_PER_MESSAGE = 10000 # Max is 8,192 tokens (https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-about#sample-datasets)
305
+
306
+
307
  def transform_conversations_dataset_into_training_examples(
308
+ conversations_ds, system_prompt, user_role, model_role, whatsapp_name
309
  ):
310
  """
311
  Takes in a dataset with conversations and returns a dataset with training examples.
 
325
  ```
326
  """
327
 
328
+ def process_examples(examples):
329
+ processed_examples = []
330
+ for conversation in examples["conversations"]:
331
+ messages = [{"role": "system", "content": [system_prompt]}]
332
+ counter = 0
333
+ for msg in conversation:
334
+ converted_role = (
335
+ model_role if msg["contact_name"] == whatsapp_name else user_role
336
+ )
337
+ if (
338
+ counter > SPLIT_CONVERSATION_THRESHOLD
339
+ and converted_role == user_role
340
+ ):
341
+ processed_examples.append(
342
+ {
343
+ "messages": [
344
+ {
345
+ "role": m["role"],
346
+ "content": json.dumps(
347
+ m["content"], ensure_ascii=False
348
+ ),
349
+ }
350
+ for m in messages
351
+ ]
352
+ }
353
+ )
354
+ messages = [{"role": "system", "content": [system_prompt]}]
355
+ counter = 0
356
+ if converted_role == messages[-1]["role"]:
357
+ messages[-1]["content"] += [msg["message"]]
358
+ else:
359
+ messages.append(
360
+ {"role": converted_role, "content": [msg["message"]]}
361
+ )
362
+ counter += 1
363
+ if len(messages) >= MIN_MESSAGES_THRESHOLD:
364
+ processed_examples.append(
365
+ {
366
+ "messages": [
367
+ {
368
+ "role": m["role"],
369
+ "content": json.dumps(m["content"], ensure_ascii=False),
370
+ }
371
+ for m in messages
372
+ ]
373
+ }
374
+ )
375
+ # Before returning, flatten the list of dictionaries into a dictionary of lists
376
+ flattened_examples = {}
377
+ for key in processed_examples[0].keys():
378
+ flattened_examples[key] = [d[key] for d in processed_examples]
379
+ return flattened_examples
380
+
381
+ processed_examples = conversations_ds.map(
382
+ process_examples,
383
  remove_columns=["conversations"],
384
+ # num_proc=os.cpu_count() - 1,
385
+ batched=True,
386
+ )
387
+
388
+ examples_filtered_by_length = processed_examples.filter(
389
+ lambda x: all(
390
+ [len(m["content"]) < MAX_CHARACTERS_PER_MESSAGE for m in x["messages"]]
391
+ )
392
  )
393
+
394
+ return examples_filtered_by_length
validation.py CHANGED
@@ -3,7 +3,7 @@ from collections import defaultdict
3
  import tiktoken
4
 
5
 
6
- def check_format_errors(train_dataset):
7
  """
8
  Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
9
  """
@@ -27,7 +27,7 @@ def check_format_errors(train_dataset):
27
  if any(k not in ("role", "content", "name", "function_call", "weight") for k in message):
28
  format_errors["message_unrecognized_key"] += 1
29
 
30
- if message.get("role", None) not in ("system", "user", "assistant", "function"):
31
  format_errors["unrecognized_role"] += 1
32
 
33
  content = message.get("content", None)
@@ -36,7 +36,7 @@ def check_format_errors(train_dataset):
36
  if (not content and not function_call) or not isinstance(content, str):
37
  format_errors["missing_content"] += 1
38
 
39
- if not any(message.get("role", None) == "assistant" for message in messages):
40
  format_errors["example_missing_assistant_message"] += 1
41
 
42
  if format_errors:
@@ -48,7 +48,7 @@ def check_format_errors(train_dataset):
48
 
49
  return format_errors if format_errors else {}
50
 
51
- def get_distributions(train_dataset):
52
  """
53
  Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
54
 
@@ -72,7 +72,7 @@ def get_distributions(train_dataset):
72
  def num_assistant_tokens_from_messages(messages):
73
  num_tokens = 0
74
  for message in messages:
75
- if message["role"] == "assistant":
76
  num_tokens += len(encoding.encode(message["content"]))
77
  return num_tokens
78
 
@@ -87,7 +87,7 @@ def get_distributions(train_dataset):
87
  messages = ex["messages"]
88
  if not any(message["role"] == "system" for message in messages):
89
  n_missing_system += 1
90
- if not any(message["role"] == "user" for message in messages):
91
  n_missing_user += 1
92
  n_messages.append(len(messages))
93
  convo_lens.append(num_tokens_from_messages(messages))
@@ -102,7 +102,7 @@ def get_distributions(train_dataset):
102
  }
103
 
104
 
105
- def check_token_counts(train_dataset):
106
  """
107
  Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
108
  """
@@ -115,7 +115,7 @@ def check_token_counts(train_dataset):
115
 
116
 
117
  # Warnings and tokens counts
118
- distributions = get_distributions(train_dataset)
119
  n_missing_system = distributions["n_missing_system"]
120
  n_missing_user = distributions["n_missing_user"]
121
  n_messages = distributions["n_messages"]
@@ -135,11 +135,11 @@ def check_token_counts(train_dataset):
135
  return
136
 
137
 
138
- def estimate_cost(train_dataset):
139
  """
140
  Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
141
  """
142
- distributions = get_distributions(train_dataset)
143
  n_missing_system = distributions["n_missing_system"]
144
  n_missing_user = distributions["n_missing_user"]
145
  n_messages = distributions["n_messages"]
 
3
  import tiktoken
4
 
5
 
6
+ def check_format_errors(train_dataset, user_role, model_role):
7
  """
8
  Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
9
  """
 
27
  if any(k not in ("role", "content", "name", "function_call", "weight") for k in message):
28
  format_errors["message_unrecognized_key"] += 1
29
 
30
+ if message.get("role", None) not in ["system", user_role, model_role]:
31
  format_errors["unrecognized_role"] += 1
32
 
33
  content = message.get("content", None)
 
36
  if (not content and not function_call) or not isinstance(content, str):
37
  format_errors["missing_content"] += 1
38
 
39
+ if not any(message.get("role", None) == model_role for message in messages):
40
  format_errors["example_missing_assistant_message"] += 1
41
 
42
  if format_errors:
 
48
 
49
  return format_errors if format_errors else {}
50
 
51
+ def get_distributions(train_dataset, user_role, model_role):
52
  """
53
  Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
54
 
 
72
  def num_assistant_tokens_from_messages(messages):
73
  num_tokens = 0
74
  for message in messages:
75
+ if message["role"] == model_role:
76
  num_tokens += len(encoding.encode(message["content"]))
77
  return num_tokens
78
 
 
87
  messages = ex["messages"]
88
  if not any(message["role"] == "system" for message in messages):
89
  n_missing_system += 1
90
+ if not any(message["role"] == user_role for message in messages):
91
  n_missing_user += 1
92
  n_messages.append(len(messages))
93
  convo_lens.append(num_tokens_from_messages(messages))
 
102
  }
103
 
104
 
105
+ def check_token_counts(train_dataset, user_role, model_role):
106
  """
107
  Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
108
  """
 
115
 
116
 
117
  # Warnings and tokens counts
118
+ distributions = get_distributions(train_dataset, user_role=user_role, model_role=model_role)
119
  n_missing_system = distributions["n_missing_system"]
120
  n_missing_user = distributions["n_missing_user"]
121
  n_messages = distributions["n_messages"]
 
135
  return
136
 
137
 
138
+ def estimate_cost(train_dataset, user_role, model_role):
139
  """
140
  Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
141
  """
142
+ distributions = get_distributions(train_dataset, user_role=user_role, model_role=model_role)
143
  n_missing_system = distributions["n_missing_system"]
144
  n_missing_user = distributions["n_missing_user"]
145
  n_messages = distributions["n_messages"]