Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- run_transformers_training.py +199 -169
- transformers_config.json +6 -5
run_transformers_training.py
CHANGED
@@ -123,30 +123,22 @@ def load_env_variables():
|
|
123 |
os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN")
|
124 |
|
125 |
def load_configs(base_path):
|
126 |
-
"""Load
|
127 |
-
configs = {}
|
128 |
-
|
129 |
# Using a single consolidated config file
|
130 |
-
config_file =
|
131 |
|
132 |
-
file_path = os.path.join(base_path, config_file)
|
133 |
try:
|
134 |
-
with open(
|
135 |
config = json.load(f)
|
136 |
-
|
137 |
-
|
138 |
-
configs["hardware"] = config.get("hardware", {})
|
139 |
-
configs["dataset"] = config.get("dataset", {})
|
140 |
-
logger.info(f"Loaded consolidated configuration from {file_path}")
|
141 |
except Exception as e:
|
142 |
logger.error(f"Error loading {config_file}: {e}")
|
143 |
raise
|
144 |
-
|
145 |
-
return configs
|
146 |
|
147 |
def parse_args():
|
148 |
parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset")
|
149 |
-
parser.add_argument("--
|
150 |
return parser.parse_args()
|
151 |
|
152 |
def load_model_and_tokenizer(config):
|
@@ -157,8 +149,8 @@ def load_model_and_tokenizer(config):
|
|
157 |
logger.error("Please ensure unsloth is in requirements.txt")
|
158 |
raise ImportError("Unsloth is required for this training setup")
|
159 |
|
160 |
-
# Get model name correctly from
|
161 |
-
model_name = config.get("
|
162 |
logger.info(f"Loading model: {model_name}")
|
163 |
|
164 |
if not model_name:
|
@@ -166,14 +158,12 @@ def load_model_and_tokenizer(config):
|
|
166 |
|
167 |
logger.info("Using Unsloth optimizations with pre-quantized model")
|
168 |
|
169 |
-
# Check for flash attention
|
170 |
use_flash_attention = config.get("use_flash_attention", True)
|
171 |
-
|
172 |
-
|
173 |
-
logger.
|
174 |
-
except ImportError:
|
175 |
use_flash_attention = False
|
176 |
-
logger.warning("Flash attention not available, falling back to standard attention")
|
177 |
|
178 |
# First detect if we have a GPU
|
179 |
if torch.cuda.is_available():
|
@@ -321,13 +311,24 @@ def load_dataset_with_mapping(dataset_config):
|
|
321 |
|
322 |
# Add prompt_number field that increments based on original order
|
323 |
def add_prompt_numbers(examples, indices):
|
324 |
-
# Defensive check to ensure indices is not None
|
325 |
if indices is None:
|
326 |
logger.warning("Warning: indices is None in add_prompt_numbers, using empty list")
|
327 |
indices = []
|
|
|
|
|
|
|
|
|
328 |
|
329 |
-
#
|
330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
return examples
|
332 |
|
333 |
# Add prompt numbers to the dataset based on original order
|
@@ -358,37 +359,73 @@ def load_dataset_with_mapping(dataset_config):
|
|
358 |
dataset = Dataset.from_list(updated_examples)
|
359 |
logger.info(f"Successfully added prompt_number field using fallback method")
|
360 |
|
361 |
-
#
|
362 |
-
|
363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
if col not in dataset.column_names:
|
365 |
-
|
366 |
-
|
367 |
-
logger.info("Converting 'text' field to 'conversations' format")
|
368 |
-
|
369 |
-
def convert_text_to_conversations(example):
|
370 |
-
# Check if text is already a list of conversation turns
|
371 |
-
if isinstance(example.get("text"), list):
|
372 |
-
return {"conversations": example["text"]}
|
373 |
-
# Otherwise, create a simple conversation with the text as user message
|
374 |
-
else:
|
375 |
-
return {
|
376 |
-
"conversations": [
|
377 |
-
{"role": "user", "content": str(example.get("text", ""))}
|
378 |
-
]
|
379 |
-
}
|
380 |
-
|
381 |
-
dataset = dataset.map(convert_text_to_conversations)
|
382 |
-
else:
|
383 |
-
logger.warning(f"Expected column '{col}' not found in dataset")
|
384 |
|
385 |
-
#
|
386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
|
388 |
-
#
|
389 |
processing_config = dataset_config.get("dataset", {}).get("processing", {})
|
390 |
data_loading_config = dataset_config.get("data_loading", {})
|
391 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
# Flag consolidation - we only need one flag to control sequence preservation
|
393 |
# Default to True to ensure safety
|
394 |
preserve_sequence = processing_config.get("preserve_entry_sequence", True)
|
@@ -450,17 +487,18 @@ def load_dataset_with_mapping(dataset_config):
|
|
450 |
logger.warning(f"Error accessing dataset at index {i}: {e}")
|
451 |
|
452 |
if sample_examples:
|
453 |
-
|
454 |
-
|
|
|
455 |
|
456 |
if sample_ids and all(isinstance(id, int) or (isinstance(id, str) and id.isdigit()) for id in sample_ids):
|
457 |
numeric_ids = [int(id) if isinstance(id, str) else id for id in sample_ids]
|
458 |
if len(numeric_ids) > 1:
|
459 |
is_ordered = all(numeric_ids[i] <= numeric_ids[i+1] for i in range(len(numeric_ids)-1))
|
460 |
if not is_ordered:
|
461 |
-
logger.warning("WARNING: Sample
|
462 |
else:
|
463 |
-
logger.info("Sample
|
464 |
except Exception as e:
|
465 |
logger.warning(f"Error checking ID sequence: {e}")
|
466 |
except Exception as e:
|
@@ -472,19 +510,19 @@ def load_dataset_with_mapping(dataset_config):
|
|
472 |
# Safely get first few samples
|
473 |
first_few_indices = range(min(5, len(dataset)))
|
474 |
sample_prompt_numbers = []
|
475 |
-
|
476 |
|
477 |
for i in first_few_indices:
|
478 |
try:
|
479 |
example = dataset[i]
|
480 |
if 'prompt_number' in example:
|
481 |
sample_prompt_numbers.append(example['prompt_number'])
|
482 |
-
if '
|
483 |
-
|
484 |
except Exception as e:
|
485 |
logger.warning(f"Error accessing sample at index {i}: {e}")
|
486 |
|
487 |
-
logger.info(f"First few samples - Prompt numbers: {sample_prompt_numbers}, IDs: {
|
488 |
|
489 |
# Log conversation structure without full content
|
490 |
if len(dataset) > 0:
|
@@ -510,6 +548,74 @@ def load_dataset_with_mapping(dataset_config):
|
|
510 |
|
511 |
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
|
512 |
logger.info(f"Dataset columns: {dataset.column_names}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
return dataset
|
514 |
|
515 |
except Exception as e:
|
@@ -752,13 +858,13 @@ class LoggingCallback(TrainerCallback):
|
|
752 |
is_sequence_maintained = False
|
753 |
|
754 |
# Also compare IDs as a backup check
|
755 |
-
elif ('
|
756 |
-
'
|
757 |
-
orig_sample['
|
758 |
-
current_sample['
|
759 |
|
760 |
-
if orig_sample['
|
761 |
-
log_info(f"WARNING: Sequence integrity compromised! Sample {i}
|
762 |
is_sequence_maintained = False
|
763 |
|
764 |
# Compare input fingerprints
|
@@ -899,12 +1005,11 @@ def check_dependencies():
|
|
899 |
missing_packages.append("peft>=0.9.0")
|
900 |
|
901 |
# Optional packages - don't add to missing list, just log
|
902 |
-
|
903 |
-
import flash_attn
|
904 |
logger.info("flash-attn found. Flash attention will be used for faster training.")
|
905 |
-
|
906 |
logger.warning("flash-attn not found. Training will work but may be slower.")
|
907 |
-
|
908 |
|
909 |
# If critical packages are missing, exit with instructions
|
910 |
if missing_packages:
|
@@ -918,115 +1023,44 @@ def check_dependencies():
|
|
918 |
|
919 |
def main():
|
920 |
# Set up logging
|
921 |
-
|
922 |
-
|
923 |
-
# Log hardware information
|
924 |
-
log_info(f"Hardware detection: CUDA {'available' if CUDA_AVAILABLE else 'not available'}")
|
925 |
-
if CUDA_AVAILABLE:
|
926 |
-
log_info(f"Found {NUM_GPUS} GPUs")
|
927 |
-
for i in range(NUM_GPUS):
|
928 |
-
log_info(f" GPU {i}: {torch.cuda.get_device_name(i)}")
|
929 |
-
else:
|
930 |
-
log_info("Running on CPU (training will be very slow)")
|
931 |
|
932 |
# Parse arguments
|
933 |
args = parse_args()
|
934 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
935 |
# Check dependencies
|
936 |
if not check_dependencies():
|
937 |
logger.error("Aborting due to missing critical dependencies")
|
938 |
return 1
|
939 |
|
940 |
-
# Load environment variables
|
941 |
-
load_env_variables()
|
942 |
-
|
943 |
# Check if we're in distributed mode
|
944 |
is_distributed = "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1
|
945 |
if is_distributed:
|
946 |
-
|
|
|
947 |
else:
|
948 |
log_info("Running in non-distributed mode (single process)")
|
949 |
|
950 |
-
# Load all configurations - do this once
|
951 |
-
try:
|
952 |
-
configs = load_configs(args.config_dir)
|
953 |
-
|
954 |
-
# Extract specific configs immediately after loading
|
955 |
-
if not configs:
|
956 |
-
logger.error("Failed to load configuration")
|
957 |
-
return 1
|
958 |
-
|
959 |
-
# Store configurations in clear variables
|
960 |
-
transformers_config = configs.get("transformers", {})
|
961 |
-
hardware_config = configs.get("hardware", {})
|
962 |
-
dataset_config = configs.get("dataset", {})
|
963 |
-
|
964 |
-
# Verify configuration sections exist
|
965 |
-
if not transformers_config:
|
966 |
-
logger.error("transformers_config.json not found or invalid")
|
967 |
-
return 1
|
968 |
-
|
969 |
-
if not hardware_config:
|
970 |
-
logger.warning("Hardware configuration section not found in transformers_config.json. Using default hardware configuration.")
|
971 |
-
|
972 |
-
if not dataset_config:
|
973 |
-
logger.error("Dataset configuration section not found in transformers_config.json")
|
974 |
-
return 1
|
975 |
-
|
976 |
-
# Validate model configuration
|
977 |
-
model_name = (transformers_config.get("model", {}).get("name") or
|
978 |
-
transformers_config.get("model_name_or_path") or
|
979 |
-
transformers_config.get("model_name"))
|
980 |
-
|
981 |
-
if not model_name:
|
982 |
-
logger.error("Model name not specified in configuration")
|
983 |
-
logger.error("Please ensure 'name' is specified under 'model' in transformers_config.json")
|
984 |
-
return 1
|
985 |
-
|
986 |
-
log_info(f"Using model: {model_name}")
|
987 |
-
log_info("All configurations loaded successfully")
|
988 |
-
|
989 |
-
# Apply hardware-specific settings if available
|
990 |
-
if hardware_config:
|
991 |
-
# Get training optimizations from hardware config
|
992 |
-
training_opts = hardware_config.get("training_optimizations", {})
|
993 |
-
|
994 |
-
# Apply batch size and gradient accumulation settings
|
995 |
-
if training_opts.get("per_device_batch_size") and transformers_config.get("training"):
|
996 |
-
batch_size = training_opts.get("per_device_batch_size")
|
997 |
-
transformers_config["training"]["per_device_train_batch_size"] = batch_size
|
998 |
-
log_info(f"Applied hardware-optimized batch size: {batch_size}")
|
999 |
-
|
1000 |
-
if training_opts.get("gradient_accumulation_steps") and transformers_config.get("training"):
|
1001 |
-
grad_steps = training_opts.get("gradient_accumulation_steps")
|
1002 |
-
transformers_config["training"]["gradient_accumulation_steps"] = grad_steps
|
1003 |
-
log_info(f"Applied hardware-optimized gradient accumulation: {grad_steps}")
|
1004 |
-
|
1005 |
-
# Apply memory optimizations
|
1006 |
-
memory_opts = training_opts.get("memory_optimizations", {})
|
1007 |
-
if memory_opts.get("use_gradient_checkpointing") is not None and transformers_config.get("training"):
|
1008 |
-
grad_ckpt = memory_opts.get("use_gradient_checkpointing")
|
1009 |
-
transformers_config["training"]["gradient_checkpointing"] = grad_ckpt
|
1010 |
-
log_info(f"Applied hardware-optimized gradient checkpointing: {grad_ckpt}")
|
1011 |
-
|
1012 |
-
# Apply system settings
|
1013 |
-
system_settings = hardware_config.get("system_settings", {})
|
1014 |
-
if system_settings.get("dataloader_num_workers") is not None:
|
1015 |
-
workers = system_settings.get("dataloader_num_workers")
|
1016 |
-
log_info(f"Using {workers} dataloader workers from hardware config")
|
1017 |
-
|
1018 |
-
# Get distribution strategy
|
1019 |
-
multi_gpu_strategy = training_opts.get("multi_gpu_strategy", "data_parallel")
|
1020 |
-
log_info(f"Hardware config specifies {multi_gpu_strategy} for multi-GPU training")
|
1021 |
-
|
1022 |
-
except Exception as e:
|
1023 |
-
logger.error(f"Error loading configurations: {e}")
|
1024 |
-
return 1
|
1025 |
-
|
1026 |
# Set random seed for reproducibility
|
1027 |
seed = transformers_config.get("seed", 42)
|
1028 |
set_seed(seed)
|
1029 |
-
|
|
|
|
|
|
|
1030 |
|
1031 |
# Empty CUDA cache to ensure clean state
|
1032 |
if CUDA_AVAILABLE:
|
@@ -1043,17 +1077,13 @@ def main():
|
|
1043 |
log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
|
1044 |
|
1045 |
try:
|
1046 |
-
log_info("Loading
|
1047 |
-
|
1048 |
-
log_info("
|
1049 |
|
1050 |
-
#
|
1051 |
-
|
1052 |
-
|
1053 |
-
dataset = load_dataset_with_mapping(dataset_config)
|
1054 |
-
log_info(f"Dataset loaded with {len(dataset)} examples")
|
1055 |
-
except Exception as e:
|
1056 |
-
logger.error(f"Error loading dataset: {e}")
|
1057 |
return 1
|
1058 |
|
1059 |
# Create data collator
|
|
|
123 |
os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN")
|
124 |
|
125 |
def load_configs(base_path):
|
126 |
+
"""Load configuration from transformers_config.json file."""
|
|
|
|
|
127 |
# Using a single consolidated config file
|
128 |
+
config_file = base_path
|
129 |
|
|
|
130 |
try:
|
131 |
+
with open(config_file, "r") as f:
|
132 |
config = json.load(f)
|
133 |
+
logger.info(f"Loaded configuration from {config_file}")
|
134 |
+
return config
|
|
|
|
|
|
|
135 |
except Exception as e:
|
136 |
logger.error(f"Error loading {config_file}: {e}")
|
137 |
raise
|
|
|
|
|
138 |
|
139 |
def parse_args():
|
140 |
parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset")
|
141 |
+
parser.add_argument("--config", type=str, default="transformers_config.json", help="Path to configuration file")
|
142 |
return parser.parse_args()
|
143 |
|
144 |
def load_model_and_tokenizer(config):
|
|
|
149 |
logger.error("Please ensure unsloth is in requirements.txt")
|
150 |
raise ImportError("Unsloth is required for this training setup")
|
151 |
|
152 |
+
# Get model name correctly from config
|
153 |
+
model_name = config.get("model_name") or config.get("model", {}).get("name")
|
154 |
logger.info(f"Loading model: {model_name}")
|
155 |
|
156 |
if not model_name:
|
|
|
158 |
|
159 |
logger.info("Using Unsloth optimizations with pre-quantized model")
|
160 |
|
161 |
+
# Check for flash attention
|
162 |
use_flash_attention = config.get("use_flash_attention", True)
|
163 |
+
if use_flash_attention and not find_spec("flash_attn"):
|
164 |
+
logger.warning("flash-attn not found. Will continue without flash attention.")
|
165 |
+
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
|
|
|
166 |
use_flash_attention = False
|
|
|
167 |
|
168 |
# First detect if we have a GPU
|
169 |
if torch.cuda.is_available():
|
|
|
311 |
|
312 |
# Add prompt_number field that increments based on original order
|
313 |
def add_prompt_numbers(examples, indices):
|
314 |
+
# Defensive check to ensure indices is not None and is iterable
|
315 |
if indices is None:
|
316 |
logger.warning("Warning: indices is None in add_prompt_numbers, using empty list")
|
317 |
indices = []
|
318 |
+
elif isinstance(indices, int):
|
319 |
+
# Handle case where indices is a single integer
|
320 |
+
logger.warning(f"Warning: indices is an integer ({indices}) in add_prompt_numbers, converting to list")
|
321 |
+
indices = [indices]
|
322 |
|
323 |
+
# Ensure indices is always a list/iterable
|
324 |
+
try:
|
325 |
+
# Create a new field with the dataset index as the prompt number, starting at 1
|
326 |
+
examples["prompt_number"] = [idx + 1 for idx in indices] # Adding 1 to make it 1-indexed
|
327 |
+
except TypeError:
|
328 |
+
# Fallback for non-iterable types
|
329 |
+
logger.warning(f"Warning: non-iterable indices in add_prompt_numbers: {type(indices)}, using default")
|
330 |
+
examples["prompt_number"] = [1] * len(next(iter(examples.values())))
|
331 |
+
|
332 |
return examples
|
333 |
|
334 |
# Add prompt numbers to the dataset based on original order
|
|
|
359 |
dataset = Dataset.from_list(updated_examples)
|
360 |
logger.info(f"Successfully added prompt_number field using fallback method")
|
361 |
|
362 |
+
# Rename 'id' to 'article_id' if it exists
|
363 |
+
if 'id' in dataset.column_names and 'article_id' not in dataset.column_names:
|
364 |
+
logger.info("Renaming 'id' column to 'article_id'")
|
365 |
+
dataset = dataset.rename_column('id', 'article_id')
|
366 |
+
|
367 |
+
# Reorder columns to make prompt_number first if it exists
|
368 |
+
if 'prompt_number' in dataset.column_names:
|
369 |
+
logger.info("Reordering columns to place prompt_number first")
|
370 |
+
# Get current column names
|
371 |
+
current_columns = dataset.column_names
|
372 |
+
# Create new column order with prompt_number first
|
373 |
+
new_column_order = ['prompt_number'] + [col for col in current_columns if col != 'prompt_number']
|
374 |
+
# Reorder columns
|
375 |
+
dataset = dataset.select_columns(new_column_order)
|
376 |
+
|
377 |
+
# Verify all new column names for logging
|
378 |
+
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
|
379 |
+
logger.info(f"Dataset columns: {dataset.column_names}")
|
380 |
+
|
381 |
+
# Verify dataset is not empty
|
382 |
+
if len(dataset) == 0:
|
383 |
+
logger.error("Dataset is empty! This will cause errors during training.")
|
384 |
+
raise ValueError("Empty dataset loaded")
|
385 |
+
|
386 |
+
# Check for required columns
|
387 |
+
required_columns = ['conversations']
|
388 |
+
for col in required_columns:
|
389 |
if col not in dataset.column_names:
|
390 |
+
logger.error(f"Required column '{col}' not found in dataset!")
|
391 |
+
raise ValueError(f"Required column '{col}' missing from dataset")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
|
393 |
+
# Verify expected columns exist
|
394 |
+
expected_columns = {"article_id", "conversations", "prompt_number"}
|
395 |
+
missing_columns = expected_columns - set(dataset.column_names)
|
396 |
+
if missing_columns:
|
397 |
+
logger.warning(f"Some expected columns are missing: {missing_columns}")
|
398 |
+
|
399 |
+
# If "conversations" is missing but "text" exists, attempt conversion
|
400 |
+
if "conversations" not in dataset.column_names and "text" in dataset.column_names:
|
401 |
+
logger.info("Converting 'text' field to 'conversations' format")
|
402 |
+
|
403 |
+
def convert_text_to_conversations(example):
|
404 |
+
# Check if text is already a list of conversation turns
|
405 |
+
if isinstance(example.get("text"), list):
|
406 |
+
example["conversations"] = example["text"]
|
407 |
+
# Otherwise, create a simple conversation with the text as user message
|
408 |
+
else:
|
409 |
+
example["conversations"] = [
|
410 |
+
{"role": "user", "content": str(example.get("text", ""))}
|
411 |
+
]
|
412 |
+
return example
|
413 |
+
|
414 |
+
dataset = dataset.map(convert_text_to_conversations)
|
415 |
+
logger.info("Successfully converted 'text' to 'conversations'")
|
416 |
|
417 |
+
# Verify data ordering requirements
|
418 |
processing_config = dataset_config.get("dataset", {}).get("processing", {})
|
419 |
data_loading_config = dataset_config.get("data_loading", {})
|
420 |
|
421 |
+
# Check if sorting is required
|
422 |
+
sort_by_article_id = processing_config.get("sort_by_article_id", False)
|
423 |
+
if sort_by_article_id and 'article_id' in dataset.column_names:
|
424 |
+
logger.info("Sorting dataset by article_id as specified in config")
|
425 |
+
dataset = dataset.sort("article_id")
|
426 |
+
sorted_ids = [example['article_id'] for example in dataset.select(range(min(5, len(dataset))))]
|
427 |
+
logger.info(f"First few article_ids after sorting: {sorted_ids}")
|
428 |
+
|
429 |
# Flag consolidation - we only need one flag to control sequence preservation
|
430 |
# Default to True to ensure safety
|
431 |
preserve_sequence = processing_config.get("preserve_entry_sequence", True)
|
|
|
487 |
logger.warning(f"Error accessing dataset at index {i}: {e}")
|
488 |
|
489 |
if sample_examples:
|
490 |
+
id_field = 'article_id' if 'article_id' in dataset.column_names else 'id'
|
491 |
+
if all(isinstance(example.get(id_field, ''), (int, str)) for example in sample_examples):
|
492 |
+
sample_ids = [example.get(id_field, '') for example in sample_examples if id_field in example]
|
493 |
|
494 |
if sample_ids and all(isinstance(id, int) or (isinstance(id, str) and id.isdigit()) for id in sample_ids):
|
495 |
numeric_ids = [int(id) if isinstance(id, str) else id for id in sample_ids]
|
496 |
if len(numeric_ids) > 1:
|
497 |
is_ordered = all(numeric_ids[i] <= numeric_ids[i+1] for i in range(len(numeric_ids)-1))
|
498 |
if not is_ordered:
|
499 |
+
logger.warning(f"WARNING: Sample {id_field}s are not in sequential order.")
|
500 |
else:
|
501 |
+
logger.info(f"Sample {id_field}s appear to be in sequential order.")
|
502 |
except Exception as e:
|
503 |
logger.warning(f"Error checking ID sequence: {e}")
|
504 |
except Exception as e:
|
|
|
510 |
# Safely get first few samples
|
511 |
first_few_indices = range(min(5, len(dataset)))
|
512 |
sample_prompt_numbers = []
|
513 |
+
sample_article_ids = []
|
514 |
|
515 |
for i in first_few_indices:
|
516 |
try:
|
517 |
example = dataset[i]
|
518 |
if 'prompt_number' in example:
|
519 |
sample_prompt_numbers.append(example['prompt_number'])
|
520 |
+
if 'article_id' in example:
|
521 |
+
sample_article_ids.append(example['article_id'])
|
522 |
except Exception as e:
|
523 |
logger.warning(f"Error accessing sample at index {i}: {e}")
|
524 |
|
525 |
+
logger.info(f"First few samples - Prompt numbers: {sample_prompt_numbers}, Article IDs: {sample_article_ids}")
|
526 |
|
527 |
# Log conversation structure without full content
|
528 |
if len(dataset) > 0:
|
|
|
548 |
|
549 |
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
|
550 |
logger.info(f"Dataset columns: {dataset.column_names}")
|
551 |
+
|
552 |
+
# Verify dataset is not empty
|
553 |
+
if len(dataset) == 0:
|
554 |
+
logger.error("Dataset is empty! Cannot proceed with training.")
|
555 |
+
return dataset
|
556 |
+
|
557 |
+
# Check for required columns
|
558 |
+
required_cols = ['conversations', 'prompt_number']
|
559 |
+
for col in required_cols:
|
560 |
+
if col not in dataset.column_names:
|
561 |
+
logger.error(f"Required column '{col}' missing from dataset. Cannot proceed with training.")
|
562 |
+
return dataset
|
563 |
+
|
564 |
+
# Validate at least one sample can be processed
|
565 |
+
try:
|
566 |
+
if len(dataset) > 0:
|
567 |
+
sample = dataset[0]
|
568 |
+
if 'conversations' not in sample or not sample['conversations']:
|
569 |
+
logger.error("First sample has no conversations! Data format may be incorrect.")
|
570 |
+
return dataset
|
571 |
+
if not isinstance(sample['conversations'], list):
|
572 |
+
logger.error(f"Conversations field should be a list but got {type(sample['conversations'])}")
|
573 |
+
return dataset
|
574 |
+
except Exception as e:
|
575 |
+
logger.error(f"Error validating first sample: {e}")
|
576 |
+
return dataset
|
577 |
+
|
578 |
+
# Add metadata if specified
|
579 |
+
metadata_config = dataset_config.get("data_formatting", {}).get("metadata_handling", {})
|
580 |
+
if metadata_config:
|
581 |
+
include_article_id = metadata_config.get("include_article_id", False)
|
582 |
+
include_prompt_number = metadata_config.get("include_prompt_number", False)
|
583 |
+
metadata_format = metadata_config.get("metadata_format", "")
|
584 |
+
|
585 |
+
if (include_article_id or include_prompt_number) and metadata_format:
|
586 |
+
logger.info("Adding metadata to conversations")
|
587 |
+
|
588 |
+
def add_metadata(example):
|
589 |
+
if not example.get("conversations"):
|
590 |
+
return example
|
591 |
+
|
592 |
+
# Prepare metadata
|
593 |
+
metadata = metadata_format
|
594 |
+
if include_article_id and "article_id" in example:
|
595 |
+
metadata = metadata.replace("{article_id}", str(example.get("article_id", "")))
|
596 |
+
if include_prompt_number and "prompt_number" in example:
|
597 |
+
metadata = metadata.replace("{prompt_number}", str(example.get("prompt_number", "")))
|
598 |
+
|
599 |
+
# Add system message with metadata if not empty
|
600 |
+
if metadata.strip():
|
601 |
+
if example["conversations"] and isinstance(example["conversations"], list):
|
602 |
+
# Check if first message is already a system message
|
603 |
+
if (isinstance(example["conversations"][0], dict) and
|
604 |
+
example["conversations"][0].get("role") == "system"):
|
605 |
+
# Append to existing system message
|
606 |
+
example["conversations"][0]["content"] += f"\n\nMetadata: {metadata}"
|
607 |
+
else:
|
608 |
+
# Add new system message at the beginning
|
609 |
+
example["conversations"].insert(0, {
|
610 |
+
"role": "system",
|
611 |
+
"content": f"Metadata: {metadata}"
|
612 |
+
})
|
613 |
+
|
614 |
+
return example
|
615 |
+
|
616 |
+
dataset = dataset.map(add_metadata)
|
617 |
+
logger.info("Metadata added to conversations")
|
618 |
+
|
619 |
return dataset
|
620 |
|
621 |
except Exception as e:
|
|
|
858 |
is_sequence_maintained = False
|
859 |
|
860 |
# Also compare IDs as a backup check
|
861 |
+
elif ('article_id' in orig_sample and
|
862 |
+
'article_id' in current_sample and
|
863 |
+
orig_sample['article_id'] is not None and
|
864 |
+
current_sample['article_id'] is not None):
|
865 |
|
866 |
+
if orig_sample['article_id'] != current_sample['article_id']:
|
867 |
+
log_info(f"WARNING: Sequence integrity compromised! Sample {i} article_id changed from {orig_sample['article_id']} to {current_sample['article_id']}")
|
868 |
is_sequence_maintained = False
|
869 |
|
870 |
# Compare input fingerprints
|
|
|
1005 |
missing_packages.append("peft>=0.9.0")
|
1006 |
|
1007 |
# Optional packages - don't add to missing list, just log
|
1008 |
+
if find_spec("flash_attn"):
|
|
|
1009 |
logger.info("flash-attn found. Flash attention will be used for faster training.")
|
1010 |
+
else:
|
1011 |
logger.warning("flash-attn not found. Training will work but may be slower.")
|
1012 |
+
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
|
1013 |
|
1014 |
# If critical packages are missing, exit with instructions
|
1015 |
if missing_packages:
|
|
|
1023 |
|
1024 |
def main():
|
1025 |
# Set up logging
|
1026 |
+
logger.info("Starting training process")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1027 |
|
1028 |
# Parse arguments
|
1029 |
args = parse_args()
|
1030 |
|
1031 |
+
# Load environment variables
|
1032 |
+
load_env_variables()
|
1033 |
+
|
1034 |
+
# Load configuration
|
1035 |
+
try:
|
1036 |
+
transformers_config = load_configs(args.config)
|
1037 |
+
hardware_config = transformers_config.get("hardware", {})
|
1038 |
+
dataset_config = transformers_config.get("dataset", {})
|
1039 |
+
logger.info("Configuration loaded successfully")
|
1040 |
+
except Exception as e:
|
1041 |
+
logger.error(f"Error loading configuration: {e}")
|
1042 |
+
return 1
|
1043 |
+
|
1044 |
# Check dependencies
|
1045 |
if not check_dependencies():
|
1046 |
logger.error("Aborting due to missing critical dependencies")
|
1047 |
return 1
|
1048 |
|
|
|
|
|
|
|
1049 |
# Check if we're in distributed mode
|
1050 |
is_distributed = "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1
|
1051 |
if is_distributed:
|
1052 |
+
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
1053 |
+
log_info(f"Running in distributed mode with {os.environ.get('WORLD_SIZE')} processes, local_rank: {local_rank}")
|
1054 |
else:
|
1055 |
log_info("Running in non-distributed mode (single process)")
|
1056 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1057 |
# Set random seed for reproducibility
|
1058 |
seed = transformers_config.get("seed", 42)
|
1059 |
set_seed(seed)
|
1060 |
+
logger.info(f"Set random seed to {seed}")
|
1061 |
+
|
1062 |
+
# Load model and tokenizer using the consolidated config
|
1063 |
+
model, tokenizer = load_model_and_tokenizer(transformers_config)
|
1064 |
|
1065 |
# Empty CUDA cache to ensure clean state
|
1066 |
if CUDA_AVAILABLE:
|
|
|
1077 |
log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
|
1078 |
|
1079 |
try:
|
1080 |
+
log_info("Loading dataset...")
|
1081 |
+
dataset = load_dataset_with_mapping(dataset_config)
|
1082 |
+
log_info(f"Dataset loaded with {len(dataset)} examples")
|
1083 |
|
1084 |
+
# Minimal validation before proceeding
|
1085 |
+
if dataset is None or len(dataset) == 0:
|
1086 |
+
logger.error("Dataset is empty or None! Cannot proceed with training.")
|
|
|
|
|
|
|
|
|
1087 |
return 1
|
1088 |
|
1089 |
# Create data collator
|
transformers_config.json
CHANGED
@@ -134,10 +134,11 @@
|
|
134 |
"name": "George-API/cognitive-data",
|
135 |
"split": "train",
|
136 |
"column_mapping": {
|
137 |
-
"conversations": "text"
|
|
|
138 |
},
|
139 |
"processing": {
|
140 |
-
"
|
141 |
"maintain_paper_order": true,
|
142 |
"preserve_entry_sequence": true,
|
143 |
"max_seq_length": 2048
|
@@ -152,9 +153,9 @@
|
|
152 |
"user": "Human: {content}\n\n"
|
153 |
},
|
154 |
"metadata_handling": {
|
155 |
-
"
|
156 |
-
"
|
157 |
-
"metadata_format": "
|
158 |
}
|
159 |
},
|
160 |
"data_loading": {
|
|
|
134 |
"name": "George-API/cognitive-data",
|
135 |
"split": "train",
|
136 |
"column_mapping": {
|
137 |
+
"conversations": "text",
|
138 |
+
"article_id": "id"
|
139 |
},
|
140 |
"processing": {
|
141 |
+
"sort_by_article_id": true,
|
142 |
"maintain_paper_order": true,
|
143 |
"preserve_entry_sequence": true,
|
144 |
"max_seq_length": 2048
|
|
|
153 |
"user": "Human: {content}\n\n"
|
154 |
},
|
155 |
"metadata_handling": {
|
156 |
+
"include_article_id": true,
|
157 |
+
"include_prompt_number": true,
|
158 |
+
"metadata_format": "Article ID: {article_id} | Prompt: {prompt_number}"
|
159 |
}
|
160 |
},
|
161 |
"data_loading": {
|