Spaces:
Runtime error
Runtime error
zetavg
commited on
format code
Browse files- llama_lora/lib/finetune.py +1 -0
- llama_lora/ui/finetune_ui.py +2 -16
- llama_lora/utils/prompter.py +11 -4
llama_lora/lib/finetune.py
CHANGED
@@ -87,6 +87,7 @@ def train(
|
|
87 |
# os.environ["WANDB_PROJECT"] = wandb_project
|
88 |
# if wandb_run_name:
|
89 |
# os.environ["WANDB_RUN_NAME"] = wandb_run_name
|
|
|
90 |
if wandb_watch:
|
91 |
os.environ["WANDB_WATCH"] = wandb_watch
|
92 |
if wandb_log_model:
|
|
|
87 |
# os.environ["WANDB_PROJECT"] = wandb_project
|
88 |
# if wandb_run_name:
|
89 |
# os.environ["WANDB_RUN_NAME"] = wandb_run_name
|
90 |
+
|
91 |
if wandb_watch:
|
92 |
os.environ["WANDB_WATCH"] = wandb_watch
|
93 |
if wandb_log_model:
|
llama_lora/ui/finetune_ui.py
CHANGED
@@ -79,9 +79,6 @@ def load_sample_dataset_to_text_input(format):
|
|
79 |
return gr.Code.update(value=sample_plain_text_value)
|
80 |
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
def get_data_from_input(load_dataset_from, dataset_text, dataset_text_format,
|
86 |
dataset_plain_text_input_variables_separator,
|
87 |
dataset_plain_text_input_and_output_separator,
|
@@ -153,7 +150,8 @@ def refresh_preview(
|
|
153 |
prompter=prompter
|
154 |
)
|
155 |
|
156 |
-
train_data = prompter.get_train_data_from_dataset(
|
|
|
157 |
|
158 |
data_count = len(data)
|
159 |
|
@@ -172,18 +170,6 @@ def refresh_preview(
|
|
172 |
]
|
173 |
preview_data = [d + v for d, v in zip(preview_data, variables)]
|
174 |
|
175 |
-
|
176 |
-
# if preview_show_actual_prompt:
|
177 |
-
# headers = headers + ["Prompt (actual input)"]
|
178 |
-
# rendered = [prompter.generate_prompt(
|
179 |
-
# item['variables']) for item in data[:max_preview_count]]
|
180 |
-
# preview_data = result = [d + [i]
|
181 |
-
# for d, i in zip(preview_data, rendered)]
|
182 |
-
|
183 |
-
# headers = headers + ["Completion (output)"]
|
184 |
-
# preview_data = result = [pd + [d['output']]
|
185 |
-
# for pd, d in zip(preview_data, data[:max_preview_count])]
|
186 |
-
|
187 |
preview_info_message = f"The dataset has about {data_count} item(s)."
|
188 |
if data_count > max_preview_count:
|
189 |
preview_info_message += f" Previewing the first {max_preview_count}."
|
|
|
79 |
return gr.Code.update(value=sample_plain_text_value)
|
80 |
|
81 |
|
|
|
|
|
|
|
82 |
def get_data_from_input(load_dataset_from, dataset_text, dataset_text_format,
|
83 |
dataset_plain_text_input_variables_separator,
|
84 |
dataset_plain_text_input_and_output_separator,
|
|
|
150 |
prompter=prompter
|
151 |
)
|
152 |
|
153 |
+
train_data = prompter.get_train_data_from_dataset(
|
154 |
+
data, max_preview_count)
|
155 |
|
156 |
data_count = len(data)
|
157 |
|
|
|
170 |
]
|
171 |
preview_data = [d + v for d, v in zip(preview_data, variables)]
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
preview_info_message = f"The dataset has about {data_count} item(s)."
|
174 |
if data_count > max_preview_count:
|
175 |
preview_info_message += f" Previewing the first {max_preview_count}."
|
llama_lora/utils/prompter.py
CHANGED
@@ -139,18 +139,21 @@ class Prompter(object):
|
|
139 |
if self.template_name == "None":
|
140 |
return ["prompt"]
|
141 |
elif "variables" in self.template:
|
142 |
-
return self.template
|
143 |
else:
|
144 |
return ["instruction", "input"]
|
145 |
|
146 |
def get_train_data_from_dataset(self, data, only_first_n_items=None):
|
147 |
if self.template_module:
|
148 |
-
if hasattr(self.template_module,
|
|
|
149 |
data = self.template_module.get_train_data_list_from_dataset(
|
150 |
data)
|
151 |
if only_first_n_items:
|
152 |
data = data[:only_first_n_items]
|
153 |
-
return list(itertools.chain(*list(
|
|
|
|
|
154 |
|
155 |
if only_first_n_items:
|
156 |
data = data[:only_first_n_items]
|
@@ -179,7 +182,11 @@ def get_prompt_name(variables, variable_names):
|
|
179 |
|
180 |
|
181 |
def variables_to_dict(variables, variable_names):
|
182 |
-
return {
|
|
|
|
|
|
|
|
|
183 |
|
184 |
|
185 |
def process_json_dataset(data):
|
|
|
139 |
if self.template_name == "None":
|
140 |
return ["prompt"]
|
141 |
elif "variables" in self.template:
|
142 |
+
return self.template['variables']
|
143 |
else:
|
144 |
return ["instruction", "input"]
|
145 |
|
146 |
def get_train_data_from_dataset(self, data, only_first_n_items=None):
|
147 |
if self.template_module:
|
148 |
+
if hasattr(self.template_module,
|
149 |
+
"get_train_data_list_from_dataset"):
|
150 |
data = self.template_module.get_train_data_list_from_dataset(
|
151 |
data)
|
152 |
if only_first_n_items:
|
153 |
data = data[:only_first_n_items]
|
154 |
+
return list(itertools.chain(*list(
|
155 |
+
map(self.template_module.get_train_data, data)
|
156 |
+
)))
|
157 |
|
158 |
if only_first_n_items:
|
159 |
data = data[:only_first_n_items]
|
|
|
182 |
|
183 |
|
184 |
def variables_to_dict(variables, variable_names):
|
185 |
+
return {
|
186 |
+
key: (variables[i] if i < len(variables)
|
187 |
+
and variables[i] is not None else '')
|
188 |
+
for i, key in enumerate(variable_names)
|
189 |
+
}
|
190 |
|
191 |
|
192 |
def process_json_dataset(data):
|