George-API commited on
Commit
5b6d8f0
·
verified ·
1 Parent(s): 1cf4e07

Upload folder using huggingface_hub

Browse files
run_transformers_training.py CHANGED
@@ -123,30 +123,22 @@ def load_env_variables():
123
  os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN")
124
 
125
  def load_configs(base_path):
126
- """Load all configuration from a single consolidated file."""
127
- configs = {}
128
-
129
  # Using a single consolidated config file
130
- config_file = "transformers_config.json"
131
 
132
- file_path = os.path.join(base_path, config_file)
133
  try:
134
- with open(file_path, "r") as f:
135
  config = json.load(f)
136
- # Extract sections into separate config dictionaries for compatibility
137
- configs["transformers"] = config
138
- configs["hardware"] = config.get("hardware", {})
139
- configs["dataset"] = config.get("dataset", {})
140
- logger.info(f"Loaded consolidated configuration from {file_path}")
141
  except Exception as e:
142
  logger.error(f"Error loading {config_file}: {e}")
143
  raise
144
-
145
- return configs
146
 
147
  def parse_args():
148
  parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset")
149
- parser.add_argument("--config_dir", type=str, default=".", help="Directory containing configuration files")
150
  return parser.parse_args()
151
 
152
  def load_model_and_tokenizer(config):
@@ -157,8 +149,8 @@ def load_model_and_tokenizer(config):
157
  logger.error("Please ensure unsloth is in requirements.txt")
158
  raise ImportError("Unsloth is required for this training setup")
159
 
160
- # Get model name correctly from nested config structure
161
- model_name = config.get("model", {}).get("name") or config.get("model_name_or_path") or config.get("model_name")
162
  logger.info(f"Loading model: {model_name}")
163
 
164
  if not model_name:
@@ -166,14 +158,12 @@ def load_model_and_tokenizer(config):
166
 
167
  logger.info("Using Unsloth optimizations with pre-quantized model")
168
 
169
- # Check for flash attention without importing it directly
170
  use_flash_attention = config.get("use_flash_attention", True)
171
- try:
172
- import flash_attn
173
- logger.info("Flash attention detected and will be used")
174
- except ImportError:
175
  use_flash_attention = False
176
- logger.warning("Flash attention not available, falling back to standard attention")
177
 
178
  # First detect if we have a GPU
179
  if torch.cuda.is_available():
@@ -321,13 +311,24 @@ def load_dataset_with_mapping(dataset_config):
321
 
322
  # Add prompt_number field that increments based on original order
323
  def add_prompt_numbers(examples, indices):
324
- # Defensive check to ensure indices is not None
325
  if indices is None:
326
  logger.warning("Warning: indices is None in add_prompt_numbers, using empty list")
327
  indices = []
 
 
 
 
328
 
329
- # Create a new field with the dataset index as the prompt number, starting at 1
330
- examples["prompt_number"] = [idx + 1 for idx in indices] # Adding 1 to make it 1-indexed
 
 
 
 
 
 
 
331
  return examples
332
 
333
  # Add prompt numbers to the dataset based on original order
@@ -358,37 +359,73 @@ def load_dataset_with_mapping(dataset_config):
358
  dataset = Dataset.from_list(updated_examples)
359
  logger.info(f"Successfully added prompt_number field using fallback method")
360
 
361
- # Verify expected columns exist
362
- expected_columns = {"id", "conversations"}
363
- for col in expected_columns:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  if col not in dataset.column_names:
365
- # If "conversations" is missing but "text" exists, it might need conversion
366
- if col == "conversations" and "text" in dataset.column_names:
367
- logger.info("Converting 'text' field to 'conversations' format")
368
-
369
- def convert_text_to_conversations(example):
370
- # Check if text is already a list of conversation turns
371
- if isinstance(example.get("text"), list):
372
- return {"conversations": example["text"]}
373
- # Otherwise, create a simple conversation with the text as user message
374
- else:
375
- return {
376
- "conversations": [
377
- {"role": "user", "content": str(example.get("text", ""))}
378
- ]
379
- }
380
-
381
- dataset = dataset.map(convert_text_to_conversations)
382
- else:
383
- logger.warning(f"Expected column '{col}' not found in dataset")
384
 
