zetavg commited on
Commit
4a6324a
·
unverified ·
1 Parent(s): 6ba132a

format code

Browse files
llama_lora/lib/finetune.py CHANGED
@@ -87,6 +87,7 @@ def train(
87
  # os.environ["WANDB_PROJECT"] = wandb_project
88
  # if wandb_run_name:
89
  # os.environ["WANDB_RUN_NAME"] = wandb_run_name
 
90
  if wandb_watch:
91
  os.environ["WANDB_WATCH"] = wandb_watch
92
  if wandb_log_model:
 
87
  # os.environ["WANDB_PROJECT"] = wandb_project
88
  # if wandb_run_name:
89
  # os.environ["WANDB_RUN_NAME"] = wandb_run_name
90
+
91
  if wandb_watch:
92
  os.environ["WANDB_WATCH"] = wandb_watch
93
  if wandb_log_model:
llama_lora/ui/finetune_ui.py CHANGED
@@ -79,9 +79,6 @@ def load_sample_dataset_to_text_input(format):
79
  return gr.Code.update(value=sample_plain_text_value)
80
 
81
 
82
-
83
-
84
-
85
  def get_data_from_input(load_dataset_from, dataset_text, dataset_text_format,
86
  dataset_plain_text_input_variables_separator,
87
  dataset_plain_text_input_and_output_separator,
@@ -153,7 +150,8 @@ def refresh_preview(
153
  prompter=prompter
154
  )
155
 
156
- train_data = prompter.get_train_data_from_dataset(data, max_preview_count)
 
157
 
158
  data_count = len(data)
159
 
@@ -172,18 +170,6 @@ def refresh_preview(
172
  ]
173
  preview_data = [d + v for d, v in zip(preview_data, variables)]
174
 
175
-
176
- # if preview_show_actual_prompt:
177
- # headers = headers + ["Prompt (actual input)"]
178
- # rendered = [prompter.generate_prompt(
179
- # item['variables']) for item in data[:max_preview_count]]
180
- # preview_data = result = [d + [i]
181
- # for d, i in zip(preview_data, rendered)]
182
-
183
- # headers = headers + ["Completion (output)"]
184
- # preview_data = result = [pd + [d['output']]
185
- # for pd, d in zip(preview_data, data[:max_preview_count])]
186
-
187
  preview_info_message = f"The dataset has about {data_count} item(s)."
188
  if data_count > max_preview_count:
189
  preview_info_message += f" Previewing the first {max_preview_count}."
 
79
  return gr.Code.update(value=sample_plain_text_value)
80
 
81
 
 
 
 
82
  def get_data_from_input(load_dataset_from, dataset_text, dataset_text_format,
83
  dataset_plain_text_input_variables_separator,
84
  dataset_plain_text_input_and_output_separator,
 
150
  prompter=prompter
151
  )
152
 
153
+ train_data = prompter.get_train_data_from_dataset(
154
+ data, max_preview_count)
155
 
156
  data_count = len(data)
157
 
 
170
  ]
171
  preview_data = [d + v for d, v in zip(preview_data, variables)]
172
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  preview_info_message = f"The dataset has about {data_count} item(s)."
174
  if data_count > max_preview_count:
175
  preview_info_message += f" Previewing the first {max_preview_count}."
llama_lora/utils/prompter.py CHANGED
@@ -139,18 +139,21 @@ class Prompter(object):
139
  if self.template_name == "None":
140
  return ["prompt"]
141
  elif "variables" in self.template:
142
- return self.template.get("variables")
143
  else:
144
  return ["instruction", "input"]
145
 
146
  def get_train_data_from_dataset(self, data, only_first_n_items=None):
147
  if self.template_module:
148
- if hasattr(self.template_module, "get_train_data_list_from_dataset"):
 
149
  data = self.template_module.get_train_data_list_from_dataset(
150
  data)
151
  if only_first_n_items:
152
  data = data[:only_first_n_items]
153
- return list(itertools.chain(*list(map(self.template_module.get_train_data, data))))
 
 
154
 
155
  if only_first_n_items:
156
  data = data[:only_first_n_items]
@@ -179,7 +182,11 @@ def get_prompt_name(variables, variable_names):
179
 
180
 
181
  def variables_to_dict(variables, variable_names):
182
- return {key: (variables[i] if i < len(variables) and variables[i] is not None else '') for i, key in enumerate(variable_names)}
 
 
 
 
183
 
184
 
185
  def process_json_dataset(data):
 
139
  if self.template_name == "None":
140
  return ["prompt"]
141
  elif "variables" in self.template:
142
+ return self.template['variables']
143
  else:
144
  return ["instruction", "input"]
145
 
146
  def get_train_data_from_dataset(self, data, only_first_n_items=None):
147
  if self.template_module:
148
+ if hasattr(self.template_module,
149
+ "get_train_data_list_from_dataset"):
150
  data = self.template_module.get_train_data_list_from_dataset(
151
  data)
152
  if only_first_n_items:
153
  data = data[:only_first_n_items]
154
+ return list(itertools.chain(*list(
155
+ map(self.template_module.get_train_data, data)
156
+ )))
157
 
158
  if only_first_n_items:
159
  data = data[:only_first_n_items]
 
182
 
183
 
184
  def variables_to_dict(variables, variable_names):
185
+ return {
186
+ key: (variables[i] if i < len(variables)
187
+ and variables[i] is not None else '')
188
+ for i, key in enumerate(variable_names)
189
+ }
190
 
191
 
192
  def process_json_dataset(data):