zetavg commited on
Commit
6ba132a
·
unverified ·
2 Parent(s): f18eda8 68255ee

Merge branch 'dev-2' of github.com:zetavg/llama-lora into dev-2

Browse files
.gitignore CHANGED
@@ -3,4 +3,5 @@ __pycache__/
3
  /venv
4
  .vscode
5
 
 
6
  /data
 
3
  /venv
4
  .vscode
5
 
6
+ /wandb
7
  /data
README.md CHANGED
@@ -60,13 +60,14 @@ file_mounts:
60
  setup: |
61
  git clone https://github.com/zetavg/LLaMA-LoRA-Tuner.git llama_lora_tuner
62
  cd llama_lora_tuner && pip install -r requirements.lock.txt
 
63
  cd ..
64
  echo 'Dependencies installed.'
65
 
66
  # Start the app.
67
  run: |
68
  echo 'Starting...'
69
- python llama_lora_tuner/app.py --data_dir='/data' --base_model='decapoda-research/llama-7b-hf' --share
70
  ```
71
 
72
  Then launch a cluster to run the task:
 
60
  setup: |
61
  git clone https://github.com/zetavg/LLaMA-LoRA-Tuner.git llama_lora_tuner
62
  cd llama_lora_tuner && pip install -r requirements.lock.txt
63
+ pip install wandb
64
  cd ..
65
  echo 'Dependencies installed.'
66
 
67
  # Start the app.
68
  run: |
69
  echo 'Starting...'
70
+ python llama_lora_tuner/app.py --data_dir='/data' --wandb_api_key "$([ -f /data/secrets/wandb_api_key ] && cat /data/secrets/wandb_api_key | tr -d '\n')" --base_model='decapoda-research/llama-7b-hf' --share
71
  ```
72
 
73
  Then launch a cluster to run the task:
app.py CHANGED
@@ -5,21 +5,37 @@ import fire
5
  import gradio as gr
6
 
7
  from llama_lora.globals import Global
 
8
  from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
9
  from llama_lora.utils.data import init_data_dir
10
 
11
 
 
12
  def main(
13
- load_8bit: bool = False,
14
  base_model: str = "",
15
  data_dir: str = "",
16
  # Allows to listen on all interfaces by providing '0.0.0.0'.
17
  server_name: str = "127.0.0.1",
18
  share: bool = False,
19
  skip_loading_base_model: bool = False,
 
20
  ui_show_sys_info: bool = True,
21
  ui_dev_mode: bool = False,
 
 
22
  ):
 
 
 
 
 
 
 
 
 
 
 
 
23
  base_model = base_model or os.environ.get("LLAMA_LORA_BASE_MODEL", "")
24
  data_dir = data_dir or os.environ.get("LLAMA_LORA_DATA_DIR", "")
