|
|
|
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from typing import List, Dict, Iterator
|
|
import json
|
|
import random
|
|
from core.tokenizer import MambaTokenizer
|
|
from core.preprocess import TextPreprocessor
|
|
|
|
class MambaDataset(Dataset):
|
|
"""Dataset for Mamba training"""
|
|
|
|
def __init__(self, data_path: str, tokenizer: MambaTokenizer,
|
|
preprocessor: TextPreprocessor, config):
|
|
self.config = config
|
|
self.tokenizer = tokenizer
|
|
self.preprocessor = preprocessor
|
|
self.max_length = config.max_seq_len
|
|
|
|
|
|
self.data = self._load_data(data_path)
|
|
|
|
def _load_data(self, data_path: str) -> List[str]:
|
|
"""Load training data from file"""
|
|
data = []
|
|
|
|
try:
|
|
if data_path.endswith('.json'):
|
|
with open(data_path, 'r') as f:
|
|
raw_data = json.load(f)
|
|
if isinstance(raw_data, list):
|
|
data = [item['text'] if isinstance(item, dict) else str(item)
|
|
for item in raw_data]
|
|
else:
|
|
data = [raw_data['text']]
|
|
|
|
elif data_path.endswith('.txt'):
|
|
with open(data_path, 'r') as f:
|
|
content = f.read()
|
|
|
|
data = self.preprocessor.chunk_text(content, self.max_length)
|
|
|
|
print(f"Loaded {len(data)} training examples")
|
|
|
|
except Exception as e:
|
|
print(f"Error loading data: {e}")
|
|
|
|
data = [f"This is example text number {i}." for i in range(1000)]
|
|
|
|
return data
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
|
"""Get a training example"""
|
|
text = self.data[idx]
|
|
|
|
|
|
clean_text = self.preprocessor.clean_text(text)
|
|
|
|
|
|
encoded = self.tokenizer.encode(clean_text, max_length=self.max_length)
|
|
|
|
|
|
input_ids = encoded['input_ids'].squeeze(0)
|
|
|
|
|
|
target_ids = torch.cat([input_ids[1:], torch.tensor([self.tokenizer.tokenizer.eos_token_id])])
|
|
|
|
return {
|
|
'input_ids': input_ids[:-1],
|
|
'target_ids': target_ids[:-1],
|
|
'attention_mask': encoded['attention_mask'].squeeze(0)[:-1]
|
|
}
|
|
|
|
class DomainSpecificDataset(Dataset):
|
|
"""Dataset for domain-specific specialist training"""
|
|
|
|
def __init__(self, domain_data: Dict[str, List[str]], domain_id: int,
|
|
tokenizer: MambaTokenizer, preprocessor: TextPreprocessor, config):
|
|
self.domain_id = domain_id
|
|
self.tokenizer = tokenizer
|
|
self.preprocessor = preprocessor
|
|
self.config = config
|
|
|
|
|
|
domain_name = f"domain_{domain_id}"
|
|
self.data = domain_data.get(domain_name, [])
|
|
|
|
if not self.data:
|
|
|
|
self.data = [f"Domain {domain_id} specific text example {i}."
|
|
for i in range(100)]
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
|
"""Get domain-specific training example"""
|
|
text = self.data[idx]
|
|
|
|
|
|
clean_text = self.preprocessor.clean_text(text)
|
|
encoded = self.tokenizer.encode(clean_text, max_length=self.config.max_seq_len)
|
|
|
|
input_ids = encoded['input_ids'].squeeze(0)
|
|
target_ids = torch.cat([input_ids[1:], torch.tensor([self.tokenizer.tokenizer.eos_token_id])])
|
|
|
|
return {
|
|
'input_ids': input_ids[:-1],
|
|
'target_ids': target_ids[:-1],
|
|
'attention_mask': encoded['attention_mask'].squeeze(0)[:-1],
|
|
'domain_id': self.domain_id
|
|
}
|
|
|
|
def create_data_loaders(config, tokenizer: MambaTokenizer,
|
|
preprocessor: TextPreprocessor) -> Dict[str, DataLoader]:
|
|
"""Create data loaders for training"""
|
|
|
|
|
|
train_dataset = MambaDataset(
|
|
data_path=getattr(config, 'train_data_path', 'train_data.txt'),
|
|
tokenizer=tokenizer,
|
|
preprocessor=preprocessor,
|
|
config=config
|
|
)
|
|
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=config.batch_size,
|
|
shuffle=True,
|
|
num_workers=4,
|
|
pin_memory=True
|
|
)
|
|
|
|
|
|
domain_loaders = {}
|
|
|
|
|
|
domain_data = {}
|
|
|
|
for domain_id in range(config.num_specialists):
|
|
domain_dataset = DomainSpecificDataset(
|
|
domain_data=domain_data,
|
|
domain_id=domain_id,
|
|
tokenizer=tokenizer,
|
|
preprocessor=preprocessor,
|
|
config=config
|
|
)
|
|
|
|
domain_loader = DataLoader(
|
|
domain_dataset,
|
|
batch_size=config.batch_size,
|
|
shuffle=True,
|
|
num_workers=2
|
|
)
|
|
|
|
domain_loaders[domain_id] = domain_loader
|
|
|
|
return {
|
|
'main': train_loader,
|
|
'domains': domain_loaders
|
|
} |