|
import os |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
from datasets import load_dataset, disable_caching, concatenate_datasets |
|
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, processors, decoders |
|
import math |
|
import re |
|
from datetime import datetime |
|
from contextlib import nullcontext |
|
from collections import defaultdict |
|
import logging |
|
import random |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
force=True |
|
) |
|
|
|
|
|
CONFIG = { |
|
|
|
"dim": 768, |
|
"n_layers": 16, |
|
"n_heads": 16, |
|
"ff_dim": 3072, |
|
|
|
|
|
"dropout": 0.1, |
|
"max_seq_len": 512, |
|
"vocab_size": 32000, |
|
|
|
|
|
"batch_size": 12, |
|
"checkpoint_interval": 2000, |
|
"debug_interval": 400, |
|
|
|
"datasets": ["daily_dialog", "empathetic_dialogues", "blended_skill_talk", "AlekseyKorshuk/persona-chat"], |
|
"tokenizer_name": "hrom_tokenizer.json", |
|
"checkpoint_dir": "checkpoints", |
|
|
|
"tokenizer_train_samples_per_dataset": 100000, |
|
"learning_rate": 1e-5, |
|
"warmup_steps": 1000, |
|
"max_turns": 8, |
|
"max_checkpoints": 5, |
|
"num_epochs": 30, |
|
"grad_accum_steps": 16 |
|
} |
|
|
|
|
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
def forward(self, seq_len): |
|
t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq) |
|
freqs = torch.einsum("i, j -> i j", t, self.inv_freq) |
|
if seq_len == 0: |
|
return torch.empty((0, self.inv_freq.shape[0] * 2), device=self.inv_freq.device) |
|
|
|
if freqs.shape[0] != seq_len and seq_len > 0: |
|
freqs = freqs.reshape(seq_len, -1) |
|
elif seq_len == 0: |
|
return torch.empty((0, self.inv_freq.shape[0]*2), device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
|
|
|
return torch.cat((freqs, freqs), dim=-1) |
|
|
|
def rotate_half(x): |
|
x1, x2 = x.chunk(2, dim=-1) |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
def apply_rotary_pos_emb(pos, t): |
|
|
|
pos = pos.to(t.device, dtype=t.dtype) |
|
pos = pos.unsqueeze(0).unsqueeze(1) |
|
tensor_seq_len = t.shape[2] |
|
pos_seq_len = pos.shape[2] |
|
|
|
if pos_seq_len < tensor_seq_len: |
|
logging.warning(f"RoPE Warning: pos sequence length ({pos_seq_len}) is shorter than tensor sequence length ({tensor_seq_len}). Using truncated tensor length for RoPE.") |
|
|
|
|
|
t_rotated = t[:, :, :pos_seq_len, :] |
|
pos = pos[:, :, :pos_seq_len, :] |
|
|
|
|
|
cos_pos = pos.cos() |
|
sin_pos = pos.sin() |
|
t_rotated = (t_rotated * cos_pos) + (rotate_half(t_rotated) * sin_pos) |
|
|
|
|
|
t_unrotated = t[:, :, pos_seq_len:, :] |
|
return torch.cat([t_rotated, t_unrotated], dim=2) |
|
|
|
elif pos_seq_len > tensor_seq_len: |
|
pos = pos[:, :, :tensor_seq_len, :] |
|
|
|
|
|
if pos.shape[-1] != t.shape[-1]: |
|
logging.error(f"Mismatched dimensions for RoPE: pos ({pos.shape[-1]}) vs t ({t.shape[-1]})") |
|
raise ValueError("Rotary embedding dimension must match head dimension.") |
|
|
|
cos_pos = pos.cos() |
|
sin_pos = pos.sin() |
|
rotated_t = (t * cos_pos) + (rotate_half(t) * sin_pos) |
|
return rotated_t |
|
|
|
|
|
class SwiGLU(nn.Module): |
|
def forward(self, x): |
|
x, gate = x.chunk(2, dim=-1) |
|
return x * nn.functional.gelu(gate) |
|
|
|
class HROMAttention(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.dim = CONFIG["dim"] |
|
self.n_heads = CONFIG["n_heads"] |
|
self.head_dim = self.dim // self.n_heads |
|
if self.dim % self.n_heads != 0: |
|
raise ValueError("dim must be divisible by n_heads") |
|
self.qkv = nn.Linear(self.dim, 3 * self.dim) |
|
self.proj = nn.Linear(self.dim, self.dim) |
|
self.rotary = RotaryEmbedding(self.head_dim) |
|
self.dropout = nn.Dropout(CONFIG["dropout"]) |
|
|
|
def forward(self, x, mask=None): |
|
B, T, C = x.shape |
|
qkv = self.qkv(x) |
|
qkv = qkv.reshape(B, T, 3, self.n_heads, self.head_dim) |
|
q, k, v = qkv.unbind(2) |
|
q = q.transpose(1, 2) |
|
k = k.transpose(1, 2) |
|
v = v.transpose(1, 2) |
|
|
|
pos = self.rotary(T) |
|
|
|
q = apply_rotary_pos_emb(pos, q) |
|
k = apply_rotary_pos_emb(pos, k) |
|
|
|
attn_scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) |
|
if mask is not None: |
|
|
|
if mask.dim() == 2: |
|
mask = mask.unsqueeze(1).unsqueeze(2) |
|
elif mask.dim() == 3: |
|
mask = mask.unsqueeze(1) |
|
|
|
attn_scores = attn_scores + mask |
|
|
|
attn_probs = torch.softmax(attn_scores.float(), dim=-1).to(dtype=x.dtype) |
|
attn_probs = self.dropout(attn_probs) |
|
|
|
output = attn_probs @ v |
|
output = output.transpose(1, 2).reshape(B, T, self.dim) |
|
return self.proj(output) |
|
|
|
|
|
class HROMBlock(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.attn = HROMAttention() |
|
self.ff = nn.Sequential( |
|
nn.Linear(CONFIG["dim"], 2 * CONFIG["ff_dim"]), |
|
SwiGLU(), |
|
nn.Linear(CONFIG["ff_dim"], CONFIG["dim"]) |
|
) |
|
self.norm1 = nn.LayerNorm(CONFIG["dim"]) |
|
self.norm2 = nn.LayerNorm(CONFIG["dim"]) |
|
self.dropout = nn.Dropout(CONFIG["dropout"]) |
|
|
|
def forward(self, x, mask=None): |
|
|
|
normed_x = self.norm1(x) |
|
attn_output = self.attn(normed_x, mask) |
|
x = x + self.dropout(attn_output) |
|
|
|
normed_x = self.norm2(x) |
|
ff_output = self.ff(normed_x) |
|
x = x + self.dropout(ff_output) |
|
return x |
|
|
|
class HROM(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.embed = nn.Embedding(CONFIG["vocab_size"], CONFIG["dim"]) |
|
self.blocks = nn.ModuleList([HROMBlock() for _ in range(CONFIG["n_layers"])]) |
|
self.norm = nn.LayerNorm(CONFIG["dim"]) |
|
self.head = nn.Linear(CONFIG["dim"], CONFIG["vocab_size"]) |
|
self.dropout = nn.Dropout(CONFIG["dropout"]) |
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Linear): |
|
torch.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
torch.nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.Embedding): |
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
elif isinstance(module, nn.LayerNorm): |
|
torch.nn.init.zeros_(module.bias) |
|
torch.nn.init.ones_(module.weight) |
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
B, T = input_ids.shape |
|
x = self.embed(input_ids) |
|
x = self.dropout(x) |
|
|
|
|
|
combined_mask = None |
|
|
|
causal_mask = torch.triu(torch.ones(T, T, device=input_ids.device) * float('-inf'), diagonal=1) |
|
combined_mask = causal_mask.unsqueeze(0).unsqueeze(1) |
|
|
|
if attention_mask is not None: |
|
|
|
|
|
pad_mask = (1.0 - attention_mask.to(torch.float32)) * torch.finfo(torch.float32).min |
|
pad_mask = pad_mask.unsqueeze(1).unsqueeze(2) |
|
|
|
|
|
combined_mask = combined_mask + pad_mask |
|
|
|
|
|
combined_mask = combined_mask.to(dtype=x.dtype) |
|
|
|
for block in self.blocks: |
|
x = block(x, combined_mask) |
|
|
|
x = self.norm(x) |
|
logits = self.head(x) |
|
return logits |
|
|
|
|
|
|
|
class TokenizerTrainer: |
|
def __init__(self): |
|
self.tokenizer = Tokenizer(models.BPE()) |
|
self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) |
|
self.tokenizer.decoder = decoders.ByteLevel() |
|
self.special_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<user>", "<assistant>"] |
|
|
|
self.tokenizer_path = os.path.join("tokenizer", CONFIG["tokenizer_name"]) |
|
self.tokenizer_dir = os.path.dirname(self.tokenizer_path) |
|
|
|
def _clean_text(self, text): |
|
text = str(text) |
|
text = re.sub(r'_comma_', ',', text) |
|
|
|
text = re.sub(r'[^\w\s.,!?\'\-:;<>"]', '', text) |
|
text = re.sub(r'\s+', ' ', text).strip() |
|
return text |
|
|
|
def train(self, dataset_names): |
|
logging.info("Starting tokenizer training...") |
|
text_samples = [] |
|
samples_per_dataset = CONFIG['tokenizer_train_samples_per_dataset'] |
|
|
|
|
|
if "daily_dialog" in dataset_names: |
|
logging.info(f"Loading daily_dialog for tokenizer training (max {samples_per_dataset} dialogues)...") |
|
try: |
|
|
|
dd_dataset = load_dataset("daily_dialog", split=f"train[:{samples_per_dataset}]", trust_remote_code=True) |
|
logging.info("Processing daily_dialog...") |
|
for entry in dd_dataset: |
|
formatted_dialogue = [] |
|
dialogue = entry['dialog'][:CONFIG["max_turns"]] |
|
for i, utterance in enumerate(dialogue): |
|
role = "<user>" if i % 2 == 0 else "<assistant>" |
|
cleaned_utterance = self._clean_text(utterance) |
|
if cleaned_utterance: |
|
formatted_dialogue.append(f"{role} {cleaned_utterance}") |
|
if formatted_dialogue: |
|
text_samples.append(" </s> ".join(formatted_dialogue)) |
|
except Exception as e: |
|
logging.error(f"Failed to load or process daily_dialog for tokenizer: {e}") |
|
|
|
|
|
if "empathetic_dialogues" in dataset_names: |
|
logging.info(f"Loading empathetic_dialogues for tokenizer training (max {samples_per_dataset} dialogues)...") |
|
try: |
|
|
|
ed_dataset = load_dataset("empathetic_dialogues", split=f"train[:{samples_per_dataset * 3}]", trust_remote_code=True) |
|
logging.info("Processing empathetic_dialogues...") |
|
conversations = defaultdict(list) |
|
processed_conv_count = 0 |
|
|
|
grouped_by_conv = defaultdict(list) |
|
for entry in ed_dataset: |
|
grouped_by_conv[entry['conv_id']].append(entry) |
|
|
|
|
|
for conv_id, entries in grouped_by_conv.items(): |
|
if processed_conv_count >= samples_per_dataset: |
|
break |
|
|
|
sorted_entries = sorted(entries, key=lambda x: x['utterance_idx']) |
|
formatted_dialogue = [] |
|
|
|
if sorted_entries[0]['context']: |
|
cleaned_context = self._clean_text(sorted_entries[0]['context']) |
|
if cleaned_context: |
|
formatted_dialogue.append(f"<user> {cleaned_context}") |
|
|
|
last_role = '<user>' if formatted_dialogue else None |
|
for entry in sorted_entries: |
|
cleaned_utterance = self._clean_text(entry['utterance']) |
|
if cleaned_utterance: |
|
|
|
current_role = '<assistant>' if last_role == '<user>' else '<user>' |
|
formatted_dialogue.append(f"{current_role} {cleaned_utterance}") |
|
last_role = current_role |
|
|
|
formatted_dialogue = formatted_dialogue[:CONFIG["max_turns"]] |
|
if formatted_dialogue: |
|
text_samples.append(" </s> ".join(formatted_dialogue)) |
|
processed_conv_count += 1 |
|
|
|
except Exception as e: |
|
logging.error(f"Failed to load or process empathetic_dialogues for tokenizer: {e}") |
|
|
|
|
|
|
|
if "blended_skill_talk" in dataset_names: |
|
logging.info(f"Loading blended_skill_talk for tokenizer training (max {samples_per_dataset} dialogues)...") |
|
try: |
|
|
|
bst_dataset = load_dataset("blended_skill_talk", split=f"train[:{samples_per_dataset}]", trust_remote_code=True) |
|
logging.info("Processing blended_skill_talk...") |
|
for entry in bst_dataset: |
|
formatted_dialogue = [] |
|
|
|
dialogue_turns_raw = entry['previous_utterance'] |
|
|
|
if entry.get('free_turker_utterance'): |
|
dialogue_turns_raw.append(entry['free_turker_utterance']) |
|
if entry.get('guided_turker_utterance'): |
|
dialogue_turns_raw.append(entry['guided_turker_utterance']) |
|
|
|
turns_to_process = dialogue_turns_raw[:CONFIG["max_turns"]] |
|
for i, utterance in enumerate(turns_to_process): |
|
role = "<user>" if i % 2 == 0 else "<assistant>" |
|
cleaned_utterance = self._clean_text(utterance) |
|
if cleaned_utterance: |
|
formatted_dialogue.append(f"{role} {cleaned_utterance}") |
|
if formatted_dialogue: |
|
text_samples.append(" </s> ".join(formatted_dialogue)) |
|
except Exception as e: |
|
logging.error(f"Failed to load or process blended_skill_talk for tokenizer: {e}") |
|
|
|
|
|
if "AlekseyKorshuk/persona-chat" in dataset_names: |
|
pc_dataset_name = "AlekseyKorshuk/persona-chat" |
|
logging.info(f"Loading {pc_dataset_name} for tokenizer training (max {samples_per_dataset} dialogues)...") |
|
try: |
|
pc_dataset = load_dataset(pc_dataset_name, split=f"train[:{samples_per_dataset}]", trust_remote_code=True) |
|
logging.info(f"Processing {pc_dataset_name}...") |
|
for entry in pc_dataset: |
|
|
|
if 'utterances' in entry and entry['utterances']: |
|
|
|
history = entry['utterances'][-1]['history'] |
|
history = history[:CONFIG["max_turns"]] |
|
formatted_dialogue = [] |
|
for i, utterance in enumerate(history): |
|
role = "<user>" if i % 2 == 0 else "<assistant>" |
|
cleaned_utterance = self._clean_text(utterance) |
|
if cleaned_utterance: |
|
formatted_dialogue.append(f"{role} {cleaned_utterance}") |
|
if formatted_dialogue: |
|
text_samples.append(" </s> ".join(formatted_dialogue)) |
|
else: |
|
logging.warning(f"Skipping {pc_dataset_name} entry due to unexpected structure: {entry}") |
|
|
|
except Exception as e: |
|
logging.error(f"Failed to load or process {pc_dataset_name} for tokenizer: {e}") |
|
|
|
|
|
logging.info(f"Total text samples for tokenizer training: {len(text_samples)}") |
|
if not text_samples: |
|
raise ValueError("No text samples collected for tokenizer training. Check dataset loading and paths.") |
|
|
|
|
|
os.makedirs(self.tokenizer_dir, exist_ok=True) |
|
|
|
logging.info(f"Training BPE tokenizer with vocab size {CONFIG['vocab_size']}...") |
|
trainer = trainers.BpeTrainer( |
|
vocab_size=CONFIG["vocab_size"], |
|
special_tokens=self.special_tokens, |
|
min_frequency=2, |
|
show_progress=True |
|
) |
|
|
|
def text_iterator(): |
|
for sample in text_samples: |
|
yield sample |
|
|
|
self.tokenizer.train_from_iterator(text_iterator(), trainer=trainer, length=len(text_samples)) |
|
|
|
eos_token_id = self.tokenizer.token_to_id("</s>") |
|
if eos_token_id is None: |
|
logging.warning("</s> token not found in trained tokenizer vocab! Using <pad> as fallback for post-processor.") |
|
eos_token_id = self.tokenizer.token_to_id("<pad>") or 0 |
|
|
|
|
|
self.tokenizer.post_processor = processors.TemplateProcessing( |
|
single="$A </s>", |
|
pair="$A </s> $B </s>", |
|
special_tokens=[("</s>", eos_token_id)], |
|
) |
|
|
|
logging.info(f"Saving tokenizer to {self.tokenizer_path}") |
|
self.tokenizer.save(self.tokenizer_path) |
|
logging.info("Tokenizer training complete.") |
|
|
|
def get_tokenizer(self): |
|
if not os.path.exists(self.tokenizer_path): |
|
raise FileNotFoundError(f"Tokenizer file not found at {self.tokenizer_path}. Train tokenizer first.") |
|
tokenizer = Tokenizer.from_file(self.tokenizer_path) |
|
|
|
required_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<user>", "<assistant>"] |
|
for token in required_tokens: |
|
if tokenizer.token_to_id(token) is None: |
|
raise ValueError(f"Crucial special token '{token}' not found in loaded tokenizer '{self.tokenizer_path}'!") |
|
return tokenizer |
|
|
|
|
|
|
|
class CombinedChatDataset(Dataset): |
|
def __init__(self, tokenizer): |
|
self.tokenizer = tokenizer |
|
self.pad_id = self.tokenizer.token_to_id("<pad>") |
|
self.eos_id = self.tokenizer.token_to_id("</s>") |
|
self.bos_id = self.tokenizer.token_to_id("<s>") |
|
self.user_id = self.tokenizer.token_to_id("<user>") |
|
self.assistant_id = self.tokenizer.token_to_id("<assistant>") |
|
self.max_length = CONFIG["max_seq_len"] |
|
|
|
self._clean_text = TokenizerTrainer()._clean_text |
|
|
|
self.all_processed_conversations = [] |
|
|
|
|
|
if "daily_dialog" in CONFIG["datasets"]: |
|
logging.info("Loading and processing daily_dialog dataset...") |
|
try: |
|
dd_dataset = load_dataset("daily_dialog", split="train", trust_remote_code=True) |
|
logging.info(f"Processing {len(dd_dataset)} daily_dialog conversations...") |
|
for entry in dd_dataset: |
|
conversation = [] |
|
dialogue = entry['dialog'][:CONFIG["max_turns"]] |
|
if not dialogue: continue |
|
for i, utterance in enumerate(dialogue): |
|
role = "<user>" if i % 2 == 0 else "<assistant>" |
|
cleaned_text = self._clean_text(utterance) |
|
if cleaned_text: |
|
conversation.append({'role': role, 'text': cleaned_text}) |
|
if conversation: |
|
self.all_processed_conversations.append(conversation) |
|
except Exception as e: |
|
logging.error(f"Failed to load or process daily_dialog for training: {e}") |
|
|
|
|
|
if "empathetic_dialogues" in CONFIG["datasets"]: |
|
logging.info("Loading and processing empathetic_dialogues dataset...") |
|
try: |
|
ed_dataset = load_dataset("empathetic_dialogues", split="train", trust_remote_code=True) |
|
logging.info("Grouping empathetic_dialogues by conversation ID...") |
|
conversations_grouped = defaultdict(list) |
|
for entry in ed_dataset: |
|
conversations_grouped[entry['conv_id']].append(entry) |
|
|
|
logging.info(f"Processing {len(conversations_grouped)} empathetic_dialogues conversations...") |
|
for conv_id, entries in conversations_grouped.items(): |
|
conversation = [] |
|
sorted_entries = sorted(entries, key=lambda x: x['utterance_idx']) |
|
|
|
if sorted_entries[0]['context']: |
|
context_text = self._clean_text(sorted_entries[0]['context']) |
|
if context_text: |
|
conversation.append({'role': '<user>', 'text': context_text}) |
|
|
|
last_role = conversation[-1]['role'] if conversation else None |
|
for entry in sorted_entries: |
|
text = self._clean_text(entry['utterance']) |
|
if not text: continue |
|
|
|
current_role = '<assistant>' if last_role == '<user>' else '<user>' |
|
conversation.append({'role': current_role, 'text': text}) |
|
last_role = current_role |
|
|
|
|
|
conversation = conversation[:CONFIG["max_turns"]] |
|
if conversation: |
|
self.all_processed_conversations.append(conversation) |
|
|
|
except Exception as e: |
|
logging.error(f"Failed to load or process empathetic_dialogues for training: {e}") |
|
|
|
|
|
if "blended_skill_talk" in CONFIG["datasets"]: |
|
logging.info("Loading and processing blended_skill_talk dataset...") |
|
try: |
|
bst_dataset = load_dataset("blended_skill_talk", split="train", trust_remote_code=True) |
|
logging.info(f"Processing {len(bst_dataset)} blended_skill_talk conversations...") |
|
for entry in bst_dataset: |
|
conversation = [] |
|
|
|
dialogue_turns_raw = entry['previous_utterance'] |
|
if entry.get('free_turker_utterance'): |
|
dialogue_turns_raw.append(entry['free_turker_utterance']) |
|
if entry.get('guided_turker_utterance'): |
|
dialogue_turns_raw.append(entry['guided_turker_utterance']) |
|
|
|
if not dialogue_turns_raw: continue |
|
|
|
turns_to_process = dialogue_turns_raw[:CONFIG["max_turns"]] |
|
|
|
for i, utterance in enumerate(turns_to_process): |
|
role = "<user>" if i % 2 == 0 else "<assistant>" |
|
cleaned_text = self._clean_text(utterance) |
|
if cleaned_text: |
|
conversation.append({'role': role, 'text': cleaned_text}) |
|
if conversation: |
|
self.all_processed_conversations.append(conversation) |
|
except Exception as e: |
|
logging.error(f"Failed to load or process blended_skill_talk for training: {e}") |
|
|
|
|
|
if "AlekseyKorshuk/persona-chat" in CONFIG["datasets"]: |
|
pc_dataset_name = "AlekseyKorshuk/persona-chat" |
|
logging.info(f"Loading and processing {pc_dataset_name} dataset...") |
|
try: |
|
pc_dataset = load_dataset(pc_dataset_name, split="train", trust_remote_code=True) |
|
logging.info(f"Processing {len(pc_dataset)} {pc_dataset_name} conversations...") |
|
for entry in pc_dataset: |
|
conversation = [] |
|
if 'utterances' in entry and entry['utterances']: |
|
|
|
history = entry['utterances'][-1]['history'] |
|
history = history[:CONFIG["max_turns"]] |
|
|
|
for i, utterance in enumerate(history): |
|
role = "<user>" if i % 2 == 0 else "<assistant>" |
|
cleaned_text = self._clean_text(utterance) |
|
if cleaned_text: |
|
conversation.append({'role': role, 'text': cleaned_text}) |
|
|
|
if conversation: |
|
self.all_processed_conversations.append(conversation) |
|
else: |
|
logging.warning(f"Skipping {pc_dataset_name} entry due to unexpected structure: {entry.keys()}") |
|
|
|
except Exception as e: |
|
logging.error(f"Failed to load or process {pc_dataset_name} for training: {e}") |
|
|
|
|
|
logging.info(f"Total processed conversations from all datasets: {len(self.all_processed_conversations)}") |
|
if not self.all_processed_conversations: |
|
raise ValueError("No processed conversations were created from any dataset. Check loading logic and dataset availability.") |
|
|
|
logging.info("Shuffling combined dataset...") |
|
random.shuffle(self.all_processed_conversations) |
|
|
|
|
|
def __len__(self): |
|
return len(self.all_processed_conversations) |
|
|
|
def __getitem__(self, idx): |
|
conversation = self.all_processed_conversations[idx] |
|
formatted_ids = [self.bos_id] |
|
for turn in conversation: |
|
role_id = self.user_id if turn['role'] == '<user>' else self.assistant_id |
|
|
|
try: |
|
utterance_ids = self.tokenizer.encode(turn['text'], add_special_tokens=False).ids |
|
except Exception as e: |
|
logging.error(f"Error encoding text at index {idx}, turn '{turn}': {e}") |
|
utterance_ids = [] |
|
|
|
|
|
|
|
if len(formatted_ids) + 1 + len(utterance_ids) + 1 > self.max_length: |
|
|
|
if len(formatted_ids) + 1 + 1 <= self.max_length: |
|
formatted_ids.append(role_id) |
|
formatted_ids.append(self.eos_id) |
|
break |
|
|
|
formatted_ids.append(role_id) |
|
formatted_ids.extend(utterance_ids) |
|
formatted_ids.append(self.eos_id) |
|
|
|
|
|
if len(formatted_ids) > self.max_length: |
|
formatted_ids = formatted_ids[:self.max_length] |
|
|
|
|
|
if formatted_ids and (formatted_ids[-1] == self.user_id or formatted_ids[-1] == self.assistant_id): |
|
formatted_ids.pop() |
|
|
|
if len(formatted_ids) > self.max_length: |
|
formatted_ids = formatted_ids[:self.max_length] |
|
|
|
|
|
|
|
if len(formatted_ids) < 2: |
|
logging.warning(f"Sequence at index {idx} is too short after processing (<2 tokens). Skipping. Original length: {len(conversation)}") |
|
|
|
return None |
|
|
|
input_ids = formatted_ids[:-1] |
|
labels = formatted_ids[1:] |
|
|
|
|
|
if len(input_ids) == 0: |
|
logging.warning(f"Sequence at index {idx} resulted in empty input_ids after slicing. Skipping.") |
|
return None |
|
|
|
|
|
return {"input_ids": input_ids, "labels": labels} |
|
|
|
@staticmethod |
|
def collate_fn(batch): |
|
|
|
batch = [item for item in batch if item is not None] |
|
if not batch: |
|
return None |
|
|
|
max_len = max(len(item["input_ids"]) for item in batch) |
|
|
|
|
|
try: |
|
|
|
tokenizer_path = os.path.join("tokenizer", CONFIG["tokenizer_name"]) |
|
|
|
tokenizer = Tokenizer.from_file(tokenizer_path) |
|
pad_id = tokenizer.token_to_id("<pad>") |
|
if pad_id is None: raise ValueError("<pad> token not found") |
|
except Exception as e: |
|
logging.error(f"Collate Error: Failed to load tokenizer or get pad_id ('{CONFIG['tokenizer_name']}'): {e}") |
|
pad_id = 0 |
|
|
|
inputs, labels, masks = [], [], [] |
|
for item in batch: |
|
input_len = len(item["input_ids"]) |
|
pad_len = max_len - input_len |
|
inputs.append(item["input_ids"] + [pad_id] * pad_len) |
|
|
|
labels.append(item["labels"] + [pad_id] * pad_len) |
|
masks.append([1] * input_len + [0] * pad_len) |
|
|
|
return { |
|
"input_ids": torch.tensor(inputs, dtype=torch.long), |
|
"labels": torch.tensor(labels, dtype=torch.long), |
|
"attention_mask": torch.tensor(masks, dtype=torch.long) |
|
} |
|
|
|
|
|
|
|
class HROMTrainer: |
|
def __init__(self, model, tokenizer): |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logging.info(f"Using device: {self.device}") |
|
self.model = model.to(self.device) |
|
|
|
self.use_amp = (self.device.type == "cuda" and hasattr(torch.cuda.amp, "GradScaler")) |
|
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None |
|
logging.info(f"Automatic Mixed Precision (AMP): {'Enabled' if self.use_amp else 'Disabled'}") |
|
|
|
self.optimizer = torch.optim.AdamW( |
|
self.model.parameters(), |
|
lr=CONFIG["learning_rate"], |
|
betas=(0.9, 0.95), |
|
weight_decay=0.1, |
|
fused= (self.device.type == "cuda") |
|
) |
|
self.tokenizer = tokenizer |
|
self.pad_id = self.tokenizer.token_to_id("<pad>") |
|
if self.pad_id is None: |
|
|
|
self.pad_id = CONFIG.get("pad_token_id", 0) |
|
logging.warning(f"<pad> token ID not found in tokenizer, using fallback ID: {self.pad_id}") |
|
|
|
|
|
|
|
self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_id) |
|
self.base_lr = CONFIG["learning_rate"] |
|
self.warmup_steps = CONFIG["warmup_steps"] |
|
|
|
def _adjust_learning_rate(self, step): |
|
if self.warmup_steps > 0 and step < self.warmup_steps: |
|
lr = self.base_lr * (step + 1) / self.warmup_steps |
|
else: |
|
|
|
|
|
lr = self.base_lr |
|
for param_group in self.optimizer.param_groups: |
|
param_group['lr'] = lr |
|
return lr |
|
|
|
def train_step(self, batch): |
|
|
|
if self.use_amp: |
|
amp_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 |
|
autocast_context = torch.cuda.amp.autocast(dtype=amp_dtype, enabled=self.use_amp) if self.use_amp else nullcontext() |
|
|
|
with autocast_context: |
|
input_ids = batch["input_ids"].to(self.device) |
|
attention_mask = batch["attention_mask"].to(self.device) |
|
labels = batch["labels"].to(self.device) |
|
|
|
outputs = self.model(input_ids, attention_mask=attention_mask) |
|
|
|
|
|
logits_flat = outputs.view(-1, outputs.size(-1)) |
|
labels_flat = labels.view(-1) |
|
|
|
|
|
loss = self.criterion(logits_flat.float(), labels_flat) |
|
|
|
|
|
scaled_loss = loss / CONFIG["grad_accum_steps"] |
|
|
|
|
|
if self.use_amp and self.scaler: |
|
self.scaler.scale(scaled_loss).backward() |
|
else: |
|
scaled_loss.backward() |
|
|
|
return loss.item() |
|
|
|
def clip_and_step(self, current_optimizer_step): |
|
current_lr = self._adjust_learning_rate(current_optimizer_step) |
|
|
|
if self.use_amp and self.scaler: |
|
|
|
self.scaler.unscale_(self.optimizer) |
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
|
|
self.scaler.step(self.optimizer) |
|
|
|
self.scaler.update() |
|
else: |
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
|
|
self.optimizer.step() |
|
|
|
|
|
self.optimizer.zero_grad(set_to_none=True) |
|
return current_lr |
|
|
|
|
|
class SafetyManager: |
|
|
|
def __init__(self, model, tokenizer): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
|
|
self.bad_words = ["kill", "murder", "suicide", "hate", "abuse", "violence", "illegal", "harm", "die", "attack", "rape", "molest", "exploit", "terror"] |
|
self.bad_word_ids = [] |
|
logging.info("Initializing safety manager...") |
|
|
|
for word in self.bad_words: |
|
|
|
ids = tokenizer.encode(f" {word}", add_special_tokens=False).ids |
|
if ids: |
|
self.bad_word_ids.append(ids) |
|
logging.debug(f"Encoded bad word '{word}' (with space) to IDs: {ids}") |
|
|
|
ids_no_space = tokenizer.encode(word, add_special_tokens=False).ids |
|
if ids_no_space and ids_no_space != ids: |
|
self.bad_word_ids.append(ids_no_space) |
|
logging.debug(f"Encoded bad word '{word}' (no space) to IDs: {ids_no_space}") |
|
|
|
if not ids and not ids_no_space: |
|
logging.warning(f"Could not encode bad word '{word}' - skipping.") |
|
|
|
|
|
self.eos_id = self.tokenizer.token_to_id("</s>") |
|
self.bos_id = self.tokenizer.token_to_id("<s>") |
|
self.user_id = self.tokenizer.token_to_id("<user>") |
|
self.assistant_id = self.tokenizer.token_to_id("<assistant>") |
|
self.pad_id = self.tokenizer.token_to_id("<pad>") |
|
|
|
if self.eos_id is None: logging.error("</s> token ID not found for SafetyManager!"); self.eos_id = 0 |
|
if self.bos_id is None: logging.error("<s> token ID not found for SafetyManager!"); self.bos_id = 0 |
|
if self.user_id is None: logging.error("<user> token ID not found for SafetyManager!") |
|
if self.assistant_id is None: logging.error("<assistant> token ID not found for SafetyManager!") |
|
if self.pad_id is None: logging.error("<pad> token ID not found for SafetyManager!"); self.pad_id = 0 |
|
|
|
|
|
def contains_sequence(self, tokens, seq): |
|
"""Checks if the list `tokens` contains the sublist `seq`.""" |
|
if not seq or not tokens or len(tokens) < len(seq): |
|
return False |
|
seq_len = len(seq) |
|
for i in range(len(tokens) - seq_len + 1): |
|
if tokens[i : i + seq_len] == seq: |
|
return True |
|
return False |
|
|
|
def content_filter(self, text_ids): |
|
"""Checks if a list of token IDs contains any bad word sequences.""" |
|
if not isinstance(text_ids, list): |
|
logging.warning("Content filter received non-list input.") |
|
return True |
|
for bad_ids in self.bad_word_ids: |
|
if self.contains_sequence(text_ids, bad_ids): |
|
|
|
detected_word = self.tokenizer.decode(bad_ids) |
|
logging.warning(f"Unsafe content detected: Found sequence corresponding to '{detected_word}' (IDs: {bad_ids}).") |
|
return False |
|
return True |
|
|
|
def generate_safely(self, prompt, max_new_tokens=50, temperature=0.5, top_k=50): |
|
self.model.eval() |
|
device = next(self.model.parameters()).device |
|
|
|
|
|
|
|
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False).ids |
|
|
|
|
|
|
|
if prompt_ids and prompt_ids[0] == self.bos_id: |
|
input_ids = list(prompt_ids) |
|
else: |
|
input_ids = [self.bos_id] + list(prompt_ids) |
|
|
|
|
|
if self.assistant_id is not None: |
|
input_ids.append(self.assistant_id) |
|
else: |
|
logging.error("Assistant token ID is None, cannot properly start generation.") |
|
return "Error: Assistant token not found." |
|
|
|
|
|
generated_ids = list(input_ids) |
|
logging.debug(f"Starting safe generation with initial IDs: {generated_ids}") |
|
|
|
with torch.no_grad(): |
|
for step in range(max_new_tokens): |
|
|
|
current_input_ids = generated_ids[-CONFIG["max_seq_len"]:] |
|
current_input_tensor = torch.tensor([current_input_ids]).to(device) |
|
|
|
attention_mask = torch.ones_like(current_input_tensor) |
|
|
|
|
|
try: |
|
outputs = self.model(current_input_tensor, attention_mask=attention_mask) |
|
next_token_logits = outputs[:, -1, :] |
|
except Exception as e: |
|
logging.error(f"Model forward pass failed during generation: {e}") |
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
if temperature > 0 and temperature != 1.0: |
|
next_token_logits = next_token_logits / temperature |
|
if top_k > 0 and top_k < next_token_logits.size(-1): |
|
v, _ = torch.topk(next_token_logits, top_k) |
|
|
|
safe_logits = torch.nan_to_num(next_token_logits, nan=-float('inf'), posinf=float('inf'), neginf=-float('inf')) |
|
threshold = v[:, [-1]] |
|
safe_logits[safe_logits < threshold] = -float('Inf') |
|
next_token_logits = safe_logits |
|
|
|
probs = torch.softmax(next_token_logits, dim=-1) |
|
|
|
if torch.isnan(probs).any(): |
|
logging.warning("NaN detected in probabilities before sampling. Replacing with uniform distribution.") |
|
probs = torch.ones_like(probs) / probs.size(-1) |
|
|
|
next_token_id = torch.multinomial(probs, num_samples=1).item() |
|
|
|
|
|
|
|
potential_sequence_ids = generated_ids + [next_token_id] |
|
|
|
|
|
if not self.content_filter(potential_sequence_ids): |
|
logging.warning(f"Potential unsafe token ({next_token_id}, '{self.tokenizer.decode([next_token_id])}') blocked POST-sampling. Stopping generation.") |
|
|
|
break |
|
|
|
|
|
generated_ids.append(next_token_id) |
|
|
|
|
|
if next_token_id == self.eos_id: |
|
logging.debug(f"EOS token generated at step {step+1}. Stopping generation.") |
|
break |
|
|
|
|
|
if step == max_new_tokens - 1: |
|
logging.debug("Max new tokens reached. Stopping generation.") |
|
|
|
if generated_ids[-1] != self.eos_id and self.eos_id is not None: |
|
generated_ids.append(self.eos_id) |
|
|
|
self.model.train() |
|
|
|
|
|
start_index = len(input_ids) |
|
response_ids = generated_ids[start_index:] |
|
|
|
|
|
|
|
decoded_text = self.tokenizer.decode(response_ids, skip_special_tokens=True).strip() |
|
|
|
return decoded_text |
|
|
|
|
|
def debug_generation(self, prompt="<user> Tell me about your hobbies."): |
|
logging.info(f"\n--- Debug Generation & Safety Check ---") |
|
|
|
if not prompt.strip().endswith("</s>"): |
|
if not prompt.strip().endswith("<user>") and not prompt.strip().endswith("<assistant>"): |
|
prompt = prompt.strip() + " </s>" |
|
else: |
|
prompt = prompt.strip() + " </s>" |
|
|
|
|
|
if prompt.startswith("<s>"): |
|
prompt = prompt[len("<s>"):].strip() |
|
|
|
|
|
generated_response = self.generate_safely(prompt, max_new_tokens=60, temperature=0.7, top_k=50) |
|
|
|
logging.info(f"Prompt Sent: '{prompt}'") |
|
logging.info(f"Generated Response: '{generated_response}'") |
|
logging.info("\n--- End Debug Generation ---\n") |
|
|
|
|
|
class CheckpointManager: |
|
def __init__(self): |
|
|
|
self.checkpoint_dir = CONFIG["checkpoint_dir"] |
|
os.makedirs(self.checkpoint_dir, exist_ok=True) |
|
logging.info(f"Checkpoint directory set to: {self.checkpoint_dir}") |
|
|
|
def save(self, model, optimizer, step): |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
prefix = os.path.basename(self.checkpoint_dir).replace("checkpoints_", "") |
|
|
|
step_str = str(step) |
|
filename = f"hrom_{prefix}_step{step_str}_{timestamp}.pt" |
|
path = os.path.join(self.checkpoint_dir, filename) |
|
state = { |
|
"model": model.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
"step": step if isinstance(step, int) else -1, |
|
"config": CONFIG |
|
} |
|
logging.info(f"Saving checkpoint to {path}...") |
|
try: |
|
torch.save(state, path) |
|
logging.info(f"Checkpoint saved successfully at step {step_str}.") |
|
self._cleanup_old_checkpoints() |
|
except Exception as e: |
|
logging.error(f"Failed to save checkpoint '{path}': {e}") |
|
|
|
def _cleanup_old_checkpoints(self): |
|
max_checkpoints = CONFIG.get("max_checkpoints", 5) |
|
if max_checkpoints <= 0: |
|
return |
|
|
|
try: |
|
|
|
prefix = os.path.basename(self.checkpoint_dir).replace("checkpoints_", "") |
|
pattern = re.compile(rf"hrom_{prefix}_step(\d+|.+)_(\d{{8}}_\d{{6}})\.pt") |
|
|
|
checkpoints = [] |
|
for f in os.listdir(self.checkpoint_dir): |
|
match = pattern.match(f) |
|
if match: |
|
filepath = os.path.join(self.checkpoint_dir, f) |
|
checkpoints.append((filepath, os.path.getmtime(filepath))) |
|
|
|
|
|
checkpoints.sort(key=lambda x: x[1]) |
|
|
|
num_to_delete = len(checkpoints) - max_checkpoints |
|
if num_to_delete > 0: |
|
|
|
for i in range(num_to_delete): |
|
file_to_remove, _ = checkpoints[i] |
|
try: |
|
os.remove(file_to_remove) |
|
|
|
except OSError as e: |
|
logging.error(f"Error removing checkpoint {file_to_remove}: {e}") |
|
except Exception as e: |
|
logging.error(f"Error during checkpoint cleanup: {e}") |
|
|
|
|
|
def load_latest(self, model, optimizer): |
|
try: |
|
|
|
prefix = os.path.basename(self.checkpoint_dir).replace("checkpoints_", "") |
|
pattern = re.compile(rf"hrom_{prefix}_step(\d+|.+)_(\d{{8}}_\d{{6}})\.pt") |
|
checkpoints = [] |
|
for f in os.listdir(self.checkpoint_dir): |
|
match = pattern.match(f) |
|
if match: |
|
filepath = os.path.join(self.checkpoint_dir, f) |
|
checkpoints.append((filepath, os.path.getmtime(filepath))) |
|
|
|
if not checkpoints: |
|
logging.info("No valid checkpoints found to load.") |
|
return 0 |
|
|
|
|
|
checkpoints.sort(key=lambda x: x[1], reverse=True) |
|
|
|
latest_checkpoint_path, _ = checkpoints[0] |
|
logging.info(f"Loading latest checkpoint from: {latest_checkpoint_path}") |
|
map_location = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
checkpoint = torch.load(latest_checkpoint_path, map_location=map_location) |
|
|
|
|
|
loaded_config = checkpoint.get("config", {}) |
|
|
|
critical_keys = ["dim", "n_layers", "n_heads", "ff_dim", "vocab_size", "max_seq_len", "tokenizer_name"] |
|
mismatched_keys = [] |
|
if loaded_config: |
|
for key in critical_keys: |
|
|
|
if key in loaded_config and key in CONFIG and loaded_config[key] != CONFIG[key]: |
|
mismatched_keys.append((key, loaded_config[key], CONFIG[key])) |
|
|
|
elif key in loaded_config and key not in CONFIG: |
|
mismatched_keys.append((key, loaded_config[key], "Not in current CONFIG")) |
|
|
|
elif key not in loaded_config and key in CONFIG: |
|
mismatched_keys.append((key, "Not in loaded CONFIG", CONFIG[key])) |
|
|
|
|
|
if mismatched_keys: |
|
logging.warning("--- CONFIG MISMATCH DETECTED ---") |
|
logging.warning(f"Checkpoint '{os.path.basename(latest_checkpoint_path)}' was saved with different critical parameters:") |
|
for key, loaded_val, current_val in mismatched_keys: |
|
logging.warning(f" - {key}: Checkpoint='{loaded_val}', Current='{current_val}'") |
|
|
|
|
|
logging.warning("Proceeding with loading, but results may be unexpected or errors may occur.") |
|
else: |
|
logging.warning("Checkpoint does not contain configuration info. Cannot check compatibility.") |
|
|
|
|
|
|
|
try: |
|
|
|
model.load_state_dict(checkpoint['model'], strict=True) |
|
except RuntimeError as e: |
|
logging.error(f"Failed to load model state_dict: {e}") |
|
logging.error("This often happens due to architecture mismatch (check CONFIG) or corrupted checkpoint.") |
|
logging.error("Starting training from scratch.") |
|
return 0 |
|
|
|
try: |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
except ValueError as e: |
|
logging.warning(f"Could not load optimizer state_dict: {e}. Optimizer state will be reset.") |
|
|
|
|
|
optimizer.state = defaultdict(dict) |
|
logging.warning("Optimizer state reset.") |
|
except Exception as e: |
|
logging.error(f"Unexpected error loading optimizer state: {e}. Starting training from scratch.") |
|
return 0 |
|
|
|
start_step = checkpoint.get('step', 0) |
|
|
|
start_step = max(0, start_step) + 1 if isinstance(start_step, int) else 0 |
|
|
|
|
|
logging.info(f"Checkpoint loaded successfully. Resuming from optimizer step {start_step}.") |
|
|
|
for state in optimizer.state.values(): |
|
for k, v in state.items(): |
|
if isinstance(v, torch.Tensor): |
|
try: |
|
state[k] = v.to(map_location) |
|
except Exception as e: |
|
logging.error(f"Failed to move optimizer tensor '{k}' to device '{map_location}': {e}") |
|
return start_step |
|
|
|
except FileNotFoundError: |
|
logging.info(f"No checkpoint directory '{self.checkpoint_dir}' or files found. Starting training from scratch.") |
|
return 0 |
|
except Exception as e: |
|
logging.error(f"Error loading checkpoint from '{self.checkpoint_dir}': {e}. Starting training from scratch.") |
|
|
|
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
|
def train(): |
|
logging.info("Starting HROM training process on combined datasets (daily_dialog, empathetic_dialogues, blended_skill_talk, AlekseyKorshuk/persona-chat)...") |
|
logging.info(f"Configuration: {CONFIG}") |
|
|
|
|
|
tokenizer_trainer = TokenizerTrainer() |
|
tokenizer_path = tokenizer_trainer.tokenizer_path |
|
if not os.path.exists(tokenizer_path): |
|
logging.info(f"Combined tokenizer '{CONFIG['tokenizer_name']}' not found. Training tokenizer...") |
|
try: |
|
|
|
tokenizer_trainer.train(CONFIG["datasets"]) |
|
except Exception as e: |
|
logging.error(f"Failed during tokenizer training: {e}", exc_info=True) |
|
return |
|
else: |
|
logging.info(f"Loading existing combined tokenizer from {tokenizer_path}") |
|
|
|
try: |
|
tokenizer = tokenizer_trainer.get_tokenizer() |
|
|
|
CONFIG['pad_token_id'] = tokenizer.token_to_id("<pad>") |
|
CONFIG['bos_token_id'] = tokenizer.token_to_id("<s>") |
|
CONFIG['eos_token_id'] = tokenizer.token_to_id("</s>") |
|
logging.info(f"Loaded tokenizer. Vocab size: {tokenizer.get_vocab_size()}. Special IDs: PAD={CONFIG['pad_token_id']}, BOS={CONFIG['bos_token_id']}, EOS={CONFIG['eos_token_id']}") |
|
except (FileNotFoundError, ValueError) as e: |
|
logging.error(f"Failed to load tokenizer: {e}. Cannot continue.") |
|
return |
|
|
|
|
|
logging.info("Initializing HROM model...") |
|
|
|
if CONFIG['vocab_size'] != tokenizer.get_vocab_size(): |
|
logging.warning(f"Config vocab_size ({CONFIG['vocab_size']}) differs from tokenizer vocab size ({tokenizer.get_vocab_size()}). Using tokenizer's size.") |
|
CONFIG['vocab_size'] = tokenizer.get_vocab_size() |
|
model = HROM() |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
logging.info(f"Model initialized. Total parameters: {total_params:,}") |
|
logging.info(f"Trainable parameters: {trainable_params:,}") |
|
logging.info(f"Parameters (Millions): Total={total_params/1e6:.2f}M, Trainable={trainable_params/1e6:.2f}M") |
|
|
|
|
|
|
|
logging.info("Setting up combined dataset and dataloader...") |
|
try: |
|
logging.info("Pre-loading/caching datasets...") |
|
for ds_name in CONFIG["datasets"]: |
|
logging.info(f"Checking cache for '{ds_name}'...") |
|
try: |
|
|
|
_ = load_dataset(ds_name, split="train[:1]", download_mode="reuse_cache_if_exists", trust_remote_code=True) |
|
except Exception as e: |
|
|
|
logging.error(f"Could not pre-check dataset '{ds_name}': {e}") |
|
logging.info("Dataset download/cache check presumed complete.") |
|
|
|
|
|
dataset = CombinedChatDataset(tokenizer) |
|
|
|
|
|
if len(dataset) == 0: |
|
logging.error("Dataset is empty after processing all sources. Cannot train.") |
|
return |
|
|
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=CONFIG["batch_size"], |
|
collate_fn=CombinedChatDataset.collate_fn, |
|
shuffle=True, |
|
|
|
num_workers=min(4, os.cpu_count() // 2 if (os.cpu_count() and os.cpu_count() > 1) else 1), |
|
pin_memory=torch.cuda.is_available(), |
|
prefetch_factor=2 if torch.cuda.is_available() and os.cpu_count() and os.cpu_count() > 1 else None, |
|
drop_last=False |
|
) |
|
except Exception as e: |
|
logging.error(f"Failed to initialize dataset/dataloader: {e}", exc_info=True) |
|
return |
|
|
|
|
|
logging.info("Initializing Trainer, Checkpoint Manager, and Safety Manager...") |
|
|
|
trainer_obj = HROMTrainer(model, tokenizer) |
|
checkpoint_manager = CheckpointManager() |
|
safety = SafetyManager(model, tokenizer) |
|
|
|
|
|
start_optimizer_step = checkpoint_manager.load_latest(model, trainer_obj.optimizer) |
|
|
|
model.to(trainer_obj.device) |
|
|
|
|
|
logging.info(f"Starting training from optimizer step {start_optimizer_step}") |
|
optimizer_step = start_optimizer_step |
|
total_loss_accum = 0.0 |
|
|
|
batch_step = optimizer_step * CONFIG["grad_accum_steps"] |
|
epochs_completed = batch_step // len(dataloader) if len(dataloader) > 0 else 0 |
|
start_epoch = epochs_completed |
|
|
|
|
|
try: |
|
if len(dataloader) == 0: |
|
raise ValueError("DataLoader has zero length. Cannot estimate total steps.") |
|
total_optimizer_steps = (len(dataloader) * CONFIG["num_epochs"]) // CONFIG["grad_accum_steps"] |
|
logging.info(f"Estimated dataset size: {len(dataset)}") |
|
logging.info(f"Estimated batches per epoch: {len(dataloader)}") |
|
logging.info(f"Gradient Accumulation Steps: {CONFIG['grad_accum_steps']}") |
|
logging.info(f"Effective Batch Size: {CONFIG['batch_size'] * CONFIG['grad_accum_steps']}") |
|
logging.info(f"Target Epochs: {CONFIG['num_epochs']}") |
|
logging.info(f"Estimated total optimizer steps for {CONFIG['num_epochs']} epochs: {total_optimizer_steps}") |
|
except Exception as e: |
|
logging.warning(f"Could not accurately estimate dataloader length or total steps: {e}") |
|
total_optimizer_steps = -1 |
|
|
|
|
|
model.train() |
|
|
|
for epoch in range(start_epoch, CONFIG["num_epochs"]): |
|
logging.info(f"--- Starting Epoch {epoch+1}/{CONFIG['num_epochs']} ---") |
|
epoch_loss = 0.0 |
|
num_batches_in_epoch = 0 |
|
|
|
|
|
for i, batch in enumerate(dataloader): |
|
|
|
if batch is None: |
|
logging.warning(f"Skipping empty batch at step {i} in epoch {epoch+1}") |
|
continue |
|
|
|
|
|
loss = trainer_obj.train_step(batch) |
|
if loss is None or torch.isnan(torch.tensor(loss)) or torch.isinf(torch.tensor(loss)): |
|
logging.error(f"NaN, Inf, or None loss detected: {loss}. Epoch {epoch+1}, Batch {i}, Opt Step {optimizer_step}. Stopping.") |
|
|
|
checkpoint_manager.save(model, trainer_obj.optimizer, f"{optimizer_step}_error") |
|
return |
|
|
|
total_loss_accum += loss |
|
epoch_loss += loss |
|
num_batches_in_epoch += 1 |
|
batch_step += 1 |
|
|
|
|
|
|
|
if batch_step % CONFIG["grad_accum_steps"] == 0: |
|
current_lr = trainer_obj.clip_and_step(optimizer_step) |
|
|
|
|
|
avg_loss = total_loss_accum / CONFIG["grad_accum_steps"] |
|
total_loss_accum = 0.0 |
|
|
|
|
|
if optimizer_step % CONFIG["debug_interval"] == 0: |
|
logging.info(f"Epoch {epoch+1} | Opt Step {optimizer_step} | Batch Step {batch_step} | Avg Loss: {avg_loss:.4f} | LR: {current_lr:.2e}") |
|
|
|
if optimizer_step % (CONFIG["debug_interval"] * 5) == 0: |
|
safety.debug_generation("<user> Hi there! How are you doing today?") |
|
|
|
|
|
if optimizer_step > 0 and optimizer_step % CONFIG["checkpoint_interval"] == 0: |
|
logging.info(f"Checkpoint interval reached at optimizer step {optimizer_step}.") |
|
checkpoint_manager.save(model, trainer_obj.optimizer, optimizer_step) |
|
|
|
safety.debug_generation("<user> Hi! How are you?") |
|
|
|
optimizer_step += 1 |
|
|
|
|
|
avg_epoch_loss = epoch_loss / num_batches_in_epoch if num_batches_in_epoch > 0 else 0 |
|
logging.info(f"--- Finished Epoch {epoch+1}/{CONFIG['num_epochs']} | Average Epoch Loss: {avg_epoch_loss:.4f} ---") |
|
|
|
|
|
checkpoint_manager.save(model, trainer_obj.optimizer, f"epoch{epoch+1}_step{optimizer_step}") |
|
|
|
safety.debug_generation("<user> Hi! Whats up?") |
|
|
|
|
|
logging.info(f"Training finished after {CONFIG['num_epochs']} target epochs.") |
|
|
|
logging.info("Saving final model state...") |
|
checkpoint_manager.save(model, trainer_obj.optimizer, f"final_step{optimizer_step}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
train() |