zetavg commited on
Commit
0537112
1 Parent(s): 570c043

support .py prompt template

Browse files
llama_lora/lib/finetune.py CHANGED
@@ -162,6 +162,8 @@ def train(
162
 
163
  # If train_dataset_data is a list, convert it to datasets.Dataset
164
  if isinstance(train_dataset_data, list):
 
 
165
  train_dataset_data = Dataset.from_list(train_dataset_data)
166
 
167
  if resume_from_checkpoint:
@@ -221,7 +223,7 @@ def train(
221
  optim="adamw_torch",
222
  evaluation_strategy="steps" if val_set_size > 0 else "no",
223
  save_strategy="steps",
224
- eval_steps=200 if val_set_size > 0 else None,
225
  save_steps=save_steps,
226
  output_dir=output_dir,
227
  save_total_limit=save_total_limit,
@@ -260,6 +262,14 @@ def train(
260
  }
261
  json.dump(finetune_params, finetune_params_json_file, indent=2)
262
 
 
 
 
 
 
 
 
 
263
  model.config.use_cache = False
264
 
265
  old_state_dict = model.state_dict
 
162
 
163
  # If train_dataset_data is a list, convert it to datasets.Dataset
164
  if isinstance(train_dataset_data, list):
165
+ with open(os.path.join(output_dir, "train_data_samples.json"), 'w') as file:
166
+ json.dump(list(train_dataset_data[:100]), file, indent=2)
167
  train_dataset_data = Dataset.from_list(train_dataset_data)
168
 
169
  if resume_from_checkpoint:
 
223
  optim="adamw_torch",
224
  evaluation_strategy="steps" if val_set_size > 0 else "no",
225
  save_strategy="steps",
226
+ eval_steps=save_steps if val_set_size > 0 else None,
227
  save_steps=save_steps,
228
  output_dir=output_dir,
229
  save_total_limit=save_total_limit,
 
262
  }
263
  json.dump(finetune_params, finetune_params_json_file, indent=2)
264
 
265
+ # Not working, will only give us ["prompt", "completion", "input_ids", "attention_mask", "labels"]
266
+ # if train_data:
267
+ # with open(os.path.join(output_dir, "train_dataset_samples.json"), 'w') as file:
268
+ # json.dump(list(train_data[:100]), file, indent=2)
269
+ # if val_data:
270
+ # with open(os.path.join(output_dir, "eval_dataset_samples.json"), 'w') as file:
271
+ # json.dump(list(val_data[:100]), file, indent=2)
272
+
273
  model.config.use_cache = False
274
 
275
  old_state_dict = model.state_dict
llama_lora/ui/finetune_ui.py CHANGED
@@ -79,56 +79,50 @@ def load_sample_dataset_to_text_input(format):
79
  return gr.Code.update(value=sample_plain_text_value)
80
 
81
 
82
- def process_json_dataset(data, only_first_n_items=None):
83
- if not isinstance(data, list):
84
- raise ValueError("The dataset is not an array of objects.")
85
-
86
- if only_first_n_items is not None:
87
- data = data[:only_first_n_items]
88
-
89
- first_item = get_val_from_arr(data, 0, None)
90
-
91
- if first_item is None:
92
- raise ValueError("The dataset is empty.")
93
- if not isinstance(first_item, dict):
94
- raise ValueError("The dataset is not an array of objects.")
95
-
96
- # Convert OpenAI fine-tuning dataset to LLaMA LoRA style
97
- if "completion" in first_item and "output" not in first_item:
98
- data = [
99
- {"output" if k == "completion" else k: v for k, v in d.items()}
100
- for d in data]
101
- first_item = get_val_from_arr(data, 0, None)
102
-
103
- # Flatten Stanford Alpaca style instances
104
- if "instances" in first_item and isinstance(first_item["instances"], list):
105
- data = [
106
- {"output" if k == "completion" else k: v for k, v in d.items()}
107
- for d in data]
108
- flattened_data = []
109
- for item in data:
110
- for instance in item["instances"]:
111
- d = {k: v for k, v in item.items() if k != "instances"}
112
- d.update(instance)
113
- flattened_data.append(d)
114
- data = flattened_data
115
- first_item = get_val_from_arr(data, 0, None)
116
-
117
- if "output" not in first_item:
118
- raise ValueError(
119
- "The data does not contains an \"output\" or \"completion\".")
120
-
121
- # Put all variables under the "variables" key if it does not exists
122
- if "variables" not in first_item:
123
- data = [
124
- {
125
- "variables":
126
- {k: v for k, v in d.items() if k != "output"},
127
- "output":
128
- d["output"]
129
- }
130
- for d in data
131
- ]
132
  return data
133
 
134
 
@@ -144,75 +138,59 @@ def refresh_preview(
144
  preview_show_actual_prompt,
145
  ):
146
  try:
147
- max_preview_count = 100
148
  prompter = Prompter(template)
149
  variable_names = prompter.get_variable_names()
150
 
151
- if load_dataset_from == "Text Input":
152
- if dataset_text_format == "JSON":
153
- data = json.loads(dataset_text)
154
- data = process_json_dataset(data)
155
-
156
- elif dataset_text_format == "JSON Lines":
157
- lines = dataset_text.split('\n')
158
- data = []
159
- for i, line in enumerate(lines):
160
- line_number = i + 1
161
- try:
162
- data.append(json.loads(line))
163
- except Exception as e:
164
- raise ValueError(
165
- f"Error parsing JSON on line {line_number}: {e}")
166
-
167
- data = process_json_dataset(data)
168
-
169
- else: # Plain Text
170
- data = parse_plain_text_input(
171
- dataset_text,
172
- (
173
- dataset_plain_text_input_variables_separator or
174
- default_dataset_plain_text_input_variables_separator
175
- ).replace("\\n", "\n"),
176
- (
177
- dataset_plain_text_input_and_output_separator or
178
- default_dataset_plain_text_input_and_output_separator
179
- ).replace("\\n", "\n"),
180
- (
181
- dataset_plain_text_data_separator or
182
- default_dataset_plain_text_data_separator
183
- ).replace("\\n", "\n"),
184
- variable_names
185
- )
186
 
187
- else: # Load dataset from data directory
188
- data = get_dataset_content(dataset_from_data_dir)
189
- data = process_json_dataset(data)
190
 
191
  data_count = len(data)
192
- headers = variable_names
 
193
  preview_data = [
194
- [item['variables'].get(name, "") for name in variable_names]
195
- for item in data[:max_preview_count]
196
  ]
197
 
198
- if preview_show_actual_prompt:
199
- headers = headers + ["Prompt (actual input)"]
200
- rendered = [prompter.generate_prompt(
201
- item['variables']) for item in data[:max_preview_count]]
202
- preview_data = result = [d + [i]
203
- for d, i in zip(preview_data, rendered)]
 
 
 
204
 
205
- headers = headers + ["Completion (output)"]
206
- preview_data = result = [pd + [d['output']]
207
- for pd, d in zip(preview_data, data[:max_preview_count])]
 
 
 
208
 
209
- preview_info_message = f"The dataset has a total of {data_count} item(s)."
 
 
 
 
210
  if data_count > max_preview_count:
211
  preview_info_message += f" Previewing the first {max_preview_count}."
212
 
213
  info_message = f"{data_count} item(s)."
214
  if load_dataset_from == "Data Dir":
215
- info_message = "This dataset contains " + info_message
216
  update_message = gr.Markdown.update(info_message, visible=True)
217
 
218
  return gr.Dataframe.update(value={'data': preview_data, 'headers': headers}), gr.Markdown.update(preview_info_message), update_message, update_message
@@ -288,57 +266,24 @@ def do_train(
288
  unload_models() # Need RAM for training
289
 
290
  prompter = Prompter(template)
291
- variable_names = prompter.get_variable_names()
292
-
293
- if load_dataset_from == "Text Input":
294
- if dataset_text_format == "JSON":
295
- data = json.loads(dataset_text)
296
- data = process_json_dataset(data)
297
-
298
- elif dataset_text_format == "JSON Lines":
299
- lines = dataset_text.split('\n')
300
- data = []
301
- for i, line in enumerate(lines):
302
- line_number = i + 1
303
- try:
304
- data.append(json.loads(line))
305
- except Exception as e:
306
- raise ValueError(
307
- f"Error parsing JSON on line {line_number}: {e}")
308
-
309
- data = process_json_dataset(data)
310
-
311
- else: # Plain Text
312
- data = parse_plain_text_input(
313
- dataset_text,
314
- (
315
- dataset_plain_text_input_variables_separator or
316
- default_dataset_plain_text_input_variables_separator
317
- ).replace("\\n", "\n"),
318
- (
319
- dataset_plain_text_input_and_output_separator or
320
- default_dataset_plain_text_input_and_output_separator
321
- ).replace("\\n", "\n"),
322
- (
323
- dataset_plain_text_data_separator or
324
- default_dataset_plain_text_data_separator
325
- ).replace("\\n", "\n"),
326
- variable_names
327
- )
328
 
329
- else: # Load dataset from data directory
330
- data = get_dataset_content(dataset_from_data_dir)
331
- data = process_json_dataset(data)
332
 
333
- data_count = len(data)
334
  evaluate_data_count = math.ceil(data_count * evaluate_data_percentage)
335
 
336
- train_data = [
337
- {
338
- 'prompt': prompter.generate_prompt(d['variables']),
339
- 'completion': d['output']}
340
- for d in data]
341
-
342
  def get_progress_text(epoch, epochs, last_loss):
343
  progress_detail = f"Epoch {math.ceil(epoch)}/{epochs}"
344
  if last_loss is not None:
@@ -449,20 +394,21 @@ Train data (first 10):
449
  'dataset_rows': len(train_data),
450
  'timestamp': time.time(),
451
 
452
- 'max_seq_length': max_seq_length,
453
- 'train_on_inputs': train_on_inputs,
 
454
 
455
- 'micro_batch_size': micro_batch_size,
456
- 'gradient_accumulation_steps': gradient_accumulation_steps,
457
- 'epochs': epochs,
458
- 'learning_rate': learning_rate,
459
 
460
- 'evaluate_data_percentage': evaluate_data_percentage,
461
 
462
- 'lora_r': lora_r,
463
- 'lora_alpha': lora_alpha,
464
- 'lora_dropout': lora_dropout,
465
- 'lora_target_modules': lora_target_modules,
466
  }
467
  json.dump(info, info_json_file, indent=2)
468
 
 
79
  return gr.Code.update(value=sample_plain_text_value)
80
 
81
 
82
+
83
+
84
+
85
+ def get_data_from_input(load_dataset_from, dataset_text, dataset_text_format,
86
+ dataset_plain_text_input_variables_separator,
87
+ dataset_plain_text_input_and_output_separator,
88
+ dataset_plain_text_data_separator,
89
+ dataset_from_data_dir, prompter):
90
+ if load_dataset_from == "Text Input":
91
+ if dataset_text_format == "JSON":
92
+ data = json.loads(dataset_text)
93
+
94
+ elif dataset_text_format == "JSON Lines":
95
+ lines = dataset_text.split('\n')
96
+ data = []
97
+ for i, line in enumerate(lines):
98
+ line_number = i + 1
99
+ try:
100
+ data.append(json.loads(line))
101
+ except Exception as e:
102
+ raise ValueError(
103
+ f"Error parsing JSON on line {line_number}: {e}")
104
+
105
+ else: # Plain Text
106
+ data = parse_plain_text_input(
107
+ dataset_text,
108
+ (
109
+ dataset_plain_text_input_variables_separator or
110
+ default_dataset_plain_text_input_variables_separator
111
+ ).replace("\\n", "\n"),
112
+ (
113
+ dataset_plain_text_input_and_output_separator or
114
+ default_dataset_plain_text_input_and_output_separator
115
+ ).replace("\\n", "\n"),
116
+ (
117
+ dataset_plain_text_data_separator or
118
+ default_dataset_plain_text_data_separator
119
+ ).replace("\\n", "\n"),
120
+ prompter.get_variable_names()
121
+ )
122
+
123
+ else: # Load dataset from data directory
124
+ data = get_dataset_content(dataset_from_data_dir)
125
+
 
 
 
 
 
 
126
  return data
127
 
128
 
 
138
  preview_show_actual_prompt,
139
  ):
