Spaces:
Running
Running
from typing import Tuple, List | |
from my_model.utilities.gen_utilities import is_pycharm | |
import seaborn as sns | |
from transformers import AutoTokenizer | |
from datasets import Dataset, load_dataset | |
import my_model.config.fine_tuning_config as config | |
from my_model.LLAMA2.LLAMA2_model import Llama2ModelManager | |
class FinetuningDataHandler: | |
""" | |
A class dedicated to handling data for fine-tuning LLaMA-2 Chat models. It manages loading, | |
inspecting, preparing, and splitting the dataset, specifically designed to filter out | |
data samples exceeding a specified token count limit. This is crucial for models with | |
token count constraints and it helps control the level of GPU RAM tolerance based on the number of tokens, | |
ensuring efficient and effective model fine-tuning. | |
Attributes: | |
tokenizer (AutoTokenizer): Tokenizer used for tokenizing the dataset. | |
dataset_file (str): File path to the dataset. | |
max_token_count (int): Maximum allowable token count per data sample. | |
Methods: | |
load_llm_tokenizer: Loads the LLM tokenizer and adds special tokens, if not already loaded. | |
load_dataset: Loads the dataset from a specified file path. | |
plot_tokens_count_distribution: Plots the distribution of token counts in the dataset. | |
filter_dataset_by_indices: Filters the dataset based on valid indices, removing samples exceeding token limits. | |
get_token_counts: Calculates token counts for each sample in the dataset. | |
prepare_dataset: Tokenizes and filters the dataset, preparing it for training. Also visualizes token count | |
distribution before and after filtering. | |
split_dataset_for_train_eval: Divides the dataset into training and evaluation sets. | |
inspect_prepare_split_data: Coordinates the data preparation and splitting process for fine-tuning. | |
""" | |
def __init__(self, tokenizer: AutoTokenizer = None, dataset_file: str = config.DATASET_FILE) -> None: | |
""" | |
Initializes the FinetuningDataHandler class. | |
Args: | |
tokenizer (AutoTokenizer, optional): Tokenizer to use for tokenizing the dataset. Defaults to None. | |
dataset_file (str): Path to the dataset file. Defaults to config.DATASET_FILE. | |
""" | |
self.tokenizer = tokenizer # The tokenizer used for processing the dataset. | |
self.dataset_file = dataset_file # Path to the fine-tuning dataset file. | |
self.max_token_count = config.MAX_TOKEN_COUNT # Max token count for filtering set to 1,024. | |
def load_llm_tokenizer(self) -> None: | |
""" | |
Loads the LLM tokenizer and adds special tokens, if not already loaded. | |
If the tokenizer is already loaded, this method does nothing. | |
Returns: | |
None | |
""" | |
if self.tokenizer is None: | |
llm_manager = Llama2ModelManager() # Initialize Llama2 model manager. | |
# we only need the tokenizer for the data inspection not the model itself. | |
self.tokenizer = llm_manager.load_tokenizer() | |
llm_manager.add_special_tokens() # Add special tokens specific to LLAMA2 vocab for efficient tokenization. | |
def load_dataset(self) -> Dataset: | |
""" | |
Loads the dataset from the specified file path. The dataset is expected to be in CSV format. | |
Returns: | |
Dataset: The loaded dataset, ready for processing. | |
""" | |
return load_dataset('csv', data_files=self.dataset_file) | |
def plot_tokens_count_distribution(self, token_counts: List[int], title: str = "Token Count Distribution") -> None: | |
""" | |
Plots the distribution of token counts in the dataset for visualization purposes. | |
Args: | |
token_counts (List[int]): List of token counts, each count representing the number of tokens in a dataset | |
sample. | |
title (str): Title for the plot, highlighting the nature of the distribution. | |
Returns: | |
None | |
""" | |
if is_pycharm(): # Ensuring compatibility with PyCharm's environment for interactive plot. | |
import matplotlib # The import is kept here intentionaly. | |
matplotlib.use('TkAgg') # Set the backend to 'TkAgg' | |
import matplotlib.pyplot as plt # The import is kept here intentionaly. | |
sns.set_style("whitegrid") | |
plt.figure(figsize=(15, 6)) | |
plt.hist(token_counts, bins=50, color='#3498db', edgecolor='black') | |
plt.title(title, fontsize=16) | |
plt.xlabel("Number of Tokens", fontsize=14) | |
plt.ylabel("Number of Samples", fontsize=14) | |
plt.xticks(fontsize=12) | |
plt.yticks(fontsize=12) | |
plt.tight_layout() | |
plt.show() | |
def filter_dataset_by_indices(self, dataset: Dataset, valid_indices: List[int]) -> Dataset: | |
""" | |
Filters the dataset based on a list of valid indices. This method is used to exclude | |
data samples that have a token count exceeding the specified maximum token count. | |
Args: | |
dataset (Dataset): The dataset to be filtered. | |
valid_indices (List[int]): Indices of samples with token counts within the limit. | |
Returns: | |
Dataset: Filtered dataset containing only samples with valid indices. | |
""" | |
return dataset['train'].select(valid_indices) # Select only samples with valid indices based on token count. | |
def get_token_counts(self, dataset: Dataset) -> List[int]: | |
""" | |
Calculates and returns the token counts for each sample in the dataset. | |
This function assumes the dataset has a 'train' split and a 'text' field. | |
Args: | |
dataset (Dataset): The dataset for which to count tokens. | |
Returns: | |
List[int]: List of token counts per sample in the dataset. | |
""" | |
if 'train' in dataset: | |
return [len(self.tokenizer.tokenize(s)) for s in dataset["train"]["text"]] | |
else: | |
# After filtering the samples with unacceptable token count, the dataset is | |
# already `dataset = dataset['train']`. | |
return [len(self.tokenizer.tokenize(s)) for s in dataset["text"]] | |
def prepare_dataset(self) -> Tuple[Dataset, Dataset]: | |
""" | |
Prepares the dataset for fine-tuning by tokenizing the data and filtering out samples | |
that exceed the maximum used context window (configurable through max_token_count). | |
It also visualizes the token count distribution before and after filtering. | |
Returns: | |
Tuple[Dataset, Dataset]: The train and evaluate datasets, post-filtering. | |
""" | |
dataset = self.load_dataset() | |
self.load_llm_tokenizer() | |
# Count tokens in each dataset sample before filtering | |
token_counts_before_filtering = self.get_token_counts(dataset) | |
# Plot token count distribution before filtering for visualization. | |
self.plot_tokens_count_distribution(token_counts_before_filtering, "Token Count Distribution Before Filtration") | |
# Identify valid indices based on max token count. | |
valid_indices = [i for i, count in enumerate(token_counts_before_filtering) if count <= self.max_token_count] | |
# Filter the dataset to exclude samples with excessive token counts. | |
filtered_dataset = self.filter_dataset_by_indices(dataset, valid_indices) | |
token_counts_after_filtering = self.get_token_counts(filtered_dataset) | |
self.plot_tokens_count_distribution(token_counts_after_filtering, "Token Count Distribution After Filtration") | |
return self.split_dataset_for_train_eval(filtered_dataset) # split the dataset into training and evaluation. | |
def split_dataset_for_train_eval(self, dataset: Dataset) -> Tuple[Dataset, Dataset]: | |
""" | |
Splits the dataset into training and evaluation datasets. | |
Args: | |
dataset (Dataset): The dataset to split. | |
Returns: | |
Tuple[Dataset, Dataset]: The split training and evaluation datasets. | |
""" | |
split_data = dataset.train_test_split(test_size=config.TEST_SIZE, shuffle=True, seed=config.SEED) | |
train_data, eval_data = split_data['train'], split_data['test'] | |
return train_data, eval_data | |
def inspect_prepare_split_data(self) -> Tuple[Dataset, Dataset]: | |
""" | |
Orchestrates the process of inspecting, preparing, and splitting the dataset for fine-tuning. | |
Returns: | |
Tuple[Dataset, Dataset]: The prepared training and evaluation datasets. | |
""" | |
return self.prepare_dataset() | |
# Example usage | |
if __name__ == "__main__": | |
# Please uncomment the below lines to test the data prep. | |
# data_handler = FinetuningDataHandler() | |
# fine_tuning_data_train, fine_tuning_data_eval = data_handler.inspect_prepare_split_data() | |
# print(fine_tuning_data_train, fine_tuning_data_eval) | |
pass | |