File size: 5,994 Bytes
fc54e43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
# =============================================================================
# training/data_loader.py
# =============================================================================
import torch
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Iterator
import json
import random
from core.tokenizer import MambaTokenizer
from core.preprocess import TextPreprocessor
class MambaDataset(Dataset):
"""Dataset for Mamba training"""
def __init__(self, data_path: str, tokenizer: MambaTokenizer,
preprocessor: TextPreprocessor, config):
self.config = config
self.tokenizer = tokenizer
self.preprocessor = preprocessor
self.max_length = config.max_seq_len
# Load data
self.data = self._load_data(data_path)
def _load_data(self, data_path: str) -> List[str]:
"""Load training data from file"""
data = []
try:
if data_path.endswith('.json'):
with open(data_path, 'r') as f:
raw_data = json.load(f)
if isinstance(raw_data, list):
data = [item['text'] if isinstance(item, dict) else str(item)
for item in raw_data]
else:
data = [raw_data['text']]
elif data_path.endswith('.txt'):
with open(data_path, 'r') as f:
content = f.read()
# Split into chunks
data = self.preprocessor.chunk_text(content, self.max_length)
print(f"Loaded {len(data)} training examples")
except Exception as e:
print(f"Error loading data: {e}")
# Create dummy data for testing
data = [f"This is example text number {i}." for i in range(1000)]
return data
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Get a training example"""
text = self.data[idx]
# Preprocess text
clean_text = self.preprocessor.clean_text(text)
# Tokenize
encoded = self.tokenizer.encode(clean_text, max_length=self.max_length)
# Create input and target sequences
input_ids = encoded['input_ids'].squeeze(0) # [seq_len]
# For language modeling, target is input shifted by 1
target_ids = torch.cat([input_ids[1:], torch.tensor([self.tokenizer.tokenizer.eos_token_id])])
return {
'input_ids': input_ids[:-1], # [seq_len-1]
'target_ids': target_ids[:-1], # [seq_len-1]
'attention_mask': encoded['attention_mask'].squeeze(0)[:-1]
}
class DomainSpecificDataset(Dataset):
"""Dataset for domain-specific specialist training"""
def __init__(self, domain_data: Dict[str, List[str]], domain_id: int,
tokenizer: MambaTokenizer, preprocessor: TextPreprocessor, config):
self.domain_id = domain_id
self.tokenizer = tokenizer
self.preprocessor = preprocessor
self.config = config
# Get domain-specific data
domain_name = f"domain_{domain_id}"
self.data = domain_data.get(domain_name, [])
if not self.data:
# Create synthetic domain data for testing
self.data = [f"Domain {domain_id} specific text example {i}."
for i in range(100)]
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Get domain-specific training example"""
text = self.data[idx]
# Preprocess and tokenize
clean_text = self.preprocessor.clean_text(text)
encoded = self.tokenizer.encode(clean_text, max_length=self.config.max_seq_len)
input_ids = encoded['input_ids'].squeeze(0)
target_ids = torch.cat([input_ids[1:], torch.tensor([self.tokenizer.tokenizer.eos_token_id])])
return {
'input_ids': input_ids[:-1],
'target_ids': target_ids[:-1],
'attention_mask': encoded['attention_mask'].squeeze(0)[:-1],
'domain_id': self.domain_id
}
def create_data_loaders(config, tokenizer: MambaTokenizer,
preprocessor: TextPreprocessor) -> Dict[str, DataLoader]:
"""Create data loaders for training"""
# Main training dataset
train_dataset = MambaDataset(
data_path=getattr(config, 'train_data_path', 'train_data.txt'),
tokenizer=tokenizer,
preprocessor=preprocessor,
config=config
)
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=4,
pin_memory=True
)
# Domain-specific datasets for specialist training
domain_loaders = {}
# Load domain-specific data (placeholder)
domain_data = {} # Should load actual domain-specific datasets
for domain_id in range(config.num_specialists):
domain_dataset = DomainSpecificDataset(
domain_data=domain_data,
domain_id=domain_id,
tokenizer=tokenizer,
preprocessor=preprocessor,
config=config
)
domain_loader = DataLoader(
domain_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=2
)
domain_loaders[domain_id] = domain_loader
return {
'main': train_loader,
'domains': domain_loaders
} |