140
  try:
141
+ max_preview_count = 30
142
  prompter = Prompter(template)
143
  variable_names = prompter.get_variable_names()
144
 
145
+ data = get_data_from_input(
146
+ load_dataset_from=load_dataset_from,
147
+ dataset_text=dataset_text,
148
+ dataset_text_format=dataset_text_format,
149
+ dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
150
+ dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
151
+ dataset_plain_text_data_separator=dataset_plain_text_data_separator,
152
+ dataset_from_data_dir=dataset_from_data_dir,
153
+ prompter=prompter
154
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ train_data = prompter.get_train_data_from_dataset(data, max_preview_count)
 
 
157
 
158
  data_count = len(data)
159
+
160
+ headers = ['Prompt', 'Completion']
161
  preview_data = [
162
+ [item.get("prompt", ""), item.get("completion", "")]
163
+ for item in train_data
164
  ]
165
 
166
+ if not prompter.template_module:
167
+ variable_names = prompter.get_variable_names()
168
+ headers += [f"Variable: {variable_name}" for variable_name in variable_names]
169
+ variables = [
170
+ [item.get(f"_var_{name}", "") for name in variable_names]
171
+ for item in train_data
172
+ ]
173
+ preview_data = [d + v for d, v in zip(preview_data, variables)]
174
+
175
 
176
+ # if preview_show_actual_prompt:
177
+ # headers = headers + ["Prompt (actual input)"]
178
+ # rendered = [prompter.generate_prompt(
179
+ # item['variables']) for item in data[:max_preview_count]]
180
+ # preview_data = result = [d + [i]
181
+ # for d, i in zip(preview_data, rendered)]
182
 
183
+ # headers = headers + ["Completion (output)"]
184
+ # preview_data = result = [pd + [d['output']]
185
+ # for pd, d in zip(preview_data, data[:max_preview_count])]
186
+
187
+ preview_info_message = f"The dataset has about {data_count} item(s)."
188
  if data_count > max_preview_count:
189
  preview_info_message += f" Previewing the first {max_preview_count}."
190
 
191
  info_message = f"{data_count} item(s)."
192
  if load_dataset_from == "Data Dir":
193
+ info_message = "This dataset contains about " + info_message
194
  update_message = gr.Markdown.update(info_message, visible=True)
195
 
196
  return gr.Dataframe.update(value={'data': preview_data, 'headers': headers}), gr.Markdown.update(preview_info_message), update_message, update_message
 
266
  unload_models() # Need RAM for training
267
 
268
  prompter = Prompter(template)
269
+ # variable_names = prompter.get_variable_names()
270
+
271
+ data = get_data_from_input(
272
+ load_dataset_from=load_dataset_from,
273
+ dataset_text=dataset_text,
274
+ dataset_text_format=dataset_text_format,
275
+ dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
276
+ dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
277
+ dataset_plain_text_data_separator=dataset_plain_text_data_separator,
278
+ dataset_from_data_dir=dataset_from_data_dir,
279
+ prompter=prompter
280
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
+ train_data = prompter.get_train_data_from_dataset(data)
 
 
283
 
284
+ data_count = len(train_data)
285
  evaluate_data_count = math.ceil(data_count * evaluate_data_percentage)
286
 
 
 
 
 
 
 
287
  def get_progress_text(epoch, epochs, last_loss):
288
  progress_detail = f"Epoch {math.ceil(epoch)}/{epochs}"
289
  if last_loss is not None:
 
394
  'dataset_rows': len(train_data),
395
  'timestamp': time.time(),
396
 
397
+ # These will be saved in another JSON file by the train function
398
+ # 'max_seq_length': max_seq_length,
399
+ # 'train_on_inputs': train_on_inputs,
400
 
401
+ # 'micro_batch_size': micro_batch_size,
402
+ # 'gradient_accumulation_steps': gradient_accumulation_steps,
403
+ # 'epochs': epochs,
404
+ # 'learning_rate': learning_rate,
405
 
406
+ # 'evaluate_data_percentage': evaluate_data_percentage,
407
 
408
+ # 'lora_r': lora_r,
409
+ # 'lora_alpha': lora_alpha,
410
+ # 'lora_dropout': lora_dropout,
411
+ # 'lora_target_modules': lora_target_modules,
412
  }
413
  json.dump(info, info_json_file, indent=2)
414
 
llama_lora/utils/data.py CHANGED
@@ -30,7 +30,7 @@ def copy_sample_data_if_not_exists(source, destination):
30
  def get_available_template_names():
31
  templates_directory_path = os.path.join(Global.data_dir, "templates")
32
  all_files = os.listdir(templates_directory_path)
33
- return [os.path.splitext(filename)[0] for filename in all_files if fnmatch.fnmatch(filename, "*.json")]
34
 
35
 
36
  def get_available_dataset_names():
 
30
  def get_available_template_names():
31
  templates_directory_path = os.path.join(Global.data_dir, "templates")
32
  all_files = os.listdir(templates_directory_path)
33
+ return [filename.rstrip(".json") for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.py")]
34
 
35
 
36
  def get_available_dataset_names():
llama_lora/utils/prompter.py CHANGED
@@ -5,13 +5,15 @@ From https://github.com/tloen/alpaca-lora/blob/main/utils/prompter.py
5
 
6
  import json
7
  import os.path as osp
 
 
8
  from typing import Union, List
9
 
10
  from ..globals import Global
11
 
12
 
13
  class Prompter(object):
14
- __slots__ = ("template_name", "template", "_verbose")
15
 
16
  def __init__(self, template_name: str = "", verbose: bool = False):
17
  self._verbose = verbose
@@ -21,12 +23,41 @@ class Prompter(object):
21
  self.template_name = "None"
22
  return
23
  self.template_name = template_name
 
24
 
25
- file_name = osp.join(Global.data_dir, "templates",
26
- f"{template_name}.json")
27
- if not osp.exists(file_name):
28
- raise ValueError(f"Can't read {file_name}")
29
- with open(file_name) as fp:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  self.template = json.load(fp)
31
  if self._verbose:
32
  print(
@@ -47,23 +78,31 @@ class Prompter(object):
47
  res = variables.get("prompt", "")
48
  elif "variables" in self.template:
49
  variable_names = self.template.get("variables")
50
- if type(variables) == dict:
51
- variables = [variables.get(name, None)
52
- for name in variable_names]
53
- if "default" not in self.template:
54
- raise ValueError(
55
- f"The template {self.template_name} has \"variables\" defined but does not has a default prompt defined. Please do it like: '\"default\": \"prompt_with_instruction\"' to handle cases when a matching prompt can't be found.")
56
- default_prompt_name = self.template.get("default")
57
- if default_prompt_name not in self.template:
58
- raise ValueError(
59
- f"The template {self.template_name} has \"default\" set to \"{default_prompt_name}\" but it's not defined. Please do it like: '\"{default_prompt_name}\": \"...\".")
60
- prompt_name = get_prompt_name(variables, variable_names)
61
- prompt_template = self.template.get(default_prompt_name)
62
- if prompt_name in self.template:
63
- prompt_template = self.template.get(prompt_name)
64
 
65
- res = prompt_template.format(
66
- **variables_to_dict(variables, variable_names))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  else:
69
  if type(variables) == dict:
@@ -104,6 +143,30 @@ class Prompter(object):
104
  else:
105
  return ["instruction", "input"]
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def get_val(arr, index, default=None):
109
  return arr[index] if -len(arr) <= index < len(arr) else default
@@ -117,3 +180,57 @@ def get_prompt_name(variables, variable_names):
117
 
118
  def variables_to_dict(variables, variable_names):
119
  return {key: (variables[i] if i < len(variables) and variables[i] is not None else '') for i, key in enumerate(variable_names)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  import json
7
  import os.path as osp
8
+ import importlib
9
+ import itertools
10
  from typing import Union, List
11
 
12
  from ..globals import Global
13
 
14
 
15
  class Prompter(object):
16
+ __slots__ = ("template_name", "template", "template_module", "_verbose")
17
 
18
  def __init__(self, template_name: str = "", verbose: bool = False):
19
  self._verbose = verbose
 
23
  self.template_name = "None"
24
  return
25
  self.template_name = template_name
26
+ self.template_module = None
27
 
28
+ base_filename, ext = osp.splitext(template_name)
29
+ if ext == "":
30
+ filename = base_filename + ".json"
31
+ else:
32
+ filename = base_filename + ext
33
+
34
+ file_path = osp.join(Global.data_dir, "templates", filename)
35
+
36
+ if not osp.exists(file_path):
37
+ raise ValueError(f"Can't read {file_path}")
38
+
39
+ if ext == ".py":
40
+ template_module_spec = importlib.util.spec_from_file_location(
41
+ "template_module", file_path)
42
+ template_module = importlib.util.module_from_spec(
43
+ template_module_spec)
44
+ template_module_spec.loader.exec_module(template_module)
45
+ self.template_module = template_module
46
+
47
+ if not hasattr(template_module, "variables"):
48
+ raise ValueError(
49
+ "The template module does not have a \"variables\" attribute.")
50
+
51
+ self.template = {
52
+ 'variables': template_module.variables
53
+ }
54
+
55
+ if hasattr(template_module, "response_split"):
56
+ self.template["response_split"] = template_module.response_split
57
+
58
+ return
59
+
60
+ with open(file_path) as fp:
61
  self.template = json.load(fp)
62
  if self._verbose:
63
  print(
 
78
  res = variables.get("prompt", "")
79
  elif "variables" in self.template:
80
  variable_names = self.template.get("variables")
81
+ if self.template_module:
82
+ if type(variables) == list:
83
+ variables = {k: v for k, v in zip(
84
+ variable_names, variables)}
 
 
 
 
 
 
 
 
 
 
85
 
86
+ res = self.template_module.get_prompt(variables)
87
+ else:
88
+ if type(variables) == dict:
89
+ variables = [variables.get(name, None)
90
+ for name in variable_names]
91
+
92
+ if "default" not in self.template:
93
+ raise ValueError(
94
+ f"The template {self.template_name} has \"variables\" defined but does not has a default prompt defined. Please do it like: '\"default\": \"prompt_with_instruction\"' to handle cases when a matching prompt can't be found.")
95
+ default_prompt_name = self.template.get("default")
96
+ if default_prompt_name not in self.template:
97
+ raise ValueError(
98
+ f"The template {self.template_name} has \"default\" set to \"{default_prompt_name}\" but it's not defined. Please do it like: '\"{default_prompt_name}\": \"...\".")
99
+ prompt_name = get_prompt_name(variables, variable_names)
100
+ prompt_template = self.template.get(default_prompt_name)
101
+ if prompt_name in self.template:
102
+ prompt_template = self.template.get(prompt_name)
103
+
104
+ res = prompt_template.format(
105
+ **variables_to_dict(variables, variable_names))
106
 
107
  else:
108
  if type(variables) == dict:
 
143
  else:
144
  return ["instruction", "input"]
145
 
146
+ def get_train_data_from_dataset(self, data, only_first_n_items=None):
147
+ if self.template_module:
148
+ if hasattr(self.template_module, "get_train_data_list_from_dataset"):
149
+ data = self.template_module.get_train_data_list_from_dataset(
150
+ data)
151
+ if only_first_n_items:
152
+ data = data[:only_first_n_items]
153
+ return list(itertools.chain(*list(map(self.template_module.get_train_data, data))))
154
+
155
+ if only_first_n_items:
156
+ data = data[:only_first_n_items]
157
+
158
+ data = process_json_dataset(data)
159
+
160
+ train_data = [
161
+ {
162
+ 'prompt': self.generate_prompt(d['variables']),
163
+ 'completion': d['output'],
164
+ **{"_var_" + k: v for k, v in d['variables'].items()}
165
+ }
166
+ for d in data]
167
+
168
+ return train_data
169
+
170
 
171
  def get_val(arr, index, default=None):
172
  return arr[index] if -len(arr) <= index < len(arr) else default
 
180
 
181
  def variables_to_dict(variables, variable_names):
182
  return {key: (variables[i] if i < len(variables) and variables[i] is not None else '') for i, key in enumerate(variable_names)}
183
+
184
+
185
+ def process_json_dataset(data):
186
+ if not isinstance(data, list):
187
+ raise ValueError("The dataset is not an array of objects.")
188
+
189
+ first_item = get_val_from_arr(data, 0, None)
190
+
191
+ if first_item is None:
192
+ raise ValueError("The dataset is empty.")
193
+ if not isinstance(first_item, dict):
194
+ raise ValueError("The dataset is not an array of objects.")
195
+
196
+ # Convert OpenAI fine-tuning dataset to LLaMA LoRA style
197
+ if "completion" in first_item and "output" not in first_item:
198
+ data = [
199
+ {"output" if k == "completion" else k: v for k, v in d.items()}
200
+ for d in data]
201
+ first_item = get_val_from_arr(data, 0, None)
202
+
203
+ # Flatten Stanford Alpaca style instances
204
+ if "instances" in first_item and isinstance(first_item["instances"], list):
205
+ data = [
206
+ {"output" if k == "completion" else k: v for k, v in d.items()}
207
+ for d in data]
208
+ flattened_data = []
209
+ for item in data:
210
+ for instance in item["instances"]:
211
+ d = {k: v for k, v in item.items() if k != "instances"}
212
+ d.update(instance)
213
+ flattened_data.append(d)
214
+ data = flattened_data
215
+ first_item = get_val_from_arr(data, 0, None)
216
+
217
+ if "output" not in first_item:
218
+ raise ValueError(
219
+ "The data does not contains an \"output\" or \"completion\".")
220
+
221
+ # Put all variables under the "variables" key if it does not exists
222
+ if "variables" not in first_item:
223
+ data = [
224
+ {
225
+ "variables":
226
+ {k: v for k, v in d.items() if k != "output"},
227
+ "output":
228
+ d["output"]
229
+ }
230
+ for d in data
231
+ ]
232
+ return data
233
+
234
+
235
+ def get_val_from_arr(arr, index, default=None):
236
+ return arr[index] if -len(arr) <= index < len(arr) else default