Spaces:
Runtime error
Runtime error
| import torch | |
| from dataclasses import dataclass | |
| from accelerate import PartialState | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser | |
| from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format | |
| from kto_dataset_processor import process_dataset_ultrafeedback | |
| from datetime import datetime | |
| import wandb | |
| #################################### | |
| # CONFIGURATION | |
| #################################### | |
| class ScriptArguments: | |
| """ | |
| Configuration for the script. | |
| """ | |
| process_dataset_func: callable = process_dataset_ultrafeedback # process_dataset function from kto_dataset_processor.py | |
| checkpoint_path: str = None # Checkpoint path | |
| push_to_hub: bool = False # Whether to push the model to the Hugging Face hub | |
| class ModelArguments(ModelConfig): | |
| """ | |
| Configuration for the model. | |
| """ | |
| model_name: str = "HuggingFaceH4/zephyr-7b-beta" | |
| use_peft: bool = True | |
| lora_target_modules: str = "all-linear" | |
| lora_r: int = 16 | |
| lora_alpha: int = 16 | |
| class TrainingArguments(KTOConfig): | |
| """ | |
| Configuration for the KTO trainer. | |
| """ | |
| output_dir: str = f"kto_{ModelArguments.model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" | |
| num_train_epochs: int = 1 | |
| per_device_train_batch_size: int = 4 # Highest that runs well | |
| learning_rate: float = 5e-7 | |
| lr_scheduler_type: str = "cosine" | |
| gradient_accumulation_steps: int = 1 | |
| logging_steps: int = 10 | |
| eval_steps: int = 500 | |
| warmup_ratio: float = 0.1 | |
| bf16: bool = True | |
| logging_first_step: bool = True | |
| # Initialize configurations | |
| script_args = ScriptArguments() | |
| training_args = TrainingArguments() | |
| model_args = ModelArguments() | |
| #################################### | |
| # HELPER FUNCTIONS | |
| #################################### | |
| def load_model_and_tokenizer(model_args): | |
| """ | |
| Load a model and tokenizer from a specified path. | |
| """ | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_args.model_name, | |
| trust_remote_code=model_args.trust_remote_code, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_args.model_name, | |
| trust_remote_code=model_args.trust_remote_code | |
| ) | |
| # Set pad token if missing | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Setup chat format if not present | |
| if tokenizer.chat_template is None: | |
| model, tokenizer = setup_chat_format(model, tokenizer) | |
| return model, tokenizer | |
| # def find_unknown_tokens(tokenizer, texts): | |
| # """ | |
| # Identify tokens in the dataset that are not in the tokenizer's vocabulary. | |
| # """ | |
| # all_tokens = set() | |
| # for text in texts: | |
| # tokens = tokenizer.tokenize(text) | |
| # all_tokens.update(tokens) | |
| # vocab = set(tokenizer.get_vocab().keys()) | |
| # unknown_tokens = all_tokens - vocab | |
| # return unknown_tokens | |
| # def add_tokens_to_tokenizer(tokenizer, model, dataset): | |
| # """ | |
| # Extend the tokenizer's vocabulary with missing tokens and resize the model embeddings. | |
| # """ | |
| # # Extract all texts from the dataset | |
| # texts = [example["completion"] for example in dataset["train"]] | |
| # # Identify unknown tokens | |
| # unknown_tokens = find_unknown_tokens(tokenizer, texts) | |
| # print(f"Found {len(unknown_tokens)} unknown tokens: {list(unknown_tokens)[:10]}...") | |
| # # Add unknown tokens to tokenizer | |
| # tokenizer.add_tokens(list(unknown_tokens)) | |
| # model.resize_token_embeddings(len(tokenizer)) | |
| # print(f"Tokenizer vocabulary size after extension: {len(tokenizer)}") | |
| #################################### | |
| # MAIN LOGIC | |
| #################################### | |
| def main(): | |
| # Initialize wandb | |
| wandb.init(project="kto") | |
| # Load models and tokenizer | |
| print("Loading models and tokenizer...") | |
| model, tokenizer = load_model_and_tokenizer(model_args) | |
| ref_model, _ = load_model_and_tokenizer(model_args) | |
| print("Models and tokenizer loaded.") | |
| # Load and process datasets using external function | |
| print("Processing dataset...") | |
| dataset = process_dataset_ultrafeedback() | |
| print("Dataset processed.") | |
| # # Extend tokenizer with missing tokens | |
| # print("Adding unknown tokens to tokenizer...") | |
| # add_tokens_to_tokenizer(tokenizer, model, dataset) | |
| # print("Tokenizer updated.") | |
| # Initialize trainer | |
| print("Initializing trainer...") | |
| trainer = KTOTrainer( | |
| model=model, | |
| ref_model=ref_model, | |
| args=training_args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["test"], | |
| tokenizer=tokenizer, | |
| peft_config=get_peft_config(model_args), | |
| ) | |
| # Training | |
| print("Starting training...") | |
| trainer.train() | |
| print("Training completed.") | |
| # Evaluation | |
| print("Evaluating model...") | |
| metrics = trainer.evaluate() | |
| print(f"Metrics: {metrics}") | |
| trainer.log_metrics("eval", metrics) | |
| trainer.save_metrics("eval", metrics) | |
| # Log metrics to wandb | |
| wandb.log({ | |
| "epoch": metrics.get("epoch"), | |
| "grad_norm": metrics.get("grad_norm"), | |
| "kl": metrics.get("kl"), | |
| "learning_rate": metrics.get("learning_rate"), | |
| "logits/chosen": metrics.get("logits/chosen"), | |
| "logits/rejected": metrics.get("logits/rejected"), | |
| "logps/chosen": metrics.get("logps/chosen"), | |
| "logps/rejected": metrics.get("logps/rejected"), | |
| "loss": metrics.get("loss"), | |
| "rewards/chosen": metrics.get("rewards/chosen"), | |
| "rewards/margins": metrics.get("rewards/margins"), | |
| "rewards/rejected": metrics.get("rewards/rejected"), | |
| "step": metrics.get("step") | |
| }) | |
| # Save model and optionally push to hub | |
| trainer.save_model(training_args.output_dir) | |
| if script_args.push_to_hub: | |
| trainer.push_to_hub() | |
| print("Process completed.") | |
| # Finish wandb run | |
| wandb.finish() | |
| if __name__ == "__main__": | |
| main() | |