385
- # Note: Explicitly NOT sorting the dataset to preserve original order
386
- logger.info("Preserving original dataset order (no sorting)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
- # Check data ordering requirements
389
  processing_config = dataset_config.get("dataset", {}).get("processing", {})
390
  data_loading_config = dataset_config.get("data_loading", {})
391
 
 
 
 
 
 
 
 
 
392
  # Flag consolidation - we only need one flag to control sequence preservation
393
  # Default to True to ensure safety
394
  preserve_sequence = processing_config.get("preserve_entry_sequence", True)
@@ -450,17 +487,18 @@ def load_dataset_with_mapping(dataset_config):
450
  logger.warning(f"Error accessing dataset at index {i}: {e}")
451
 
452
  if sample_examples:
453
- if all(isinstance(example.get('id', ''), (int, str)) for example in sample_examples):
454
- sample_ids = [example.get('id', '') for example in sample_examples if 'id' in example]
 
455
 
456
  if sample_ids and all(isinstance(id, int) or (isinstance(id, str) and id.isdigit()) for id in sample_ids):
457
  numeric_ids = [int(id) if isinstance(id, str) else id for id in sample_ids]
458
  if len(numeric_ids) > 1:
459
  is_ordered = all(numeric_ids[i] <= numeric_ids[i+1] for i in range(len(numeric_ids)-1))
460
  if not is_ordered:
461
- logger.warning("WARNING: Sample IDs are not in sequential order.")
462
  else:
463
- logger.info("Sample IDs appear to be in sequential order.")
464
  except Exception as e:
465
  logger.warning(f"Error checking ID sequence: {e}")
466
  except Exception as e:
@@ -472,19 +510,19 @@ def load_dataset_with_mapping(dataset_config):
472
  # Safely get first few samples
473
  first_few_indices = range(min(5, len(dataset)))
474
  sample_prompt_numbers = []
475
- sample_ids = []
476
 
477
  for i in first_few_indices:
478
  try:
479
  example = dataset[i]
480
  if 'prompt_number' in example:
481
  sample_prompt_numbers.append(example['prompt_number'])
482
- if 'id' in example:
483
- sample_ids.append(example['id'])
484
  except Exception as e:
485
  logger.warning(f"Error accessing sample at index {i}: {e}")
486
 
487
- logger.info(f"First few samples - Prompt numbers: {sample_prompt_numbers}, IDs: {sample_ids}")
488
 
489
  # Log conversation structure without full content
490
  if len(dataset) > 0:
@@ -510,6 +548,74 @@ def load_dataset_with_mapping(dataset_config):
510
 
511
  logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
512
  logger.info(f"Dataset columns: {dataset.column_names}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  return dataset
514
 
515
  except Exception as e:
@@ -752,13 +858,13 @@ class LoggingCallback(TrainerCallback):
752
  is_sequence_maintained = False
753
 
754
  # Also compare IDs as a backup check
755
- elif ('id' in orig_sample and
756
- 'id' in current_sample and
757
- orig_sample['id'] is not None and
758
- current_sample['id'] is not None):
759
 
760
- if orig_sample['id'] != current_sample['id']:
761
- log_info(f"WARNING: Sequence integrity compromised! Sample {i} ID changed from {orig_sample['id']} to {current_sample['id']}")
762
  is_sequence_maintained = False
763
 
764
  # Compare input fingerprints
@@ -899,12 +1005,11 @@ def check_dependencies():
899
  missing_packages.append("peft>=0.9.0")
900
 
901
  # Optional packages - don't add to missing list, just log
902
- try:
903
- import flash_attn
904
  logger.info("flash-attn found. Flash attention will be used for faster training.")
905
- except ImportError:
906
  logger.warning("flash-attn not found. Training will work but may be slower.")
907
- # Don't add to missing packages since it's optional and can cause build issues
908
 
909
  # If critical packages are missing, exit with instructions
910
  if missing_packages:
@@ -918,115 +1023,44 @@ def check_dependencies():
918
 
919
  def main():
920
  # Set up logging
921
- log_info("Starting Phi-4 fine-tuning process")
922
-
923
- # Log hardware information
924
- log_info(f"Hardware detection: CUDA {'available' if CUDA_AVAILABLE else 'not available'}")
925
- if CUDA_AVAILABLE:
926
- log_info(f"Found {NUM_GPUS} GPUs")
927
- for i in range(NUM_GPUS):
928
- log_info(f" GPU {i}: {torch.cuda.get_device_name(i)}")
929
- else:
930
- log_info("Running on CPU (training will be very slow)")
931
 
932
  # Parse arguments
933
  args = parse_args()
934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935
  # Check dependencies
936
  if not check_dependencies():
937
  logger.error("Aborting due to missing critical dependencies")
938
  return 1
939
 
940
- # Load environment variables
941
- load_env_variables()
942
-
943
  # Check if we're in distributed mode
944
  is_distributed = "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1
945
  if is_distributed:
946
- log_info(f"Running in distributed mode with world size: {os.environ.get('WORLD_SIZE')}")
 
947
  else:
948
  log_info("Running in non-distributed mode (single process)")
949
 
950
- # Load all configurations - do this once
951
- try:
952
- configs = load_configs(args.config_dir)
953
-
954
- # Extract specific configs immediately after loading
955
- if not configs:
956
- logger.error("Failed to load configuration")
957
- return 1
958
-
959
- # Store configurations in clear variables
960
- transformers_config = configs.get("transformers", {})
961
- hardware_config = configs.get("hardware", {})
962
- dataset_config = configs.get("dataset", {})
963
-
964
- # Verify configuration sections exist
965
- if not transformers_config:
966
- logger.error("transformers_config.json not found or invalid")
967
- return 1
968
-
969
- if not hardware_config:
970
- logger.warning("Hardware configuration section not found in transformers_config.json. Using default hardware configuration.")
971
-
972
- if not dataset_config:
973
- logger.error("Dataset configuration section not found in transformers_config.json")
974
- return 1
975
-
976
- # Validate model configuration
977
- model_name = (transformers_config.get("model", {}).get("name") or
978
- transformers_config.get("model_name_or_path") or
979
- transformers_config.get("model_name"))
980
-
981
- if not model_name:
982
- logger.error("Model name not specified in configuration")
983
- logger.error("Please ensure 'name' is specified under 'model' in transformers_config.json")
984
- return 1
985
-
986
- log_info(f"Using model: {model_name}")
987
- log_info("All configurations loaded successfully")
988
-
989
- # Apply hardware-specific settings if available
990
- if hardware_config:
991
- # Get training optimizations from hardware config
992
- training_opts = hardware_config.get("training_optimizations", {})
993
-
994
- # Apply batch size and gradient accumulation settings
995
- if training_opts.get("per_device_batch_size") and transformers_config.get("training"):
996
- batch_size = training_opts.get("per_device_batch_size")
997
- transformers_config["training"]["per_device_train_batch_size"] = batch_size
998
- log_info(f"Applied hardware-optimized batch size: {batch_size}")
999
-
1000
- if training_opts.get("gradient_accumulation_steps") and transformers_config.get("training"):
1001
- grad_steps = training_opts.get("gradient_accumulation_steps")
1002
- transformers_config["training"]["gradient_accumulation_steps"] = grad_steps
1003
- log_info(f"Applied hardware-optimized gradient accumulation: {grad_steps}")
1004
-
1005
- # Apply memory optimizations
1006
- memory_opts = training_opts.get("memory_optimizations", {})
1007
- if memory_opts.get("use_gradient_checkpointing") is not None and transformers_config.get("training"):
1008
- grad_ckpt = memory_opts.get("use_gradient_checkpointing")
1009
- transformers_config["training"]["gradient_checkpointing"] = grad_ckpt
1010
- log_info(f"Applied hardware-optimized gradient checkpointing: {grad_ckpt}")
1011
-
1012
- # Apply system settings
1013
- system_settings = hardware_config.get("system_settings", {})
1014
- if system_settings.get("dataloader_num_workers") is not None:
1015
- workers = system_settings.get("dataloader_num_workers")
1016
- log_info(f"Using {workers} dataloader workers from hardware config")
1017
-
1018
- # Get distribution strategy
1019
- multi_gpu_strategy = training_opts.get("multi_gpu_strategy", "data_parallel")
1020
- log_info(f"Hardware config specifies {multi_gpu_strategy} for multi-GPU training")
1021
-
1022
- except Exception as e:
1023
- logger.error(f"Error loading configurations: {e}")
1024
- return 1
1025
-
1026
  # Set random seed for reproducibility
1027
  seed = transformers_config.get("seed", 42)
1028
  set_seed(seed)
1029
- log_info(f"Set random seed to {seed} for reproducibility")
 
 
 
1030
 
1031
  # Empty CUDA cache to ensure clean state
1032
  if CUDA_AVAILABLE:
@@ -1043,17 +1077,13 @@ def main():
1043
  log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
1044
 
1045
  try:
1046
- log_info("Loading model and tokenizer...")
1047
- model, tokenizer = load_model_and_tokenizer(transformers_config)
1048
- log_info("Model and tokenizer loaded successfully")
1049
 
1050
- # Load dataset with proper mapping
1051
- try:
1052
- log_info(f"Loading dataset from {dataset_config.get('dataset', {}).get('name', '')}")
1053
- dataset = load_dataset_with_mapping(dataset_config)
1054
- log_info(f"Dataset loaded with {len(dataset)} examples")
1055
- except Exception as e:
1056
- logger.error(f"Error loading dataset: {e}")
1057
  return 1
1058
 
1059
  # Create data collator
 
123
  os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN")
124
 
125
  def load_configs(base_path):
126
+ """Load configuration from transformers_config.json file."""
 
 
127
  # Using a single consolidated config file
128
+ config_file = base_path
129
 
 
130
  try:
131
+ with open(config_file, "r") as f:
132
  config = json.load(f)
133
+ logger.info(f"Loaded configuration from {config_file}")
134
+ return config
 
 
 
135
  except Exception as e:
136
  logger.error(f"Error loading {config_file}: {e}")
137
  raise
 
 
138
 
139
  def parse_args():
140
  parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset")
141
+ parser.add_argument("--config", type=str, default="transformers_config.json", help="Path to configuration file")
142
  return parser.parse_args()
143
 
144
  def load_model_and_tokenizer(config):
 
149
  logger.error("Please ensure unsloth is in requirements.txt")
150
  raise ImportError("Unsloth is required for this training setup")
151
 
152
+ # Get model name correctly from config
153
+ model_name = config.get("model_name") or config.get("model", {}).get("name")
154
  logger.info(f"Loading model: {model_name}")
155
 
156
  if not model_name:
 
158
 
159
  logger.info("Using Unsloth optimizations with pre-quantized model")
160
 
161
+ # Check for flash attention
162
  use_flash_attention = config.get("use_flash_attention", True)
163
+ if use_flash_attention and not find_spec("flash_attn"):
164
+ logger.warning("flash-attn not found. Will continue without flash attention.")
165
+ logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
 
166
  use_flash_attention = False
 
167
 
168
  # First detect if we have a GPU
169
  if torch.cuda.is_available():
 
311
 
312
  # Add prompt_number field that increments based on original order
313
  def add_prompt_numbers(examples, indices):
314
+ # Defensive check to ensure indices is not None and is iterable
315
  if indices is None:
316
  logger.warning("Warning: indices is None in add_prompt_numbers, using empty list")
317
  indices = []
318
+ elif isinstance(indices, int):
319
+ # Handle case where indices is a single integer
320
+ logger.warning(f"Warning: indices is an integer ({indices}) in add_prompt_numbers, converting to list")
321
+ indices = [indices]
322
 
323
+ # Ensure indices is always a list/iterable
324
+ try:
325
+ # Create a new field with the dataset index as the prompt number, starting at 1
326
+ examples["prompt_number"] = [idx + 1 for idx in indices] # Adding 1 to make it 1-indexed
327
+ except TypeError:
328
+ # Fallback for non-iterable types
329
+ logger.warning(f"Warning: non-iterable indices in add_prompt_numbers: {type(indices)}, using default")
330
+ examples["prompt_number"] = [1] * len(next(iter(examples.values())))
331
+
332
  return examples
333
 
334
  # Add prompt numbers to the dataset based on original order
 
359
  dataset = Dataset.from_list(updated_examples)
360
  logger.info(f"Successfully added prompt_number field using fallback method")
361
 
362
+ # Rename 'id' to 'article_id' if it exists
363
+ if 'id' in dataset.column_names and 'article_id' not in dataset.column_names:
364
+ logger.info("Renaming 'id' column to 'article_id'")
365
+ dataset = dataset.rename_column('id', 'article_id')
366
+
367
+ # Reorder columns to make prompt_number first if it exists
368
+ if 'prompt_number' in dataset.column_names:
369
+ logger.info("Reordering columns to place prompt_number first")
370
+ # Get current column names
371
+ current_columns = dataset.column_names
372
+ # Create new column order with prompt_number first
373
+ new_column_order = ['prompt_number'] + [col for col in current_columns if col != 'prompt_number']
374
+ # Reorder columns
375
+ dataset = dataset.select_columns(new_column_order)
376
+
377
+ # Verify all new column names for logging
378
+ logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
379
+ logger.info(f"Dataset columns: {dataset.column_names}")
380
+
381
+ # Verify dataset is not empty
382
+ if len(dataset) == 0:
383
+ logger.error("Dataset is empty! This will cause errors during training.")
384
+ raise ValueError("Empty dataset loaded")
385
+
386
+ # Check for required columns
387
+ required_columns = ['conversations']
388
+ for col in required_columns:
389
  if col not in dataset.column_names:
390
+ logger.error(f"Required column '{col}' not found in dataset!")
391
+ raise ValueError(f"Required column '{col}' missing from dataset")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
+ # Verify expected columns exist
394
+ expected_columns = {"article_id", "conversations", "prompt_number"}
395
+ missing_columns = expected_columns - set(dataset.column_names)
396
+ if missing_columns:
397
+ logger.warning(f"Some expected columns are missing: {missing_columns}")
398
+
399
+ # If "conversations" is missing but "text" exists, attempt conversion
400
+ if "conversations" not in dataset.column_names and "text" in dataset.column_names:
401
+ logger.info("Converting 'text' field to 'conversations' format")
402
+
403
+ def convert_text_to_conversations(example):
404
+ # Check if text is already a list of conversation turns
405
+ if isinstance(example.get("text"), list):
406
+ example["conversations"] = example["text"]
407
+ # Otherwise, create a simple conversation with the text as user message
408
+ else:
409
+ example["conversations"] = [
410
+ {"role": "user", "content": str(example.get("text", ""))}
411
+ ]
412
+ return example
413
+
414
+ dataset = dataset.map(convert_text_to_conversations)
415
+ logger.info("Successfully converted 'text' to 'conversations'")
416
 
417
+ # Verify data ordering requirements
418
  processing_config = dataset_config.get("dataset", {}).get("processing", {})
419
  data_loading_config = dataset_config.get("data_loading", {})
420
 
421
+ # Check if sorting is required
422
+ sort_by_article_id = processing_config.get("sort_by_article_id", False)
423
+ if sort_by_article_id and 'article_id' in dataset.column_names:
424
+ logger.info("Sorting dataset by article_id as specified in config")
425
+ dataset = dataset.sort("article_id")
426
+ sorted_ids = [example['article_id'] for example in dataset.select(range(min(5, len(dataset))))]
427
+ logger.info(f"First few article_ids after sorting: {sorted_ids}")
428
+
429
  # Flag consolidation - we only need one flag to control sequence preservation
430
  # Default to True to ensure safety
431
  preserve_sequence = processing_config.get("preserve_entry_sequence", True)
 
487
  logger.warning(f"Error accessing dataset at index {i}: {e}")
488
 
489
  if sample_examples:
490
+ id_field = 'article_id' if 'article_id' in dataset.column_names else 'id'
491
+ if all(isinstance(example.get(id_field, ''), (int, str)) for example in sample_examples):
492
+ sample_ids = [example.get(id_field, '') for example in sample_examples if id_field in example]
493
 
494
  if sample_ids and all(isinstance(id, int) or (isinstance(id, str) and id.isdigit()) for id in sample_ids):
495
  numeric_ids = [int(id) if isinstance(id, str) else id for id in sample_ids]
496
  if len(numeric_ids) > 1:
497
  is_ordered = all(numeric_ids[i] <= numeric_ids[i+1] for i in range(len(numeric_ids)-1))
498
  if not is_ordered:
499
+ logger.warning(f"WARNING: Sample {id_field}s are not in sequential order.")
500
  else:
501
+ logger.info(f"Sample {id_field}s appear to be in sequential order.")
502
  except Exception as e:
503
  logger.warning(f"Error checking ID sequence: {e}")
504
  except Exception as e:
 
510
  # Safely get first few samples
511
  first_few_indices = range(min(5, len(dataset)))
512
  sample_prompt_numbers = []
513
+ sample_article_ids = []
514
 
515
  for i in first_few_indices:
516
  try:
517
  example = dataset[i]
518
  if 'prompt_number' in example:
519
  sample_prompt_numbers.append(example['prompt_number'])
520
+ if 'article_id' in example:
521
+ sample_article_ids.append(example['article_id'])
522
  except Exception as e:
523
  logger.warning(f"Error accessing sample at index {i}: {e}")
524
 
525
+ logger.info(f"First few samples - Prompt numbers: {sample_prompt_numbers}, Article IDs: {sample_article_ids}")
526
 
527
  # Log conversation structure without full content
528
  if len(dataset) > 0:
 
548
 
549
  logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
550
  logger.info(f"Dataset columns: {dataset.column_names}")
551
+
552
+ # Verify dataset is not empty
553
+ if len(dataset) == 0:
554
+ logger.error("Dataset is empty! Cannot proceed with training.")
555
+ return dataset
556
+
557
+ # Check for required columns
558
+ required_cols = ['conversations', 'prompt_number']
559
+ for col in required_cols:
560
+ if col not in dataset.column_names:
561
+ logger.error(f"Required column '{col}' missing from dataset. Cannot proceed with training.")
562
+ return dataset
563
+
564
+ # Validate at least one sample can be processed
565
+ try:
566
+ if len(dataset) > 0:
567
+ sample = dataset[0]
568
+ if 'conversations' not in sample or not sample['conversations']:
569
+ logger.error("First sample has no conversations! Data format may be incorrect.")
570
+ return dataset
571
+ if not isinstance(sample['conversations'], list):
572
+ logger.error(f"Conversations field should be a list but got {type(sample['conversations'])}")
573
+ return dataset
574
+ except Exception as e:
575
+ logger.error(f"Error validating first sample: {e}")
576
+ return dataset
577
+
578
+ # Add metadata if specified
579
+ metadata_config = dataset_config.get("data_formatting", {}).get("metadata_handling", {})
580
+ if metadata_config:
581
+ include_article_id = metadata_config.get("include_article_id", False)
582
+ include_prompt_number = metadata_config.get("include_prompt_number", False)
583
+ metadata_format = metadata_config.get("metadata_format", "")
584
+
585
+ if (include_article_id or include_prompt_number) and metadata_format:
586
+ logger.info("Adding metadata to conversations")
587
+
588
+ def add_metadata(example):
589
+ if not example.get("conversations"):
590
+ return example
591
+
592
+ # Prepare metadata
593
+ metadata = metadata_format
594
+ if include_article_id and "article_id" in example:
595
+ metadata = metadata.replace("{article_id}", str(example.get("article_id", "")))
596
+ if include_prompt_number and "prompt_number" in example:
597
+ metadata = metadata.replace("{prompt_number}", str(example.get("prompt_number", "")))
598
+
599
+ # Add system message with metadata if not empty
600
+ if metadata.strip():
601
+ if example["conversations"] and isinstance(example["conversations"], list):
602
+ # Check if first message is already a system message
603
+ if (isinstance(example["conversations"][0], dict) and
604
+ example["conversations"][0].get("role") == "system"):
605
+ # Append to existing system message
606
+ example["conversations"][0]["content"] += f"\n\nMetadata: {metadata}"
607
+ else:
608
+ # Add new system message at the beginning
609
+ example["conversations"].insert(0, {
610
+ "role": "system",
611
+ "content": f"Metadata: {metadata}"
612
+ })
613
+
614
+ return example
615
+
616
+ dataset = dataset.map(add_metadata)
617
+ logger.info("Metadata added to conversations")
618
+
619
  return dataset
620
 
621
  except Exception as e:
 
858
  is_sequence_maintained = False
859
 
860
  # Also compare IDs as a backup check
861
+ elif ('article_id' in orig_sample and
862
+ 'article_id' in current_sample and
863
+ orig_sample['article_id'] is not None and
864
+ current_sample['article_id'] is not None):
865
 
866
+ if orig_sample['article_id'] != current_sample['article_id']:
867
+ log_info(f"WARNING: Sequence integrity compromised! Sample {i} article_id changed from {orig_sample['article_id']} to {current_sample['article_id']}")
868
  is_sequence_maintained = False
869
 
870
  # Compare input fingerprints
 
1005
  missing_packages.append("peft>=0.9.0")
1006
 
1007
  # Optional packages - don't add to missing list, just log
1008
+ if find_spec("flash_attn"):
 
1009
  logger.info("flash-attn found. Flash attention will be used for faster training.")
1010
+ else:
1011
  logger.warning("flash-attn not found. Training will work but may be slower.")
1012
+ logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
1013
 
1014
  # If critical packages are missing, exit with instructions
1015
  if missing_packages:
 
1023
 
1024
  def main():
1025
  # Set up logging
1026
+ logger.info("Starting training process")
 
 
 
 
 
 
 
 
 
1027
 
1028
  # Parse arguments
1029
  args = parse_args()
1030
 
1031
+ # Load environment variables
1032
+ load_env_variables()
1033
+
1034
+ # Load configuration
1035
+ try:
1036
+ transformers_config = load_configs(args.config)
1037
+ hardware_config = transformers_config.get("hardware", {})
1038
+ dataset_config = transformers_config.get("dataset", {})
1039
+ logger.info("Configuration loaded successfully")
1040
+ except Exception as e:
1041
+ logger.error(f"Error loading configuration: {e}")
1042
+ return 1
1043
+
1044
  # Check dependencies
1045
  if not check_dependencies():
1046
  logger.error("Aborting due to missing critical dependencies")
1047
  return 1
1048
 
 
 
 
1049
  # Check if we're in distributed mode
1050
  is_distributed = "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1
1051
  if is_distributed:
1052
+ local_rank = int(os.environ.get("LOCAL_RANK", "0"))
1053
+ log_info(f"Running in distributed mode with {os.environ.get('WORLD_SIZE')} processes, local_rank: {local_rank}")
1054
  else:
1055
  log_info("Running in non-distributed mode (single process)")
1056
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1057
  # Set random seed for reproducibility
1058
  seed = transformers_config.get("seed", 42)
1059
  set_seed(seed)
1060
+ logger.info(f"Set random seed to {seed}")
1061
+
1062
+ # Load model and tokenizer using the consolidated config
1063
+ model, tokenizer = load_model_and_tokenizer(transformers_config)
1064
 
1065
  # Empty CUDA cache to ensure clean state
1066
  if CUDA_AVAILABLE:
 
1077
  log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
1078
 
1079
  try:
1080
+ log_info("Loading dataset...")
1081
+ dataset = load_dataset_with_mapping(dataset_config)
1082
+ log_info(f"Dataset loaded with {len(dataset)} examples")
1083
 
1084
+ # Minimal validation before proceeding
1085
+ if dataset is None or len(dataset) == 0:
1086
+ logger.error("Dataset is empty or None! Cannot proceed with training.")
 
 
 
 
1087
  return 1
1088
 
1089
  # Create data collator
transformers_config.json CHANGED
@@ -134,10 +134,11 @@
134
  "name": "George-API/cognitive-data",
135
  "split": "train",
136
  "column_mapping": {
137
- "conversations": "text"
 
138
  },
139
  "processing": {
140
- "sort_by_id": true,
141
  "maintain_paper_order": true,
142
  "preserve_entry_sequence": true,
143
  "max_seq_length": 2048
@@ -152,9 +153,9 @@
152
  "user": "Human: {content}\n\n"
153
  },
154
  "metadata_handling": {
155
- "include_paper_id": true,
156
- "include_chunk_number": true,
157
- "metadata_format": "Paper ID: {paper_id} | Chunk: {chunk_number}"
158
  }
159
  },
160
  "data_loading": {
 
134
  "name": "George-API/cognitive-data",
135
  "split": "train",
136
  "column_mapping": {
137
+ "conversations": "text",
138
+ "article_id": "id"
139
  },
140
  "processing": {
141
+ "sort_by_article_id": true,
142
  "maintain_paper_order": true,
143
  "preserve_entry_sequence": true,
144
  "max_seq_length": 2048
 
153
  "user": "Human: {content}\n\n"
154
  },
155
  "metadata_handling": {
156
+ "include_article_id": true,
157
+ "include_prompt_number": true,
158
+ "metadata_format": "Article ID: {article_id} | Prompt: {prompt_number}"
159
  }
160
  },
161
  "data_loading": {