Spaces:
Sleeping
Sleeping
File size: 11,710 Bytes
adb15f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
def format_phi_chat(messages, dataset_config):
"""Format messages according to phi-4's chat template and dataset config."""
formatted_chat = ""
# Get role templates from config
roles = dataset_config.get("data_formatting", {}).get("roles", {
"system": "System: {content}\n\n",
"human": "Human: {content}\n\n",
"user": "Human: {content}\n\n",
"assistant": "Assistant: {content}\n\n"
})
# Handle research introduction metadata first
metadata = next((msg for msg in messages if isinstance(msg, dict) and
"[RESEARCH INTRODUCTION]" in msg.get("content", "")), None)
if metadata:
system_template = roles.get("system", "System: {content}\n\n")
formatted_chat = system_template.format(content=metadata['content'])
messages = [msg for msg in messages if msg != metadata]
# Process remaining messages
for message in messages:
if not isinstance(message, dict) or "content" not in message:
logger.warning(f"Skipping invalid message format: {message}")
continue
role = message.get("role", "").lower()
content = message.get("content", "")
# Format based on role
if role == "human" or role == "user":
template = roles.get("user", roles.get("human", "Human: {content}\n\n"))
formatted_chat += template.format(content=content)
elif role == "assistant" or role == "bot":
template = roles.get("assistant", "Assistant: {content}\n\n")
formatted_chat += template.format(content=content)
elif role == "system":
# For system messages, prepend them
template = roles.get("system", "System: {content}\n\n")
formatted_chat = template.format(content=content) + formatted_chat
else:
# Default to system for unknown roles
logger.warning(f"Unknown role '{role}' - treating as system message")
template = roles.get("system", "System: {content}\n\n")
formatted_chat += template.format(content=content)
return formatted_chat.strip()
class SimpleDataCollator:
def __init__(self, tokenizer, dataset_config):
self.tokenizer = tokenizer
self.dataset_config = dataset_config
self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
logger.info(f"SimpleDataCollator initialized - using pre-audited dataset with max_seq_length={self.max_seq_length}")
logger.info("Using exact dataset structure without reformatting")
# Check if we're on GPU
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"SimpleDataCollator using device: {self.device}")
def __call__(self, features):
"""Process examples preserving exact JSONL structure"""
batch = {"input_ids": [], "attention_mask": [], "labels": []}
for example in features:
try:
# Get ID
paper_id = example.get("id", "")
# Get conversations - these should already contain role and content
conversations = example.get("conversations", [])
if not conversations:
self.stats["skipped"] += 1
continue
# Directly use the conversations array as input to the model's chat template
# This preserves the exact structure with roles and content as they are
try:
# Let tokenizer handle the content with the model's chat template
inputs = self.tokenizer.apply_chat_template(
conversations,
return_tensors=None,
add_generation_prompt=False
)
except Exception as chat_error:
# Fallback if apply_chat_template fails
logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)[:100]}")
# Create a basic representation of the conversation
conversation_text = ""
for msg in conversations:
if isinstance(msg, dict) and 'content' in msg:
conversation_text += msg.get('content', '') + "\n\n"
# Basic tokenization
inputs = self.tokenizer(
conversation_text,
add_special_tokens=True,
return_tensors=None
)
# Apply length cap if needed (shouldn't be necessary for pre-audited data)
if self.max_seq_length > 0 and len(inputs) > self.max_seq_length:
logger.warning(f"Example {paper_id} exceeds max_seq_length ({len(inputs)} > {self.max_seq_length})")
inputs = inputs[:self.max_seq_length]
# Create attention mask (1 for all tokens)
attention_mask = [1] * len(inputs)
if len(inputs) > 0:
# For causal language modeling, labels are the same as inputs
labels = inputs.copy()
batch["input_ids"].append(inputs)
batch["attention_mask"].append(attention_mask)
batch["labels"].append(labels)
self.stats["processed"] += 1
self.stats["total_tokens"] += len(inputs)
# Debug logging for first few examples
log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3)
if self.stats["processed"] <= log_samples:
logger.info(f"Example {self.stats['processed']}:")
logger.info(f"Paper ID: {paper_id}")
logger.info(f"Token count: {len(inputs)}")
logger.info(f"Conversation entries: {len(conversations)}")
else:
self.stats["skipped"] += 1
except Exception as e:
logger.warning(f"Error processing example: {str(e)[:100]}...")
logger.warning(f"Problematic example ID: {example.get('id', 'unknown')}")
self.stats["skipped"] += 1
continue
if not batch["input_ids"]:
logger.warning("Empty batch, returning dummy tensors")
return {
"input_ids": torch.zeros((1, 1), dtype=torch.long),
"attention_mask": torch.zeros((1, 1), dtype=torch.long),
"labels": torch.zeros((1, 1), dtype=torch.long)
}
# Pad the batch
max_length = max(len(ids) for ids in batch["input_ids"])
for i in range(len(batch["input_ids"])):
padding_length = max_length - len(batch["input_ids"][i])
if padding_length > 0:
batch["input_ids"][i].extend([self.pad_token_id] * padding_length)
batch["attention_mask"][i].extend([0] * padding_length)
batch["labels"][i].extend([-100] * padding_length)
# Convert to tensors
batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()}
# Log stats periodically
log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100)
if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0:
logger.info(f"Data collator stats: processed={self.stats['processed']}, "
f"skipped={self.stats['skipped']}, "
f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}")
return batch
class LoggingCallback(TrainerCallback):
def __init__(self):
self.last_log_time = time.time()
self.last_memory_log_time = time.time()
def on_step_end(self, args, state, control, **kwargs):
# Log every 50 steps or every 5 minutes, whichever comes first
current_time = time.time()
# Log loss every 50 steps or 5 minutes
if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
if state.log_history:
loss = state.log_history[-1].get('loss', 'N/A')
# Use simple formatting for better HF Space log compatibility
log_info(f"Step {state.global_step}: Loss {loss}")
else:
log_info(f"Step {state.global_step}: No loss data available")
self.last_log_time = current_time
# Log memory usage every 15 minutes
if current_time - self.last_memory_log_time > 900: # 15 minutes
if torch.cuda.is_available():
memory_info = []
for i in range(torch.cuda.device_count()):
allocated = torch.cuda.memory_allocated(i) / 1024**2
reserved = torch.cuda.memory_reserved(i) / 1024**2
memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB")
# Log in compact format for better visibility
log_info(f"Memory usage - {', '.join(memory_info)}")
self.last_memory_log_time = current_time
def on_train_begin(self, args, state, control, **kwargs):
log_info("=== Training is starting ===")
# Log important training parameters for visibility
log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {max(1, torch.cuda.device_count())} GPUs")
log_info(f"Learning rate: {args.learning_rate}")
log_info(f"Epochs: {args.num_train_epochs}")
# Log memory information in compact format
if torch.cuda.is_available():
memory_info = []
for i in range(torch.cuda.device_count()):
allocated = torch.cuda.memory_allocated(i) / 1024**2
max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
log_info(f"Initial memory usage - {', '.join(memory_info)}")
def on_train_end(self, args, state, control, **kwargs):
log_info("=== Training completed ===")
if torch.cuda.is_available():
memory_info = []
for i in range(torch.cuda.device_count()):
allocated = torch.cuda.memory_allocated(i) / 1024**2
max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
log_info(f"Final memory usage - {', '.join(memory_info)}")
log_info(f"Total steps: {state.global_step}")
log_info(f"Final loss: {state.log_history[-1].get('loss', 'N/A') if state.log_history else 'N/A'}") |