Spaces:
Runtime error
Runtime error
""" | |
Not used as part of the streamlit app, but run offline to prepare training for fine-tuning. | |
""" | |
import argparse | |
import os | |
import stat | |
import pandas as pd | |
from abc import ABC, abstractmethod | |
from copy import copy | |
from random import choice, shuffle | |
from time import time | |
from typing import Tuple, Generator | |
from src.common import data_dir, join_items_comma_and, random_true_false, pop_n | |
from src.datatypes import * | |
parser = argparse.ArgumentParser(prog="prep_finetuning", | |
description="Fine tune a llama 2 model and push to hugging face hub for serving") | |
parser.add_argument('-base_model', required=True, help="The base model to use") | |
parser.add_argument('-products_db', required=True, help="The products sqlite to train on") | |
parser.add_argument('-fine_tuned_model', required=True, help="The target model name in hugging face hub") | |
parser.add_argument('-hf_user', required=True, help="The hugging face user to write the model to the hub") | |
parser.add_argument('-hf_token', required=True, help="The hugging face token to write the model to the hub") | |
args = parser.parse_args() | |
class TrainingDataGenerator(ABC): | |
""" | |
Abstract class to generate fine-tuning training data. Implemented | |
as a generator to minimise passing around large lists unnecessarily | |
""" | |
def generate(self) -> Generator[Tuple[str, str], None, None]: | |
""" | |
Required to be implemented by the generator implementation. | |
:return: should yield pairs of training data as question, answer | |
""" | |
pass | |
class CategoryDataGenerator(TrainingDataGenerator): | |
""" | |
Concrete implementation to build training data about product categories | |
""" | |
def generate(self) -> Generator[Tuple[str, str], None, None]: | |
# 1. First build "what do you offer" type queries | |
cat_names = [c.name for c in Category.all.values()] | |
category_synonyms = ["types", "categories", "kinds"] | |
for _ in range(5): | |
shuffle(cat_names) | |
cat = choice(category_synonyms) | |
q = f"What {cat} of products do you offer?" | |
a = f"ElectroHome offers {join_items_comma_and(cat_names)}." | |
yield q, a | |
# 2. Now build some "product in category" type queries | |
for c in Category.all.values(): | |
prod_names = [p.name for p in c.products] | |
total_prod_count = len(prod_names) | |
for _ in range(1): | |
working_prod_names = copy(prod_names) | |
shuffle(working_prod_names) | |
while len(working_prod_names) > 0: | |
ans_product_names = pop_n(working_prod_names, 3) | |
q = f"What {c.name} do you have?" | |
a = f"We have {total_prod_count} {c.name}. For example {join_items_comma_and(ans_product_names)}. What are you lookinh for in a {c.name[:-1]}, and I can help to guide you?" | |
yield q, a | |
class ProductDescriptionDataGenerator(TrainingDataGenerator): | |
""" | |
Concrete implementation to build training data from the marketing description and price | |
""" | |
def generate(self) -> Generator[Tuple[str, str], None, None]: | |
for p in Product.all.values(): | |
question_templates = [ | |
"Tell me about the #.", | |
"Describe the # for me.", | |
"What can you tell me about the #?" | |
] | |
q = choice(question_templates).replace("#", p.name) | |
# Mix up the paths to include price or not and at the start/end of the response | |
if random_true_false(): | |
a = p.description | |
else: | |
if random_true_false(): | |
a = f"{p.description} It typically retails for ${p.price}." | |
else: | |
a = f"The {p.name} typically retails for ${p.price}. {p.description}" | |
yield q, a | |
class PriceDataGenerator(TrainingDataGenerator): | |
""" | |
Concrete implementation to build training data just for pricing | |
""" | |
def generate(self) -> Generator[Tuple[str, str], None, None]: | |
for p in Product.all.values(): | |
question_templates = [ | |
"How much is the #?", | |
"What does the # cost?", | |
"What's the price of the #?" | |
] | |
q = choice(question_templates).replace("#", p.name) | |
answer_templates = [ | |
"It typically retails for $#.", | |
"Our recommended retail price is $#, but do check with your stockist.", | |
"The list price is $# but check in our online store if there are any offers on (www.electrohome.com)." | |
] | |
a = choice(answer_templates).replace("#", str(p.price)) | |
yield q, a | |
class FeatureDataGenerator(TrainingDataGenerator): | |
""" | |
Concrete implementation to build training data just for pricing | |
""" | |
def generate(self) -> Generator[Tuple[str, str], None, None]: | |
# 1. First generate Q&A for what features are available by category | |
for c in Category.all.values(): | |
cat_features = [f.name for f in c.features] | |
for _ in range(1): | |
working_cat_features = copy(cat_features) | |
shuffle(working_cat_features) | |
while len(working_cat_features) > 0: | |
some_features = pop_n(working_cat_features, 3) | |
feature_clause = join_items_comma_and(some_features) | |
question_templates = [ | |
"What features do your # have?", | |
"What should I think about when I'm considering #?", | |
"What sort of things differentiate your #?" | |
] | |
q = choice(question_templates).replace("#", c.name) | |
answer_templates = [ | |
"Our # have features like ##.", | |
"You might want to consider things like ## which # offer.", | |
"# have lots of different features, like ##." | |
] | |
a = choice(answer_templates).replace("##", feature_clause).replace("#", c.name) | |
yield q, a | |
# 2. Now generate questions the other way around - i.e. search products by feature | |
for f in Feature.all.values(): | |
cat_name = f.category.name | |
prod_names = [p.name for p in f.products] | |
for _ in range(1): | |
working_prod_names = copy(prod_names) | |
while len(working_prod_names) > 0: | |
some_prods = pop_n(working_prod_names, 3) | |
if len(some_prods) > 1: # Single product examples mess up some trainind data | |
prod_clause = join_items_comma_and(some_prods) | |
q = f"Which {cat_name} offer {f.name}?" | |
answer_templates = [ | |
"## are # which offer ###.", | |
"## have ###.", | |
"We have some great # which offer ### including our ##." | |
] | |
a = choice(answer_templates).replace("###", f.name).replace("##", prod_clause).replace("#", cat_name) | |
yield q, a | |
else: | |
q = f"Which {cat_name} offer {f.name}?" | |
a = f"The {some_prods[0]} has {f.name}." | |
yield q, a | |
def training_string_from_q_and_a(q: str, a: str, sys_prompt: str = None) -> str: | |
""" | |
Build the single llama formatted training string from a question | |
answer pair | |
""" | |
return f'User: {q}\nBot: {a}' | |
def fine_tuning_out_dir(out_model: str) -> str: | |
""" | |
Utility to generate the full path to the output directory, creating it if it is not there | |
""" | |
out_dir = os.path.join(data_dir, 'fine_tuning', out_model) | |
if not os.path.exists(out_dir): | |
os.makedirs(out_dir) | |
return out_dir | |
def generate_dataset(out_model: str) -> int: | |
""" | |
Coordinator to build the training data. Generates all available Q&A pairs | |
then formats them to llama format and saves them to the training csv file. | |
:return Count of lines written to the training data | |
""" | |
training_file = os.path.join(fine_tuning_out_dir(out_model), 'train.csv') | |
lines = [] | |
generators = [ | |
CategoryDataGenerator(), | |
ProductDescriptionDataGenerator(), | |
PriceDataGenerator(), | |
FeatureDataGenerator() | |
] | |
line_count = 0 | |
for g in generators: | |
for q, a in g.generate(): | |
line = training_string_from_q_and_a(q, a) | |
lines.append(line) | |
line_count += 1 | |
df = pd.DataFrame(lines, columns=['text']) | |
df.to_csv(training_file, index=False) | |
return line_count | |
def generate_training_scripts(out_model: str, hf_user: str, hf_token: str) -> None: | |
""" | |
Generates the shell script to run to actually train the model | |
""" | |
shell_file = os.path.join(fine_tuning_out_dir(out_model), 'train.zsh') | |
with open(shell_file, "w") as f: | |
f.write("#!/bin/zsh\n\n") | |
f.write("# DO NOT COMMIT THIS FILE TO GIT AS IT CONTAINS THE HUGGING FACE WRITE TOKEN FOR THE REPO\n\n") | |
f.write('echo "STARTING TRAINING AND PUSH TO HUB"\n') | |
f.write("start=$(date +%s)\n") | |
f.write(f"autotrain llm --train --project-name {out_model} --model meta-llama/Llama-2-7b-chat-hf --data-path . --peft --lr 2e-4 --batch-size 12 --epochs 3 --trainer sft --merge-adapter --push-to-hub --username {hf_user} --token {hf_token}\n") | |
f.write("end=$(date +%s)\n") | |
f.write('echo "TRAINING AND PUSH TOOK $(($end-$start)) seconds"') | |
stats = os.stat(shell_file) | |
os.chmod(shell_file, stats.st_mode | stat.S_IEXEC) | |
if __name__ == "__main__": | |
start = time() | |
if args.products_db == DataLoader.active_db: | |
DataLoader.load_data() | |
else: | |
DataLoader.set_db_name(args.products_db) | |
line_count = generate_dataset(args.fine_tuned_model) | |
generate_training_scripts(args.fine_tuned_model, args.hf_user, args.hf_token) | |
end = time() | |
elapsed = (int((end - start) * 10)) / 10 # round to 1dp | |
print(f"Generated {line_count} training examples and the training script in {elapsed} seconds.") | |