""" Not used as part of the streamlit app, but run offline to prepare training for fine-tuning. """ import argparse import os import stat import pandas as pd from abc import ABC, abstractmethod from copy import copy from random import choice, shuffle from time import time from typing import Tuple, Generator from src.common import data_dir, join_items_comma_and, random_true_false, pop_n from src.datatypes import * parser = argparse.ArgumentParser(prog="prep_finetuning", description="Fine tune a llama 2 model and push to hugging face hub for serving") parser.add_argument('-base_model', required=True, help="The base model to use") parser.add_argument('-products_db', required=True, help="The products sqlite to train on") parser.add_argument('-fine_tuned_model', required=True, help="The target model name in hugging face hub") parser.add_argument('-hf_user', required=True, help="The hugging face user to write the model to the hub") parser.add_argument('-hf_token', required=True, help="The hugging face token to write the model to the hub") args = parser.parse_args() class TrainingDataGenerator(ABC): """ Abstract class to generate fine-tuning training data. Implemented as a generator to minimise passing around large lists unnecessarily """ @abstractmethod def generate(self) -> Generator[Tuple[str, str], None, None]: """ Required to be implemented by the generator implementation. :return: should yield pairs of training data as question, answer """ pass class CategoryDataGenerator(TrainingDataGenerator): """ Concrete implementation to build training data about product categories """ def generate(self) -> Generator[Tuple[str, str], None, None]: # 1. First build "what do you offer" type queries cat_names = [c.name for c in Category.all.values()] category_synonyms = ["types", "categories", "kinds"] for _ in range(5): shuffle(cat_names) cat = choice(category_synonyms) q = f"What {cat} of products do you offer?" a = f"ElectroHome offers {join_items_comma_and(cat_names)}." yield q, a # 2. Now build some "product in category" type queries for c in Category.all.values(): prod_names = [p.name for p in c.products] total_prod_count = len(prod_names) for _ in range(1): working_prod_names = copy(prod_names) shuffle(working_prod_names) while len(working_prod_names) > 0: ans_product_names = pop_n(working_prod_names, 3) q = f"What {c.name} do you have?" a = f"We have {total_prod_count} {c.name}. For example {join_items_comma_and(ans_product_names)}. What are you lookinh for in a {c.name[:-1]}, and I can help to guide you?" yield q, a class ProductDescriptionDataGenerator(TrainingDataGenerator): """ Concrete implementation to build training data from the marketing description and price """ def generate(self) -> Generator[Tuple[str, str], None, None]: for p in Product.all.values(): question_templates = [ "Tell me about the #.", "Describe the # for me.", "What can you tell me about the #?" ] q = choice(question_templates).replace("#", p.name) # Mix up the paths to include price or not and at the start/end of the response if random_true_false(): a = p.description else: if random_true_false(): a = f"{p.description} It typically retails for ${p.price}." else: a = f"The {p.name} typically retails for ${p.price}. {p.description}" yield q, a class PriceDataGenerator(TrainingDataGenerator): """ Concrete implementation to build training data just for pricing """ def generate(self) -> Generator[Tuple[str, str], None, None]: for p in Product.all.values(): question_templates = [ "How much is the #?", "What does the # cost?", "What's the price of the #?" ] q = choice(question_templates).replace("#", p.name) answer_templates = [ "It typically retails for $#.", "Our recommended retail price is $#, but do check with your stockist.", "The list price is $# but check in our online store if there are any offers on (www.electrohome.com)." ] a = choice(answer_templates).replace("#", str(p.price)) yield q, a class FeatureDataGenerator(TrainingDataGenerator): """ Concrete implementation to build training data just for pricing """ def generate(self) -> Generator[Tuple[str, str], None, None]: # 1. First generate Q&A for what features are available by category for c in Category.all.values(): cat_features = [f.name for f in c.features] for _ in range(1): working_cat_features = copy(cat_features) shuffle(working_cat_features) while len(working_cat_features) > 0: some_features = pop_n(working_cat_features, 3) feature_clause = join_items_comma_and(some_features) question_templates = [ "What features do your # have?", "What should I think about when I'm considering #?", "What sort of things differentiate your #?" ] q = choice(question_templates).replace("#", c.name) answer_templates = [ "Our # have features like ##.", "You might want to consider things like ## which # offer.", "# have lots of different features, like ##." ] a = choice(answer_templates).replace("##", feature_clause).replace("#", c.name) yield q, a # 2. Now generate questions the other way around - i.e. search products by feature for f in Feature.all.values(): cat_name = f.category.name prod_names = [p.name for p in f.products] for _ in range(1): working_prod_names = copy(prod_names) while len(working_prod_names) > 0: some_prods = pop_n(working_prod_names, 3) if len(some_prods) > 1: # Single product examples mess up some trainind data prod_clause = join_items_comma_and(some_prods) q = f"Which {cat_name} offer {f.name}?" answer_templates = [ "## are # which offer ###.", "## have ###.", "We have some great # which offer ### including our ##." ] a = choice(answer_templates).replace("###", f.name).replace("##", prod_clause).replace("#", cat_name) yield q, a else: q = f"Which {cat_name} offer {f.name}?" a = f"The {some_prods[0]} has {f.name}." yield q, a def training_string_from_q_and_a(q: str, a: str, sys_prompt: str = None) -> str: """ Build the single llama formatted training string from a question answer pair """ return f'User: {q}\nBot: {a}' def fine_tuning_out_dir(out_model: str) -> str: """ Utility to generate the full path to the output directory, creating it if it is not there """ out_dir = os.path.join(data_dir, 'fine_tuning', out_model) if not os.path.exists(out_dir): os.makedirs(out_dir) return out_dir def generate_dataset(out_model: str) -> int: """ Coordinator to build the training data. Generates all available Q&A pairs then formats them to llama format and saves them to the training csv file. :return Count of lines written to the training data """ training_file = os.path.join(fine_tuning_out_dir(out_model), 'train.csv') lines = [] generators = [ CategoryDataGenerator(), ProductDescriptionDataGenerator(), PriceDataGenerator(), FeatureDataGenerator() ] line_count = 0 for g in generators: for q, a in g.generate(): line = training_string_from_q_and_a(q, a) lines.append(line) line_count += 1 df = pd.DataFrame(lines, columns=['text']) df.to_csv(training_file, index=False) return line_count def generate_training_scripts(out_model: str, hf_user: str, hf_token: str) -> None: """ Generates the shell script to run to actually train the model """ shell_file = os.path.join(fine_tuning_out_dir(out_model), 'train.zsh') with open(shell_file, "w") as f: f.write("#!/bin/zsh\n\n") f.write("# DO NOT COMMIT THIS FILE TO GIT AS IT CONTAINS THE HUGGING FACE WRITE TOKEN FOR THE REPO\n\n") f.write('echo "STARTING TRAINING AND PUSH TO HUB"\n') f.write("start=$(date +%s)\n") f.write(f"autotrain llm --train --project-name {out_model} --model meta-llama/Llama-2-7b-chat-hf --data-path . --peft --lr 2e-4 --batch-size 12 --epochs 3 --trainer sft --merge-adapter --push-to-hub --username {hf_user} --token {hf_token}\n") f.write("end=$(date +%s)\n") f.write('echo "TRAINING AND PUSH TOOK $(($end-$start)) seconds"') stats = os.stat(shell_file) os.chmod(shell_file, stats.st_mode | stat.S_IEXEC) if __name__ == "__main__": start = time() if args.products_db == DataLoader.active_db: DataLoader.load_data() else: DataLoader.set_db_name(args.products_db) line_count = generate_dataset(args.fine_tuned_model) generate_training_scripts(args.fine_tuned_model, args.hf_user, args.hf_token) end = time() elapsed = (int((end - start) * 10)) / 10 # round to 1dp print(f"Generated {line_count} training examples and the training script in {elapsed} seconds.")