# Import spaces first to ensure GPU resources are managed correctly import spaces # Import necessary libraries import os import json import logging import time import torch import bitsandbytes as bnb from datasets import Dataset from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, TrainingArguments, Trainer from peft import PeftModel, LoraConfig from transformers import BitsAndBytesConfig # Configure logging logging.basicConfig(level=logging.INFO, filename='training_log.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s') logging.info("Started the script") # Load the Hugging Face API token from environment variables HF_API_TOKEN = os.getenv('HF_API_TOKEN') # Load the dataset file_path = 'best_training_data.json' # Adjust path as needed logging.info(f"Loading dataset from {file_path}") try: with open(file_path, 'r') as file: data = json.load(file) logging.info("Dataset loaded successfully") except Exception as e: logging.error(f"Failed to load dataset: {e}") # Convert the dataset to Hugging Face Dataset format try: dataset = Dataset.from_dict({"text": [entry["text"] for entry in data]}) logging.info("Dataset converted to Hugging Face Dataset format") except Exception as e: logging.error(f"Failed to convert dataset: {e}") # Initialize Tokenizer try: tokenizer = AutoTokenizer.from_pretrained("SweatyCrayfish/llama-3-8b-quantized", token=HF_API_TOKEN) logging.info("Tokenizer loaded successfully") # Add padding token if not already present if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) logging.info("Padding token added to the tokenizer") tokenizer.save_pretrained('.') except Exception as e: logging.error(f"Failed to load or configure tokenizer: {e}") # Tokenize the Dataset def tokenize_function(examples): return tokenizer(examples["text"], truncation=True, padding='max_length', max_length=1024, return_tensors='pt') try: tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"]) logging.info("Dataset tokenized successfully") except Exception as e: logging.error(f"Failed to tokenize the dataset: {e}") # Setup Quantization Configuration nf4_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 ) # Load the LLaMA 8B Model with Quantization try: model = AutoModelForCausalLM.from_pretrained( "SweatyCrayfish/llama-3-8b-quantized", quantization_config=nf4_config, token=HF_API_TOKEN, device_map="auto" ) model.resize_token_embeddings(len(tokenizer)) model.gradient_checkpointing_enable() model.config.use_cache = False # Disable use_cache when using gradient checkpointing logging.info("Model initialized and resized embeddings") # Set up LoRa lora_config = LoraConfig( r=64, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM", target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'] ) model = PeftModel(model, lora_config) logging.info("LoRa configuration applied to the model") # Ensure only floating point parameters require gradients for param in model.parameters(): if param.dtype in [torch.float16, torch.float32, torch.bfloat16, torch.complex64, torch.complex128]: param.requires_grad = True logging.info("Model parameters configured for gradient computation") except Exception as e: logging.error(f"Failed to initialize the model: {e}") # Setup Training Arguments try: training_args = TrainingArguments( output_dir="training_results", evaluation_strategy="no", # Disable evaluation save_strategy="epoch", # Save only at the end of each epoch learning_rate=2e-4, per_device_train_batch_size=5, gradient_accumulation_steps=4, num_train_epochs=12, weight_decay=0.01, save_total_limit=1, logging_dir="training_logs", logging_steps=50, fp16=False, bf16=True, load_best_model_at_end=False, # Do not load the best model greater_is_better=False, report_to="none" # Disable reporting to external services ) logging.info("Training arguments configured successfully") except Exception as e: logging.error(f"Failed to configure training arguments: {e}") # Initialize the Trainer try: data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=data_collator ) logging.info("Trainer initialized successfully") except Exception as e: logging.error(f"Failed to initialize the Trainer: {e}") # Implementing 120-Second Segmented Training @spaces.GPU(duration=120) def segmented_train(trainer): start_time = time.time() while time.time() - start_time < 120: try: trainer.train() except torch.cuda.OutOfMemoryError as e: logging.error(f"Out of memory error: {e}") break except Exception as e: logging.error(f"Training error: {e}") break trainer.save_state() try: segmented_train(trainer) logging.info("Model training completed successfully") except Exception as e: logging.error(f"Training failed: {e}") import traceback traceback.print_exc() # Save the Model try: model.save_pretrained("llama3-8b-chat-finetuned-final-version") tokenizer.save_pretrained("llama3-8b-chat-finetuned-final-version") logging.info("Final fine-tuned model and tokenizer saved successfully") except Exception as e: logging.error(f"Failed to save the final fine-tuned model: {e}") # Inference Function @spaces.GPU def generate_response(prompt, model, tokenizer, max_length=128, min_length=20, temperature=0.7, top_k=50, top_p=0.9): try: inputs = tokenizer(prompt, return_tensors="pt").to("cuda") with torch.no_grad(): outputs = model.generate( inputs.input_ids, max_length=max_length, min_length=min_length, do_sample=True, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=1.3, no_repeat_ngram_size=3, eos_token_id=tokenizer.eos_token_id ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response except Exception as e: logging.error(f"Failed to generate response: {e}") return "" # Example Usage prompt = "bro did u talk with DK today" response = generate_response(prompt, model, tokenizer) print(response) logging.info(f"Generated response: {response}")