Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- dataset_config.json +4 -4
- run_transformers_training.py +116 -8
dataset_config.json
CHANGED
@@ -3,8 +3,7 @@
|
|
3 |
"name": "George-API/cognitive-data",
|
4 |
"split": "train",
|
5 |
"column_mapping": {
|
6 |
-
"
|
7 |
-
"id": "id"
|
8 |
},
|
9 |
"processing": {
|
10 |
"sort_by_id": true,
|
@@ -17,7 +16,8 @@
|
|
17 |
"roles": {
|
18 |
"system": "System: {content}\n\n",
|
19 |
"human": "Human: {content}\n\n",
|
20 |
-
"assistant": "Assistant: {content}\n\n"
|
|
|
21 |
},
|
22 |
"metadata_handling": {
|
23 |
"include_paper_id": true,
|
@@ -29,7 +29,7 @@
|
|
29 |
"batch_size": 24,
|
30 |
"shuffle": false,
|
31 |
"drop_last": false,
|
32 |
-
"num_workers":
|
33 |
"pin_memory": true,
|
34 |
"prefetch_factor": 4
|
35 |
},
|
|
|
3 |
"name": "George-API/cognitive-data",
|
4 |
"split": "train",
|
5 |
"column_mapping": {
|
6 |
+
"conversations": "text"
|
|
|
7 |
},
|
8 |
"processing": {
|
9 |
"sort_by_id": true,
|
|
|
16 |
"roles": {
|
17 |
"system": "System: {content}\n\n",
|
18 |
"human": "Human: {content}\n\n",
|
19 |
+
"assistant": "Assistant: {content}\n\n",
|
20 |
+
"user": "Human: {content}\n\n"
|
21 |
},
|
22 |
"metadata_handling": {
|
23 |
"include_paper_id": true,
|
|
|
29 |
"batch_size": 24,
|
30 |
"shuffle": false,
|
31 |
"drop_last": false,
|
32 |
+
"num_workers": 4,
|
33 |
"pin_memory": true,
|
34 |
"prefetch_factor": 4
|
35 |
},
|
run_transformers_training.py
CHANGED
@@ -208,15 +208,51 @@ def load_dataset_with_mapping(dataset_config):
|
|
208 |
logger.info(f"Loading dataset {dataset_name}, split {dataset_split}")
|
209 |
dataset = load_dataset(dataset_name, split=dataset_split)
|
210 |
|
211 |
-
# Map columns if specified
|
212 |
column_mapping = dataset_config.get("dataset", {}).get("column_mapping", {})
|
213 |
if column_mapping:
|
214 |
-
logger.info(f"
|
215 |
|
216 |
-
#
|
|
|
217 |
for target, source in column_mapping.items():
|
218 |
if source in dataset.column_names:
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
# Sort dataset if required
|
222 |
sort_by_id = dataset_config.get("dataset", {}).get("processing", {}).get("sort_by_id", False)
|
@@ -227,8 +263,14 @@ def load_dataset_with_mapping(dataset_config):
|
|
227 |
# Log the first few IDs to verify sorting
|
228 |
sample_ids = [example['id'] for example in dataset.select(range(min(5, len(dataset))))]
|
229 |
logger.info(f"First few IDs after sorting: {sample_ids}")
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
|
|
|
232 |
return dataset
|
233 |
|
234 |
except Exception as e:
|
@@ -243,11 +285,13 @@ def format_phi_chat(messages, dataset_config):
|
|
243 |
roles = dataset_config.get("data_formatting", {}).get("roles", {
|
244 |
"system": "System: {content}\n\n",
|
245 |
"human": "Human: {content}\n\n",
|
|
|
246 |
"assistant": "Assistant: {content}\n\n"
|
247 |
})
|
248 |
|
249 |
# Handle research introduction metadata first
|
250 |
-
metadata = next((msg for msg in messages if
|
|
|
251 |
if metadata:
|
252 |
system_template = roles.get("system", "System: {content}\n\n")
|
253 |
formatted_chat = system_template.format(content=metadata['content'])
|
@@ -255,20 +299,29 @@ def format_phi_chat(messages, dataset_config):
|
|
255 |
|
256 |
# Process remaining messages
|
257 |
for message in messages:
|
|
|
|
|
|
|
|
|
258 |
role = message.get("role", "").lower()
|
259 |
content = message.get("content", "")
|
260 |
|
261 |
# Format based on role
|
262 |
if role == "human" or role == "user":
|
263 |
-
template = roles.get("human", "Human: {content}\n\n")
|
264 |
formatted_chat += template.format(content=content)
|
265 |
-
elif role == "assistant":
|
266 |
template = roles.get("assistant", "Assistant: {content}\n\n")
|
267 |
formatted_chat += template.format(content=content)
|
268 |
elif role == "system":
|
269 |
# For system messages, prepend them
|
270 |
template = roles.get("system", "System: {content}\n\n")
|
271 |
formatted_chat = template.format(content=content) + formatted_chat
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
return formatted_chat.strip()
|
274 |
|
@@ -284,8 +337,56 @@ class SimpleDataCollator:
|
|
284 |
self.include_metadata = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_paper_id", True)
|
285 |
self.include_chunk = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_chunk_number", True)
|
286 |
self.metadata_format = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("metadata_format", "Paper ID: {paper_id} | Chunk: {chunk_number}")
|
|
|
287 |
logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}")
|
288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
def __call__(self, features):
|
290 |
batch = {"input_ids": [], "attention_mask": [], "labels": []}
|
291 |
|
@@ -293,7 +394,12 @@ class SimpleDataCollator:
|
|
293 |
try:
|
294 |
# Get ID and conversation fields
|
295 |
paper_id = example.get("id", "")
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
if not conversation:
|
299 |
self.stats["skipped"] += 1
|
@@ -346,10 +452,12 @@ class SimpleDataCollator:
|
|
346 |
logger.info(f"Paper ID: {paper_id} | Chunk: {self.paper_counters[paper_id]}")
|
347 |
logger.info(f"Token count: {len(inputs['input_ids'])}")
|
348 |
logger.info(f"Content preview:\n{formatted_content[:500]}...")
|
|
|
349 |
else:
|
350 |
self.stats["skipped"] += 1
|
351 |
except Exception as e:
|
352 |
logger.warning(f"Error processing example: {str(e)[:100]}...")
|
|
|
353 |
self.stats["skipped"] += 1
|
354 |
continue
|
355 |
|
|
|
208 |
logger.info(f"Loading dataset {dataset_name}, split {dataset_split}")
|
209 |
dataset = load_dataset(dataset_name, split=dataset_split)
|
210 |
|
211 |
+
# Map columns if specified - with checks to avoid conflicts
|
212 |
column_mapping = dataset_config.get("dataset", {}).get("column_mapping", {})
|
213 |
if column_mapping:
|
214 |
+
logger.info(f"Checking column mapping: {column_mapping}")
|
215 |
|
216 |
+
# Only apply mappings for columns that need renaming and don't already exist
|
217 |
+
safe_mappings = {}
|
218 |
for target, source in column_mapping.items():
|
219 |
if source in dataset.column_names:
|
220 |
+
# Skip if target already exists and is not the same as source
|
221 |
+
if target in dataset.column_names and target != source:
|
222 |
+
logger.warning(f"Cannot rename '{source}' to '{target}' - target column already exists")
|
223 |
+
else:
|
224 |
+
safe_mappings[source] = target
|
225 |
+
|
226 |
+
# Apply safe renames
|
227 |
+
if safe_mappings:
|
228 |
+
logger.info(f"Applying safe column mapping: {safe_mappings}")
|
229 |
+
for source, target in safe_mappings.items():
|
230 |
+
if source != target: # Only rename if names are different
|
231 |
+
dataset = dataset.rename_column(source, target)
|
232 |
+
|
233 |
+
# Verify expected columns exist
|
234 |
+
expected_columns = {"id", "conversations"}
|
235 |
+
for col in expected_columns:
|
236 |
+
if col not in dataset.column_names:
|
237 |
+
# If "conversations" is missing but "text" exists, it might need conversion
|
238 |
+
if col == "conversations" and "text" in dataset.column_names:
|
239 |
+
logger.info("Converting 'text' field to 'conversations' format")
|
240 |
+
|
241 |
+
def convert_text_to_conversations(example):
|
242 |
+
# Check if text is already a list of conversation turns
|
243 |
+
if isinstance(example.get("text"), list):
|
244 |
+
return {"conversations": example["text"]}
|
245 |
+
# Otherwise, create a simple conversation with the text as user message
|
246 |
+
else:
|
247 |
+
return {
|
248 |
+
"conversations": [
|
249 |
+
{"role": "user", "content": str(example.get("text", ""))}
|
250 |
+
]
|
251 |
+
}
|
252 |
+
|
253 |
+
dataset = dataset.map(convert_text_to_conversations)
|
254 |
+
else:
|
255 |
+
logger.warning(f"Expected column '{col}' not found in dataset")
|
256 |
|
257 |
# Sort dataset if required
|
258 |
sort_by_id = dataset_config.get("dataset", {}).get("processing", {}).get("sort_by_id", False)
|
|
|
263 |
# Log the first few IDs to verify sorting
|
264 |
sample_ids = [example['id'] for example in dataset.select(range(min(5, len(dataset))))]
|
265 |
logger.info(f"First few IDs after sorting: {sample_ids}")
|
266 |
+
|
267 |
+
# Log example of conversations structure to verify format
|
268 |
+
if "conversations" in dataset.column_names:
|
269 |
+
sample_conv = dataset["conversations"][0] if len(dataset) > 0 else []
|
270 |
+
logger.info(f"Example conversation structure: {sample_conv}")
|
271 |
|
272 |
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
|
273 |
+
logger.info(f"Dataset columns: {dataset.column_names}")
|
274 |
return dataset
|
275 |
|
276 |
except Exception as e:
|
|
|
285 |
roles = dataset_config.get("data_formatting", {}).get("roles", {
|
286 |
"system": "System: {content}\n\n",
|
287 |
"human": "Human: {content}\n\n",
|
288 |
+
"user": "Human: {content}\n\n",
|
289 |
"assistant": "Assistant: {content}\n\n"
|
290 |
})
|
291 |
|
292 |
# Handle research introduction metadata first
|
293 |
+
metadata = next((msg for msg in messages if isinstance(msg, dict) and
|
294 |
+
"[RESEARCH INTRODUCTION]" in msg.get("content", "")), None)
|
295 |
if metadata:
|
296 |
system_template = roles.get("system", "System: {content}\n\n")
|
297 |
formatted_chat = system_template.format(content=metadata['content'])
|
|
|
299 |
|
300 |
# Process remaining messages
|
301 |
for message in messages:
|
302 |
+
if not isinstance(message, dict) or "content" not in message:
|
303 |
+
logger.warning(f"Skipping invalid message format: {message}")
|
304 |
+
continue
|
305 |
+
|
306 |
role = message.get("role", "").lower()
|
307 |
content = message.get("content", "")
|
308 |
|
309 |
# Format based on role
|
310 |
if role == "human" or role == "user":
|
311 |
+
template = roles.get("user", roles.get("human", "Human: {content}\n\n"))
|
312 |
formatted_chat += template.format(content=content)
|
313 |
+
elif role == "assistant" or role == "bot":
|
314 |
template = roles.get("assistant", "Assistant: {content}\n\n")
|
315 |
formatted_chat += template.format(content=content)
|
316 |
elif role == "system":
|
317 |
# For system messages, prepend them
|
318 |
template = roles.get("system", "System: {content}\n\n")
|
319 |
formatted_chat = template.format(content=content) + formatted_chat
|
320 |
+
else:
|
321 |
+
# Default to system for unknown roles
|
322 |
+
logger.warning(f"Unknown role '{role}' - treating as system message")
|
323 |
+
template = roles.get("system", "System: {content}\n\n")
|
324 |
+
formatted_chat += template.format(content=content)
|
325 |
|
326 |
return formatted_chat.strip()
|
327 |
|
|
|
337 |
self.include_metadata = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_paper_id", True)
|
338 |
self.include_chunk = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_chunk_number", True)
|
339 |
self.metadata_format = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("metadata_format", "Paper ID: {paper_id} | Chunk: {chunk_number}")
|
340 |
+
self.roles = dataset_config.get("data_formatting", {}).get("roles", {})
|
341 |
logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}")
|
342 |
|
343 |
+
def normalize_conversation(self, conversation):
|
344 |
+
"""Normalize conversation format to ensure consistent structure."""
|
345 |
+
normalized = []
|
346 |
+
|
347 |
+
# Handle non-list or empty inputs
|
348 |
+
if not isinstance(conversation, list):
|
349 |
+
logger.warning(f"Conversation is not a list: {type(conversation)}")
|
350 |
+
if hasattr(conversation, 'items'): # It's a dict-like object
|
351 |
+
conversation = [conversation]
|
352 |
+
else:
|
353 |
+
return []
|
354 |
+
|
355 |
+
for turn in conversation:
|
356 |
+
# Skip empty or None entries
|
357 |
+
if not turn:
|
358 |
+
continue
|
359 |
+
|
360 |
+
# Handle string entries (convert to user message)
|
361 |
+
if isinstance(turn, str):
|
362 |
+
normalized.append({"role": "user", "content": turn})
|
363 |
+
continue
|
364 |
+
|
365 |
+
# Handle dict-like entries
|
366 |
+
if not isinstance(turn, dict) and hasattr(turn, 'get'):
|
367 |
+
# Convert to dict
|
368 |
+
turn = {k: turn.get(k) for k in ['role', 'content'] if hasattr(turn, 'get') and turn.get(k) is not None}
|
369 |
+
|
370 |
+
# Ensure both role and content exist
|
371 |
+
if not isinstance(turn, dict) or 'role' not in turn or 'content' not in turn:
|
372 |
+
logger.warning(f"Skipping malformatted conversation turn: {turn}")
|
373 |
+
continue
|
374 |
+
|
375 |
+
# Normalize role field
|
376 |
+
role = turn.get('role', '').lower()
|
377 |
+
if role == 'user' or role == 'human':
|
378 |
+
role = 'user'
|
379 |
+
elif role == 'assistant' or role == 'bot':
|
380 |
+
role = 'assistant'
|
381 |
+
|
382 |
+
# Add normalized turn
|
383 |
+
normalized.append({
|
384 |
+
"role": role,
|
385 |
+
"content": str(turn.get('content', ''))
|
386 |
+
})
|
387 |
+
|
388 |
+
return normalized
|
389 |
+
|
390 |
def __call__(self, features):
|
391 |
batch = {"input_ids": [], "attention_mask": [], "labels": []}
|
392 |
|
|
|
394 |
try:
|
395 |
# Get ID and conversation fields
|
396 |
paper_id = example.get("id", "")
|
397 |
+
|
398 |
+
# Handle conversation field - could be under 'conversations' or 'text'
|
399 |
+
conversation = example.get("conversations", example.get("text", []))
|
400 |
+
|
401 |
+
# Normalize conversation format
|
402 |
+
conversation = self.normalize_conversation(conversation)
|
403 |
|
404 |
if not conversation:
|
405 |
self.stats["skipped"] += 1
|
|
|
452 |
logger.info(f"Paper ID: {paper_id} | Chunk: {self.paper_counters[paper_id]}")
|
453 |
logger.info(f"Token count: {len(inputs['input_ids'])}")
|
454 |
logger.info(f"Content preview:\n{formatted_content[:500]}...")
|
455 |
+
logger.info(f"Conversation structure: {conversation[:2]}...")
|
456 |
else:
|
457 |
self.stats["skipped"] += 1
|
458 |
except Exception as e:
|
459 |
logger.warning(f"Error processing example: {str(e)[:100]}...")
|
460 |
+
logger.warning(f"Problematic example: {str(example)[:200]}...")
|
461 |
self.stats["skipped"] += 1
|
462 |
continue
|
463 |
|