25
  assert (
@@ -34,12 +50,22 @@ def main(
34
  Global.data_dir = os.path.abspath(data_dir)
35
  Global.load_8bit = load_8bit
36
 
 
 
 
 
 
 
 
37
  Global.ui_dev_mode = ui_dev_mode
38
  Global.ui_show_sys_info = ui_show_sys_info
39
 
40
  os.makedirs(data_dir, exist_ok=True)
41
  init_data_dir()
42
 
 
 
 
43
  with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
44
  main_page()
45
 
 
5
  import gradio as gr
6
 
7
  from llama_lora.globals import Global
8
+ from llama_lora.models import prepare_base_model
9
  from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
10
  from llama_lora.utils.data import init_data_dir
11
 
12
 
13
+
14
  def main(
 
15
  base_model: str = "",
16
  data_dir: str = "",
17
  # Allows to listen on all interfaces by providing '0.0.0.0'.
18
  server_name: str = "127.0.0.1",
19
  share: bool = False,
20
  skip_loading_base_model: bool = False,
21
+ load_8bit: bool = False,
22
  ui_show_sys_info: bool = True,
23
  ui_dev_mode: bool = False,
24
+ wandb_api_key: str = "",
25
+ wandb_project: str = "",
26
  ):
27
+ '''
28
+ Start the LLaMA-LoRA Tuner UI.
29
+
30
+ :param base_model: (required) The name of the default base model to use.
31
+ :param data_dir: (required) The path to the directory to store data.
32
+ :param server_name: Allows to listen on all interfaces by providing '0.0.0.0'.
33
+ :param share: Create a public Gradio URL.
34
+
35
+ :param wandb_api_key: The API key for Weights & Biases. Setting either this or `wandb_project` will enable Weights & Biases.
36
+ :param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
37
+ '''
38
+
39
  base_model = base_model or os.environ.get("LLAMA_LORA_BASE_MODEL", "")
40
  data_dir = data_dir or os.environ.get("LLAMA_LORA_DATA_DIR", "")
41
  assert (
 
50
  Global.data_dir = os.path.abspath(data_dir)
51
  Global.load_8bit = load_8bit
52
 
53
+ if len(wandb_api_key) > 0:
54
+ Global.enable_wandb = True
55
+ Global.wandb_api_key = wandb_api_key
56
+ if len(wandb_project) > 0:
57
+ Global.enable_wandb = True
58
+ Global.wandb_project = wandb_project
59
+
60
  Global.ui_dev_mode = ui_dev_mode
61
  Global.ui_show_sys_info = ui_show_sys_info
62
 
63
  os.makedirs(data_dir, exist_ok=True)
64
  init_data_dir()
65
 
66
+ if (not skip_loading_base_model) and (not ui_dev_mode):
67
+ prepare_base_model(base_model)
68
+
69
  with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
70
  main_page()
71
 
llama_lora/globals.py CHANGED
@@ -40,6 +40,11 @@ class Global:
40
  gpu_total_cores = None # GPU total cores
41
  gpu_total_memory = None
42
 
 
 
 
 
 
43
  # UI related
44
  ui_title: str = "LLaMA-LoRA Tuner"
45
  ui_emoji: str = "🦙🎛️"
 
40
  gpu_total_cores = None # GPU total cores
41
  gpu_total_memory = None
42
 
43
+ # WandB
44
+ enable_wandb = False
45
+ wandb_api_key = None
46
+ default_wandb_project = "llama-lora-tuner"
47
+
48
  # UI related
49
  ui_title: str = "LLaMA-LoRA Tuner"
50
  ui_emoji: str = "🦙🎛️"
llama_lora/lib/finetune.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import sys
 
3
  from typing import Any, List
4
 
5
  import json
@@ -50,8 +51,66 @@ def train(
50
  save_total_limit: int = 3,
51
  logging_steps: int = 10,
52
  # logging
53
- callbacks: List[Any] = []
 
 
 
 
 
 
 
 
54
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  if os.path.exists(output_dir):
56
  if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
57
  raise ValueError(
@@ -138,6 +197,8 @@ def train(
138
 
139
  # If train_dataset_data is a list, convert it to datasets.Dataset
140
  if isinstance(train_dataset_data, list):
 
 
141
  train_dataset_data = Dataset.from_list(train_dataset_data)
142
 
143
  if resume_from_checkpoint:
@@ -197,15 +258,15 @@ def train(
197
  optim="adamw_torch",
198
  evaluation_strategy="steps" if val_set_size > 0 else "no",
199
  save_strategy="steps",
200
- eval_steps=200 if val_set_size > 0 else None,
201
  save_steps=save_steps,
202
  output_dir=output_dir,
203
  save_total_limit=save_total_limit,
204
  load_best_model_at_end=True if val_set_size > 0 else False,
205
  ddp_find_unused_parameters=False if ddp else None,
206
  group_by_length=group_by_length,
207
- # report_to="wandb" if use_wandb else None,
208
- # run_name=wandb_run_name if use_wandb else None,
209
  ),
210
  data_collator=transformers.DataCollatorForSeq2Seq(
211
  tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
@@ -217,24 +278,16 @@ def train(
217
  os.makedirs(output_dir)
218
  with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
219
  json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
220
- with open(os.path.join(output_dir, "finetune_params.json"), 'w') as finetune_params_json_file:
221
- finetune_params = {
222
- 'micro_batch_size': micro_batch_size,
223
- 'gradient_accumulation_steps': gradient_accumulation_steps,
224
- 'num_train_epochs': num_train_epochs,
225
- 'learning_rate': learning_rate,
226
- 'cutoff_len': cutoff_len,
227
- 'lora_r': lora_r,
228
- 'lora_alpha': lora_alpha,
229
- 'lora_dropout': lora_dropout,
230
- 'lora_target_modules': lora_target_modules,
231
- 'train_on_inputs': train_on_inputs,
232
- 'group_by_length': group_by_length,
233
- 'save_steps': save_steps,
234
- 'save_total_limit': save_total_limit,
235
- 'logging_steps': logging_steps,
236
- }
237
- json.dump(finetune_params, finetune_params_json_file, indent=2)
238
 
239
  model.config.use_cache = False
240
 
 
1
  import os
2
  import sys
3
+ import importlib
4
  from typing import Any, List
5
 
6
  import json
 
51
  save_total_limit: int = 3,
52
  logging_steps: int = 10,
53
  # logging
54
+ callbacks: List[Any] = [],
55
+ # wandb params
56
+ wandb_api_key = None,
57
+ wandb_project: str = "",
58
+ wandb_group = None,
59
+ wandb_run_name: str = "",
60
+ wandb_tags: List[str] = [],
61
+ wandb_watch: str = "false", # options: false | gradients | all
62
+ wandb_log_model: str = "true", # options: false | true
63
  ):
64
+ # for logging
65
+ finetune_args = {
66
+ 'micro_batch_size': micro_batch_size,
67
+ 'gradient_accumulation_steps': gradient_accumulation_steps,
68
+ 'num_train_epochs': num_train_epochs,
69
+ 'learning_rate': learning_rate,
70
+ 'cutoff_len': cutoff_len,
71
+ 'lora_r': lora_r,
72
+ 'lora_alpha': lora_alpha,
73
+ 'lora_dropout': lora_dropout,
74
+ 'lora_target_modules': lora_target_modules,
75
+ 'train_on_inputs': train_on_inputs,
76
+ 'group_by_length': group_by_length,
77
+ 'save_steps': save_steps,
78
+ 'save_total_limit': save_total_limit,
79
+ 'logging_steps': logging_steps,
80
+ }
81
+
82
+ if wandb_api_key:
83
+ os.environ["WANDB_API_KEY"] = wandb_api_key
84
+
85
+ # wandb: WARNING Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to https://wandb.me/wandb-init.
86
+ # if wandb_project:
87
+ # os.environ["WANDB_PROJECT"] = wandb_project
88
+ # if wandb_run_name:
89
+ # os.environ["WANDB_RUN_NAME"] = wandb_run_name
90
+ if wandb_watch:
91
+ os.environ["WANDB_WATCH"] = wandb_watch
92
+ if wandb_log_model:
93
+ os.environ["WANDB_LOG_MODEL"] = wandb_log_model
94
+ use_wandb = (wandb_project and len(wandb_project) > 0) or (
95
+ "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
96
+ )
97
+ if use_wandb:
98
+ os.environ['WANDB_MODE'] = "online"
99
+ wandb = importlib.import_module("wandb")
100
+ wandb.init(
101
+ project=wandb_project,
102
+ resume="auto",
103
+ group=wandb_group,
104
+ name=wandb_run_name,
105
+ tags=wandb_tags,
106
+ reinit=True,
107
+ magic=True,
108
+ config={'finetune_args': finetune_args},
109
+ # id=None # used for resuming
110
+ )
111
+ else:
112
+ os.environ['WANDB_MODE'] = "disabled"
113
+
114
  if os.path.exists(output_dir):
115
  if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
116
  raise ValueError(
 
197
 
198
  # If train_dataset_data is a list, convert it to datasets.Dataset
199
  if isinstance(train_dataset_data, list):
200
+ with open(os.path.join(output_dir, "train_data_samples.json"), 'w') as file:
201
+ json.dump(list(train_dataset_data[:100]), file, indent=2)
202
  train_dataset_data = Dataset.from_list(train_dataset_data)
203
 
204
  if resume_from_checkpoint:
 
258
  optim="adamw_torch",
259
  evaluation_strategy="steps" if val_set_size > 0 else "no",
260
  save_strategy="steps",
261
+ eval_steps=save_steps if val_set_size > 0 else None,
262
  save_steps=save_steps,
263
  output_dir=output_dir,
264
  save_total_limit=save_total_limit,
265
  load_best_model_at_end=True if val_set_size > 0 else False,
266
  ddp_find_unused_parameters=False if ddp else None,
267
  group_by_length=group_by_length,
268
+ report_to="wandb" if use_wandb else None,
269
+ run_name=wandb_run_name if use_wandb else None,
270
  ),
271
  data_collator=transformers.DataCollatorForSeq2Seq(
272
  tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
 
278
  os.makedirs(output_dir)
279
  with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
280
  json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
281
+ with open(os.path.join(output_dir, "finetune_args.json"), 'w') as finetune_args_json_file:
282
+ json.dump(finetune_args, finetune_args_json_file, indent=2)
283
+
284
+ # Not working, will only give us ["prompt", "completion", "input_ids", "attention_mask", "labels"]
285
+ # if train_data:
286
+ # with open(os.path.join(output_dir, "train_dataset_samples.json"), 'w') as file:
287
+ # json.dump(list(train_data[:100]), file, indent=2)
288
+ # if val_data:
289
+ # with open(os.path.join(output_dir, "eval_dataset_samples.json"), 'w') as file:
290
+ # json.dump(list(val_data[:100]), file, indent=2)
 
 
 
 
 
 
 
 
291
 
292
  model.config.use_cache = False
293
 
llama_lora/lib/get_device.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_device():
5
+ device ="cpu"
6
+ if torch.cuda.is_available():
7
+ device = "cuda"
8
+
9
+ try:
10
+ if torch.backends.mps.is_available():
11
+ device = "mps"
12
+ except: # noqa: E722
13
+ pass
14
+
15
+ return device
llama_lora/lib/inference.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+
4
+ from .get_device import get_device
5
+ from .streaming_generation_utils import Iteratorize, Stream
6
+
7
+
8
+ def generate(
9
+ # model
10
+ model,
11
+ tokenizer,
12
+ # input
13
+ prompt,
14
+ generation_config,
15
+ max_new_tokens,
16
+ stopping_criteria=[],
17
+ # output options
18
+ stream_output=False
19
+ ):
20
+ device = get_device()
21
+
22
+ inputs = tokenizer(prompt, return_tensors="pt")
23
+ input_ids = inputs["input_ids"].to(device)
24
+ generate_params = {
25
+ "input_ids": input_ids,
26
+ "generation_config": generation_config,
27
+ "return_dict_in_generate": True,
28
+ "output_scores": True,
29
+ "max_new_tokens": max_new_tokens,
30
+ "stopping_criteria": transformers.StoppingCriteriaList() + stopping_criteria
31
+ }
32
+
33
+ if stream_output:
34
+ # Stream the reply 1 token at a time.
35
+ # This is based on the trick of using 'stopping_criteria' to create an iterator,
36
+ # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
37
+
38
+ def generate_with_callback(callback=None, **kwargs):
39
+ kwargs["stopping_criteria"].insert(
40
+ 0,
41
+ Stream(callback_func=callback)
42
+ )
43
+ with torch.no_grad():
44
+ model.generate(**kwargs)
45
+
46
+ def generate_with_streaming(**kwargs):
47
+ return Iteratorize(
48
+ generate_with_callback, kwargs, callback=None
49
+ )
50
+
51
+ with generate_with_streaming(**generate_params) as generator:
52
+ for output in generator:
53
+ decoded_output = tokenizer.decode(output, skip_special_tokens=True)
54
+ yield decoded_output, output
55
+ if output[-1] in [tokenizer.eos_token_id]:
56
+ break
57
+ return # early return for stream_output
58
+
59
+ # Without streaming
60
+ with torch.no_grad():
61
+ generation_output = model.generate(**generate_params)
62
+ output = generation_output.sequences[0]
63
+ decoded_output = tokenizer.decode(output, skip_special_tokens=True)
64
+ yield decoded_output, output
65
+ return
llama_lora/{utils/callbacks.py → lib/streaming_generation_utils.py} RENAMED
File without changes
llama_lora/models.py CHANGED
@@ -8,19 +8,7 @@ from transformers import LlamaForCausalLM, LlamaTokenizer
8
  from peft import PeftModel
9
 
10
  from .globals import Global
11
-
12
-
13
- def get_device():
14
- if torch.cuda.is_available():
15
- return "cuda"
16
- else:
17
- return "cpu"
18
-
19
- try:
20
- if torch.backends.mps.is_available():
21
- return "mps"
22
- except: # noqa: E722
23
- pass
24
 
25
 
26
  def get_new_base_model(base_model_name):
@@ -60,9 +48,10 @@ def get_new_base_model(base_model_name):
60
  base_model_name, device_map={"": device}, low_cpu_mem_usage=True
61
  )
62
 
63
- model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
64
- model.config.bos_token_id = 1
65
- model.config.eos_token_id = 2
 
66
 
67
  return model
68
 
 
8
  from peft import PeftModel
9
 
10
  from .globals import Global
11
+ from .lib.get_device import get_device
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  def get_new_base_model(base_model_name):
 
48
  base_model_name, device_map={"": device}, low_cpu_mem_usage=True
49
  )
50
 
51
+ tokenizer = get_tokenizer(base_model_name)
52
+ model.config.pad_token_id = tokenizer.pad_token_id = 0
53
+ model.config.bos_token_id = tokenizer.bos_token_id = 1
54
+ model.config.eos_token_id = tokenizer.eos_token_id = 2
55
 
56
  return model
57
 
llama_lora/ui/finetune_ui.py CHANGED
@@ -79,56 +79,50 @@ def load_sample_dataset_to_text_input(format):
79
  return gr.Code.update(value=sample_plain_text_value)
80
 
81
 
82
- def process_json_dataset(data, only_first_n_items=None):
83
- if not isinstance(data, list):
84
- raise ValueError("The dataset is not an array of objects.")
85
-
86
- if only_first_n_items is not None:
87
- data = data[:only_first_n_items]
88
-
89
- first_item = get_val_from_arr(data, 0, None)
90
-
91
- if first_item is None:
92
- raise ValueError("The dataset is empty.")
93
- if not isinstance(first_item, dict):
94
- raise ValueError("The dataset is not an array of objects.")
95
-
96
- # Convert OpenAI fine-tuning dataset to LLaMA LoRA style
97
- if "completion" in first_item and "output" not in first_item:
98
- data = [
99
- {"output" if k == "completion" else k: v for k, v in d.items()}
100
- for d in data]
101
- first_item = get_val_from_arr(data, 0, None)
102
-
103
- # Flatten Stanford Alpaca style instances
104
- if "instances" in first_item and isinstance(first_item["instances"], list):
105
- data = [
106
- {"output" if k == "completion" else k: v for k, v in d.items()}
107
- for d in data]
108
- flattened_data = []
109
- for item in data:
110
- for instance in item["instances"]:
111
- d = {k: v for k, v in item.items() if k != "instances"}
112
- d.update(instance)
113
- flattened_data.append(d)
114
- data = flattened_data
115
- first_item = get_val_from_arr(data, 0, None)
116
-
117
- if "output" not in first_item:
118
- raise ValueError(
119
- "The data does not contains an \"output\" or \"completion\".")
120
-
121
- # Put all variables under the "variables" key if it does not exists
122
- if "variables" not in first_item:
123
- data = [
124
- {
125
- "variables":
126
- {k: v for k, v in d.items() if k != "output"},
127
- "output":
128
- d["output"]
129
- }
130
- for d in data
131
- ]
132
  return data
133
 
134
 
@@ -144,75 +138,59 @@ def refresh_preview(
144
  preview_show_actual_prompt,
145
  ):
146
  try:
147
- max_preview_count = 100
148
  prompter = Prompter(template)
149
  variable_names = prompter.get_variable_names()
150
 
151
- if load_dataset_from == "Text Input":
152
- if dataset_text_format == "JSON":
153
- data = json.loads(dataset_text)
154
- data = process_json_dataset(data)
155
-
156
- elif dataset_text_format == "JSON Lines":
157
- lines = dataset_text.split('\n')
158
- data = []
159
- for i, line in enumerate(lines):
160
- line_number = i + 1
161
- try:
162
- data.append(json.loads(line))
163
- except Exception as e:
164
- raise ValueError(
165
- f"Error parsing JSON on line {line_number}: {e}")
166
-
167
- data = process_json_dataset(data)
168
-
169
- else: # Plain Text
170
- data = parse_plain_text_input(
171
- dataset_text,
172
- (
173
- dataset_plain_text_input_variables_separator or
174
- default_dataset_plain_text_input_variables_separator
175
- ).replace("\\n", "\n"),
176
- (
177
- dataset_plain_text_input_and_output_separator or
178
- default_dataset_plain_text_input_and_output_separator
179
- ).replace("\\n", "\n"),
180
- (
181
- dataset_plain_text_data_separator or
182
- default_dataset_plain_text_data_separator
183
- ).replace("\\n", "\n"),
184
- variable_names
185
- )
186
 
187
- else: # Load dataset from data directory
188
- data = get_dataset_content(dataset_from_data_dir)
189
- data = process_json_dataset(data)
190
 
191
  data_count = len(data)
192
- headers = variable_names
 
193
  preview_data = [
194
- [item['variables'].get(name, "") for name in variable_names]
195
- for item in data[:max_preview_count]
196
  ]
197
 
198
- if preview_show_actual_prompt:
199
- headers = headers + ["Prompt (actual input)"]
200
- rendered = [prompter.generate_prompt(
201
- item['variables']) for item in data[:max_preview_count]]
202
- preview_data = result = [d + [i]
203
- for d, i in zip(preview_data, rendered)]
 
 
 
204
 
205
- headers = headers + ["Completion (output)"]
206
- preview_data = result = [pd + [d['output']]
207
- for pd, d in zip(preview_data, data[:max_preview_count])]
 
 
 
208
 
209
- preview_info_message = f"The dataset has a total of {data_count} item(s)."
 
 
 
 
210
  if data_count > max_preview_count:
211
  preview_info_message += f" Previewing the first {max_preview_count}."
212
 
213
  info_message = f"{data_count} item(s)."
214
  if load_dataset_from == "Data Dir":
215
- info_message = "This dataset contains " + info_message
216
  update_message = gr.Markdown.update(info_message, visible=True)
217
 
218
  return gr.Dataframe.update(value={'data': preview_data, 'headers': headers}), gr.Markdown.update(preview_info_message), update_message, update_message
@@ -288,57 +266,24 @@ def do_train(
288
  unload_models() # Need RAM for training
289
 
290
  prompter = Prompter(template)
291
- variable_names = prompter.get_variable_names()
292
-
293
- if load_dataset_from == "Text Input":
294
- if dataset_text_format == "JSON":
295
- data = json.loads(dataset_text)
296
- data = process_json_dataset(data)
297
-
298
- elif dataset_text_format == "JSON Lines":
299
- lines = dataset_text.split('\n')
300
- data = []
301
- for i, line in enumerate(lines):
302
- line_number = i + 1
303
- try:
304
- data.append(json.loads(line))
305
- except Exception as e:
306
- raise ValueError(
307
- f"Error parsing JSON on line {line_number}: {e}")
308
-
309
- data = process_json_dataset(data)
310
-
311
- else: # Plain Text
312
- data = parse_plain_text_input(
313
- dataset_text,
314
- (
315
- dataset_plain_text_input_variables_separator or
316
- default_dataset_plain_text_input_variables_separator
317
- ).replace("\\n", "\n"),
318
- (
319
- dataset_plain_text_input_and_output_separator or
320
- default_dataset_plain_text_input_and_output_separator
321
- ).replace("\\n", "\n"),
322
- (
323
- dataset_plain_text_data_separator or
324
- default_dataset_plain_text_data_separator
325
- ).replace("\\n", "\n"),
326
- variable_names
327
- )
328
 
329
- else: # Load dataset from data directory
330
- data = get_dataset_content(dataset_from_data_dir)
331
- data = process_json_dataset(data)
332
 
333
- data_count = len(data)
334
  evaluate_data_count = math.ceil(data_count * evaluate_data_percentage)
335
 
336
- train_data = [
337
- {
338
- 'prompt': prompter.generate_prompt(d['variables']),
339
- 'completion': d['output']}
340
- for d in data]
341
-
342
  def get_progress_text(epoch, epochs, last_loss):
343
  progress_detail = f"Epoch {math.ceil(epoch)}/{epochs}"
344
  if last_loss is not None:
@@ -449,26 +394,33 @@ Train data (first 10):
449
  'dataset_rows': len(train_data),
450
  'timestamp': time.time(),
451
 
452
- 'max_seq_length': max_seq_length,
453
- 'train_on_inputs': train_on_inputs,
 
454
 
455
- 'micro_batch_size': micro_batch_size,
456
- 'gradient_accumulation_steps': gradient_accumulation_steps,
457
- 'epochs': epochs,
458
- 'learning_rate': learning_rate,
459
 
460
- 'evaluate_data_percentage': evaluate_data_percentage,
461
 
462
- 'lora_r': lora_r,
463
- 'lora_alpha': lora_alpha,
464
- 'lora_dropout': lora_dropout,
465
- 'lora_target_modules': lora_target_modules,
466
  }
467
  json.dump(info, info_json_file, indent=2)
468
 
469
  if not should_training_progress_track_tqdm:
470
  progress(0, desc="Train starting...")
471
 
 
 
 
 
 
 
472
  train_output = Global.train_fn(
473
  base_model, # base_model
474
  tokenizer, # tokenizer
@@ -491,7 +443,12 @@ Train data (first 10):
491
  save_steps, # save_steps
492
  save_total_limit, # save_total_limit
493
  logging_steps, # logging_steps
494
- training_callbacks # callbacks
 
 
 
 
 
495
  )
496
 
497
  logs_str = "\n".join([json.dumps(log)
 
79
  return gr.Code.update(value=sample_plain_text_value)
80
 
81
 
82
+
83
+
84
+
85
+ def get_data_from_input(load_dataset_from, dataset_text, dataset_text_format,
86
+ dataset_plain_text_input_variables_separator,
87
+ dataset_plain_text_input_and_output_separator,
88
+ dataset_plain_text_data_separator,
89
+ dataset_from_data_dir, prompter):
90
+ if load_dataset_from == "Text Input":
91
+ if dataset_text_format == "JSON":
92
+ data = json.loads(dataset_text)
93
+
94
+ elif dataset_text_format == "JSON Lines":
95
+ lines = dataset_text.split('\n')
96
+ data = []
97
+ for i, line in enumerate(lines):
98
+ line_number = i + 1
99
+ try:
100
+ data.append(json.loads(line))
101
+ except Exception as e:
102
+ raise ValueError(
103
+ f"Error parsing JSON on line {line_number}: {e}")
104
+
105
+ else: # Plain Text
106
+ data = parse_plain_text_input(
107
+ dataset_text,
108
+ (
109
+ dataset_plain_text_input_variables_separator or
110
+ default_dataset_plain_text_input_variables_separator
111
+ ).replace("\\n", "\n"),
112
+ (
113
+ dataset_plain_text_input_and_output_separator or
114
+ default_dataset_plain_text_input_and_output_separator
115
+ ).replace("\\n", "\n"),
116
+ (
117
+ dataset_plain_text_data_separator or
118
+ default_dataset_plain_text_data_separator
119
+ ).replace("\\n", "\n"),
120
+ prompter.get_variable_names()
121
+ )
122
+
123
+ else: # Load dataset from data directory
124
+ data = get_dataset_content(dataset_from_data_dir)
125
+
 
 
 
 
 
 
126
  return data
127
 
128
 
 
138
  preview_show_actual_prompt,
139
  ):
140
  try:
141
+ max_preview_count = 30
142
  prompter = Prompter(template)
143
  variable_names = prompter.get_variable_names()
144
 
145
+ data = get_data_from_input(
146
+ load_dataset_from=load_dataset_from,
147
+ dataset_text=dataset_text,
148
+ dataset_text_format=dataset_text_format,
149
+ dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
150
+ dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
151
+ dataset_plain_text_data_separator=dataset_plain_text_data_separator,
152
+ dataset_from_data_dir=dataset_from_data_dir,
153
+ prompter=prompter
154
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ train_data = prompter.get_train_data_from_dataset(data, max_preview_count)
 
 
157
 
158
  data_count = len(data)
159
+
160
+ headers = ['Prompt', 'Completion']
161
  preview_data = [
162
+ [item.get("prompt", ""), item.get("completion", "")]
163
+ for item in train_data
164
  ]
165
 
166
+ if not prompter.template_module:
167
+ variable_names = prompter.get_variable_names()
168
+ headers += [f"Variable: {variable_name}" for variable_name in variable_names]
169
+ variables = [
170
+ [item.get(f"_var_{name}", "") for name in variable_names]
171
+ for item in train_data
172
+ ]
173
+ preview_data = [d + v for d, v in zip(preview_data, variables)]
174
+
175
 
176
+ # if preview_show_actual_prompt:
177
+ # headers = headers + ["Prompt (actual input)"]
178
+ # rendered = [prompter.generate_prompt(
179
+ # item['variables']) for item in data[:max_preview_count]]
180
+ # preview_data = result = [d + [i]
181
+ # for d, i in zip(preview_data, rendered)]
182
 
183
+ # headers = headers + ["Completion (output)"]
184
+ # preview_data = result = [pd + [d['output']]
185
+ # for pd, d in zip(preview_data, data[:max_preview_count])]
186
+
187
+ preview_info_message = f"The dataset has about {data_count} item(s)."
188
  if data_count > max_preview_count:
189
  preview_info_message += f" Previewing the first {max_preview_count}."
190
 
191
  info_message = f"{data_count} item(s)."
192
  if load_dataset_from == "Data Dir":
193
+ info_message = "This dataset contains about " + info_message
194
  update_message = gr.Markdown.update(info_message, visible=True)
195
 
196
  return gr.Dataframe.update(value={'data': preview_data, 'headers': headers}), gr.Markdown.update(preview_info_message), update_message, update_message
 
266
  unload_models() # Need RAM for training
267
 
268
  prompter = Prompter(template)
269
+ # variable_names = prompter.get_variable_names()
270
+
271
+ data = get_data_from_input(
272
+ load_dataset_from=load_dataset_from,
273
+ dataset_text=dataset_text,
274
+ dataset_text_format=dataset_text_format,
275
+ dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
276
+ dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
277
+ dataset_plain_text_data_separator=dataset_plain_text_data_separator,
278
+ dataset_from_data_dir=dataset_from_data_dir,
279
+ prompter=prompter
280
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
+ train_data = prompter.get_train_data_from_dataset(data)
 
 
283
 
284
+ data_count = len(train_data)
285
  evaluate_data_count = math.ceil(data_count * evaluate_data_percentage)
286
 
 
 
 
 
 
 
287
  def get_progress_text(epoch, epochs, last_loss):
288
  progress_detail = f"Epoch {math.ceil(epoch)}/{epochs}"
289
  if last_loss is not None:
 
394
  'dataset_rows': len(train_data),
395
  'timestamp': time.time(),
396
 
397
+ # These will be saved in another JSON file by the train function
398
+ # 'max_seq_length': max_seq_length,
399
+ # 'train_on_inputs': train_on_inputs,
400
 
401
+ # 'micro_batch_size': micro_batch_size,
402
+ # 'gradient_accumulation_steps': gradient_accumulation_steps,
403
+ # 'epochs': epochs,
404
+ # 'learning_rate': learning_rate,
405
 
406
+ # 'evaluate_data_percentage': evaluate_data_percentage,
407
 
408
+ # 'lora_r': lora_r,
409
+ # 'lora_alpha': lora_alpha,
410
+ # 'lora_dropout': lora_dropout,
411
+ # 'lora_target_modules': lora_target_modules,
412
  }
413
  json.dump(info, info_json_file, indent=2)
414
 
415
  if not should_training_progress_track_tqdm:
416
  progress(0, desc="Train starting...")
417
 
418
+ wandb_group = template
419
+ wandb_tags = [f"template:{template}"]
420
+ if load_dataset_from == "Data Dir" and dataset_from_data_dir:
421
+ wandb_group += f"/{dataset_from_data_dir}"
422
+ wandb_tags.append(f"dataset:{dataset_from_data_dir}")
423
+
424
  train_output = Global.train_fn(
425
  base_model, # base_model
426
  tokenizer, # tokenizer
 
443
  save_steps, # save_steps
444
  save_total_limit, # save_total_limit
445
  logging_steps, # logging_steps
446
+ training_callbacks, # callbacks
447
+ Global.wandb_api_key, # wandb_api_key
448
+ Global.default_wandb_project if Global.enable_wandb else None, # wandb_project
449
+ wandb_group, # wandb_group
450
+ model_name, # wandb_run_name
451
+ wandb_tags # wandb_tags
452
  )
453
 
454
  logs_str = "\n".join([json.dumps(log)
llama_lora/ui/inference_ui.py CHANGED
@@ -8,12 +8,12 @@ from transformers import GenerationConfig
8
 
9
  from ..globals import Global
10
  from ..models import get_model, get_tokenizer, get_device
 
11
  from ..utils.data import (
12
  get_available_template_names,
13
  get_available_lora_model_names,
14
  get_info_of_available_lora_model)
15
  from ..utils.prompter import Prompter
16
- from ..utils.callbacks import Iteratorize, Stream
17
 
18
  device = get_device()
19
 
@@ -103,8 +103,6 @@ def do_inference(
103
  tokenizer = get_tokenizer(base_model_name)
104
  model = get_model(base_model_name, lora_model_name)
105
 
106
- inputs = tokenizer(prompt, return_tensors="pt")
107
- input_ids = inputs["input_ids"].to(device)
108
  generation_config = GenerationConfig(
109
  temperature=temperature,
110
  top_p=top_p,
@@ -113,26 +111,56 @@ def do_inference(
113
  num_beams=num_beams,
114
  )
115
 
116
- generate_params = {
117
- "input_ids": input_ids,
118
- "generation_config": generation_config,
119
- "return_dict_in_generate": True,
120
- "output_scores": True,
121
- "max_new_tokens": max_new_tokens,
122
- }
123
-
124
  def ui_generation_stopping_criteria(input_ids, score, **kwargs):
125
  if Global.should_stop_generating:
126
  return True
127
  return False
128
 
129
  Global.should_stop_generating = False
130
- generate_params.setdefault(
131
- "stopping_criteria", transformers.StoppingCriteriaList()
132
- )
133
- generate_params["stopping_criteria"].append(
134
- ui_generation_stopping_criteria
135
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  if stream_output:
138
  # Stream the reply 1 token at a time.
 
8
 
9
  from ..globals import Global
10
  from ..models import get_model, get_tokenizer, get_device
11
+ from ..lib.inference import generate
12
  from ..utils.data import (
13
  get_available_template_names,
14
  get_available_lora_model_names,
15
  get_info_of_available_lora_model)
16
  from ..utils.prompter import Prompter
 
17
 
18
  device = get_device()
19
 
 
103
  tokenizer = get_tokenizer(base_model_name)
104
  model = get_model(base_model_name, lora_model_name)
105
 
 
 
106
  generation_config = GenerationConfig(
107
  temperature=temperature,
108
  top_p=top_p,
 
111
  num_beams=num_beams,
112
  )
113
 
 
 
 
 
 
 
 
 
114
  def ui_generation_stopping_criteria(input_ids, score, **kwargs):
115
  if Global.should_stop_generating:
116
  return True
117
  return False
118
 
119
  Global.should_stop_generating = False
120
+
121
+ generation_args = {
122
+ 'model': model,
123
+ 'tokenizer': tokenizer,
124
+ 'prompt': prompt,
125
+ 'generation_config': generation_config,
126
+ 'max_new_tokens': max_new_tokens,
127
+ 'stopping_criteria': [ui_generation_stopping_criteria],
128
+ 'stream_output': stream_output
129
+ }
130
+
131
+ for (decoded_output, output) in generate(**generation_args):
132
+ raw_output_str = None
133
+ if show_raw:
134
+ raw_output_str = str(output)
135
+ response = prompter.get_response(decoded_output)
136
+
137
+ if Global.should_stop_generating:
138
+ return
139
+
140
+ yield (
141
+ gr.Textbox.update(
142
+ value=response, lines=inference_output_lines),
143
+ raw_output_str)
144
+
145
+ if Global.should_stop_generating:
146
+ # If the user stops the generation, and then clicks the
147
+ # generation button again, they may mysteriously landed
148
+ # here, in the previous, should-be-stopped generation
149
+ # function call, with the new generation function not be
150
+ # called at all. To workaround this, we yield a message
151
+ # and setting lines=1, and if the front-end JS detects
152
+ # that lines has been set to 1 (rows="1" in HTML),
153
+ # it will automatically click the generate button again
154
+ # (gr.Textbox.update() does not support updating
155
+ # elem_classes or elem_id).
156
+ # [WORKAROUND-UI01]
157
+ yield (
158
+ gr.Textbox.update(
159
+ value="Please retry", lines=1),
160
+ None)
161
+
162
+ return
163
+
164
 
165
  if stream_output:
166
  # Stream the reply 1 token at a time.
llama_lora/utils/data.py CHANGED
@@ -30,7 +30,7 @@ def copy_sample_data_if_not_exists(source, destination):
30
  def get_available_template_names():
31
  templates_directory_path = os.path.join(Global.data_dir, "templates")
32
  all_files = os.listdir(templates_directory_path)
33
- return [os.path.splitext(filename)[0] for filename in all_files if fnmatch.fnmatch(filename, "*.json")]
34
 
35
 
36
  def get_available_dataset_names():
 
30
  def get_available_template_names():
31
  templates_directory_path = os.path.join(Global.data_dir, "templates")
32
  all_files = os.listdir(templates_directory_path)
33
+ return [filename.rstrip(".json") for filename in all_files if fnmatch.fnmatch(filename, "*.json") or fnmatch.fnmatch(filename, "*.py")]
34
 
35
 
36
  def get_available_dataset_names():
llama_lora/utils/prompter.py CHANGED
@@ -5,13 +5,15 @@ From https://github.com/tloen/alpaca-lora/blob/main/utils/prompter.py
5
 
6
  import json
7
  import os.path as osp
 
 
8
  from typing import Union, List
9
 
10
  from ..globals import Global
11
 
12
 
13
  class Prompter(object):
14
- __slots__ = ("template_name", "template", "_verbose")
15
 
16
  def __init__(self, template_name: str = "", verbose: bool = False):
17
  self._verbose = verbose
@@ -21,12 +23,41 @@ class Prompter(object):
21
  self.template_name = "None"
22
  return
23
  self.template_name = template_name
 
24
 
25
- file_name = osp.join(Global.data_dir, "templates",
26
- f"{template_name}.json")
27
- if not osp.exists(file_name):
28
- raise ValueError(f"Can't read {file_name}")
29
- with open(file_name) as fp:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  self.template = json.load(fp)
31
  if self._verbose:
32
  print(
@@ -47,23 +78,31 @@ class Prompter(object):
47
  res = variables.get("prompt", "")
48
  elif "variables" in self.template:
49
  variable_names = self.template.get("variables")
50
- if type(variables) == dict:
51
- variables = [variables.get(name, None)
52
- for name in variable_names]
53
- if "default" not in self.template:
54
- raise ValueError(
55
- f"The template {self.template_name} has \"variables\" defined but does not has a default prompt defined. Please do it like: '\"default\": \"prompt_with_instruction\"' to handle cases when a matching prompt can't be found.")
56
- default_prompt_name = self.template.get("default")
57
- if default_prompt_name not in self.template:
58
- raise ValueError(
59
- f"The template {self.template_name} has \"default\" set to \"{default_prompt_name}\" but it's not defined. Please do it like: '\"{default_prompt_name}\": \"...\".")
60
- prompt_name = get_prompt_name(variables, variable_names)
61
- prompt_template = self.template.get(default_prompt_name)
62
- if prompt_name in self.template:
63
- prompt_template = self.template.get(prompt_name)
64
 
65
- res = prompt_template.format(
66
- **variables_to_dict(variables, variable_names))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  else:
69
  if type(variables) == dict:
@@ -104,6 +143,30 @@ class Prompter(object):
104
  else:
105
  return ["instruction", "input"]
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def get_val(arr, index, default=None):
109
  return arr[index] if -len(arr) <= index < len(arr) else default
@@ -117,3 +180,57 @@ def get_prompt_name(variables, variable_names):
117
 
118
  def variables_to_dict(variables, variable_names):
119
  return {key: (variables[i] if i < len(variables) and variables[i] is not None else '') for i, key in enumerate(variable_names)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  import json
7
  import os.path as osp
8
+ import importlib
9
+ import itertools
10
  from typing import Union, List
11
 
12
  from ..globals import Global
13
 
14
 
15
  class Prompter(object):
16
+ __slots__ = ("template_name", "template", "template_module", "_verbose")
17
 
18
  def __init__(self, template_name: str = "", verbose: bool = False):
19
  self._verbose = verbose
 
23
  self.template_name = "None"
24
  return
25
  self.template_name = template_name
26
+ self.template_module = None
27
 
28
+ base_filename, ext = osp.splitext(template_name)
29
+ if ext == "":
30
+ filename = base_filename + ".json"
31
+ else:
32
+ filename = base_filename + ext
33
+
34
+ file_path = osp.join(Global.data_dir, "templates", filename)
35
+
36
+ if not osp.exists(file_path):
37
+ raise ValueError(f"Can't read {file_path}")
38
+
39
+ if ext == ".py":
40
+ template_module_spec = importlib.util.spec_from_file_location(
41
+ "template_module", file_path)
42
+ template_module = importlib.util.module_from_spec(
43
+ template_module_spec)
44
+ template_module_spec.loader.exec_module(template_module)
45
+ self.template_module = template_module
46
+
47
+ if not hasattr(template_module, "variables"):
48
+ raise ValueError(
49
+ "The template module does not have a \"variables\" attribute.")
50
+
51
+ self.template = {
52
+ 'variables': template_module.variables
53
+ }
54
+
55
+ if hasattr(template_module, "response_split"):
56
+ self.template["response_split"] = template_module.response_split
57
+
58
+ return
59
+
60
+ with open(file_path) as fp:
61
  self.template = json.load(fp)
62
  if self._verbose:
63
  print(
 
78
  res = variables.get("prompt", "")
79
  elif "variables" in self.template:
80
  variable_names = self.template.get("variables")
81
+ if self.template_module:
82
+ if type(variables) == list:
83
+ variables = {k: v for k, v in zip(
84
+ variable_names, variables)}
 
 
 
 
 
 
 
 
 
 
85
 
86
+ res = self.template_module.get_prompt(variables)
87
+ else:
88
+ if type(variables) == dict:
89
+ variables = [variables.get(name, None)
90
+ for name in variable_names]
91
+
92
+ if "default" not in self.template:
93
+ raise ValueError(
94
+ f"The template {self.template_name} has \"variables\" defined but does not has a default prompt defined. Please do it like: '\"default\": \"prompt_with_instruction\"' to handle cases when a matching prompt can't be found.")
95
+ default_prompt_name = self.template.get("default")
96
+ if default_prompt_name not in self.template:
97
+ raise ValueError(
98
+ f"The template {self.template_name} has \"default\" set to \"{default_prompt_name}\" but it's not defined. Please do it like: '\"{default_prompt_name}\": \"...\".")
99
+ prompt_name = get_prompt_name(variables, variable_names)
100
+ prompt_template = self.template.get(default_prompt_name)
101
+ if prompt_name in self.template:
102
+ prompt_template = self.template.get(prompt_name)
103
+
104
+ res = prompt_template.format(
105
+ **variables_to_dict(variables, variable_names))
106
 
107
  else:
108
  if type(variables) == dict:
 
143
  else:
144
  return ["instruction", "input"]
145
 
146
+ def get_train_data_from_dataset(self, data, only_first_n_items=None):
147
+ if self.template_module:
148
+ if hasattr(self.template_module, "get_train_data_list_from_dataset"):
149
+ data = self.template_module.get_train_data_list_from_dataset(
150
+ data)
151
+ if only_first_n_items:
152
+ data = data[:only_first_n_items]
153
+ return list(itertools.chain(*list(map(self.template_module.get_train_data, data))))
154
+
155
+ if only_first_n_items:
156
+ data = data[:only_first_n_items]
157
+
158
+ data = process_json_dataset(data)
159
+
160
+ train_data = [
161
+ {
162
+ 'prompt': self.generate_prompt(d['variables']),
163
+ 'completion': d['output'],
164
+ **{"_var_" + k: v for k, v in d['variables'].items()}
165
+ }
166
+ for d in data]
167
+
168
+ return train_data
169
+
170
 
171
  def get_val(arr, index, default=None):
172
  return arr[index] if -len(arr) <= index < len(arr) else default
 
180
 
181
  def variables_to_dict(variables, variable_names):
182
  return {key: (variables[i] if i < len(variables) and variables[i] is not None else '') for i, key in enumerate(variable_names)}
183
+
184
+
185
+ def process_json_dataset(data):
186
+ if not isinstance(data, list):
187
+ raise ValueError("The dataset is not an array of objects.")
188
+
189
+ first_item = get_val_from_arr(data, 0, None)
190
+
191
+ if first_item is None:
192
+ raise ValueError("The dataset is empty.")
193
+ if not isinstance(first_item, dict):
194
+ raise ValueError("The dataset is not an array of objects.")
195
+
196
+ # Convert OpenAI fine-tuning dataset to LLaMA LoRA style
197
+ if "completion" in first_item and "output" not in first_item:
198
+ data = [
199
+ {"output" if k == "completion" else k: v for k, v in d.items()}
200
+ for d in data]
201
+ first_item = get_val_from_arr(data, 0, None)
202
+
203
+ # Flatten Stanford Alpaca style instances
204
+ if "instances" in first_item and isinstance(first_item["instances"], list):
205
+ data = [
206
+ {"output" if k == "completion" else k: v for k, v in d.items()}
207
+ for d in data]
208
+ flattened_data = []
209
+ for item in data:
210
+ for instance in item["instances"]:
211
+ d = {k: v for k, v in item.items() if k != "instances"}
212
+ d.update(instance)
213
+ flattened_data.append(d)
214
+ data = flattened_data
215
+ first_item = get_val_from_arr(data, 0, None)
216
+
217
+ if "output" not in first_item:
218
+ raise ValueError(
219
+ "The data does not contains an \"output\" or \"completion\".")
220
+
221
+ # Put all variables under the "variables" key if it does not exists
222
+ if "variables" not in first_item:
223
+ data = [
224
+ {
225
+ "variables":
226
+ {k: v for k, v in d.items() if k != "output"},
227
+ "output":
228
+ d["output"]
229
+ }
230
+ for d in data
231
+ ]
232
+ return data
233
+
234
+
235
+ def get_val_from_arr(arr, index, default=None):
236
+ return arr[index] if -len(arr) <= index < len(arr) else default