Debito's picture
Upload 4 files
fc54e43 verified
# =============================================================================
# training/data_loader.py
# =============================================================================
import torch
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Iterator
import json
import random
from core.tokenizer import MambaTokenizer
from core.preprocess import TextPreprocessor
class MambaDataset(Dataset):
"""Dataset for Mamba training"""
def __init__(self, data_path: str, tokenizer: MambaTokenizer,
preprocessor: TextPreprocessor, config):
self.config = config
self.tokenizer = tokenizer
self.preprocessor = preprocessor
self.max_length = config.max_seq_len
# Load data
self.data = self._load_data(data_path)
def _load_data(self, data_path: str) -> List[str]:
"""Load training data from file"""
data = []
try:
if data_path.endswith('.json'):
with open(data_path, 'r') as f:
raw_data = json.load(f)
if isinstance(raw_data, list):
data = [item['text'] if isinstance(item, dict) else str(item)
for item in raw_data]
else:
data = [raw_data['text']]
elif data_path.endswith('.txt'):
with open(data_path, 'r') as f:
content = f.read()
# Split into chunks
data = self.preprocessor.chunk_text(content, self.max_length)
print(f"Loaded {len(data)} training examples")
except Exception as e:
print(f"Error loading data: {e}")
# Create dummy data for testing
data = [f"This is example text number {i}." for i in range(1000)]
return data
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Get a training example"""
text = self.data[idx]
# Preprocess text
clean_text = self.preprocessor.clean_text(text)
# Tokenize
encoded = self.tokenizer.encode(clean_text, max_length=self.max_length)
# Create input and target sequences
input_ids = encoded['input_ids'].squeeze(0) # [seq_len]
# For language modeling, target is input shifted by 1
target_ids = torch.cat([input_ids[1:], torch.tensor([self.tokenizer.tokenizer.eos_token_id])])
return {
'input_ids': input_ids[:-1], # [seq_len-1]
'target_ids': target_ids[:-1], # [seq_len-1]
'attention_mask': encoded['attention_mask'].squeeze(0)[:-1]
}
class DomainSpecificDataset(Dataset):
"""Dataset for domain-specific specialist training"""
def __init__(self, domain_data: Dict[str, List[str]], domain_id: int,
tokenizer: MambaTokenizer, preprocessor: TextPreprocessor, config):
self.domain_id = domain_id
self.tokenizer = tokenizer
self.preprocessor = preprocessor
self.config = config
# Get domain-specific data
domain_name = f"domain_{domain_id}"
self.data = domain_data.get(domain_name, [])
if not self.data:
# Create synthetic domain data for testing
self.data = [f"Domain {domain_id} specific text example {i}."
for i in range(100)]
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Get domain-specific training example"""
text = self.data[idx]
# Preprocess and tokenize
clean_text = self.preprocessor.clean_text(text)
encoded = self.tokenizer.encode(clean_text, max_length=self.config.max_seq_len)
input_ids = encoded['input_ids'].squeeze(0)
target_ids = torch.cat([input_ids[1:], torch.tensor([self.tokenizer.tokenizer.eos_token_id])])
return {
'input_ids': input_ids[:-1],
'target_ids': target_ids[:-1],
'attention_mask': encoded['attention_mask'].squeeze(0)[:-1],
'domain_id': self.domain_id
}
def create_data_loaders(config, tokenizer: MambaTokenizer,
preprocessor: TextPreprocessor) -> Dict[str, DataLoader]:
"""Create data loaders for training"""
# Main training dataset
train_dataset = MambaDataset(
data_path=getattr(config, 'train_data_path', 'train_data.txt'),
tokenizer=tokenizer,
preprocessor=preprocessor,
config=config
)
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=4,
pin_memory=True
)
# Domain-specific datasets for specialist training
domain_loaders = {}
# Load domain-specific data (placeholder)
domain_data = {} # Should load actual domain-specific datasets
for domain_id in range(config.num_specialists):
domain_dataset = DomainSpecificDataset(
domain_data=domain_data,
domain_id=domain_id,
tokenizer=tokenizer,
preprocessor=preprocessor,
config=config
)
domain_loader = DataLoader(
domain_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=2
)
domain_loaders[domain_id] = domain_loader
return {
'main': train_loader,
'domains': domain_loaders
}