File size: 5,994 Bytes
fc54e43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# =============================================================================
# training/data_loader.py
# =============================================================================
import torch
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Iterator
import json
import random
from core.tokenizer import MambaTokenizer
from core.preprocess import TextPreprocessor

class MambaDataset(Dataset):
    """Dataset for Mamba training"""
    
    def __init__(self, data_path: str, tokenizer: MambaTokenizer, 

                 preprocessor: TextPreprocessor, config):
        self.config = config
        self.tokenizer = tokenizer
        self.preprocessor = preprocessor
        self.max_length = config.max_seq_len
        
        # Load data
        self.data = self._load_data(data_path)
        
    def _load_data(self, data_path: str) -> List[str]:
        """Load training data from file"""
        data = []
        
        try:
            if data_path.endswith('.json'):
                with open(data_path, 'r') as f:
                    raw_data = json.load(f)
                    if isinstance(raw_data, list):
                        data = [item['text'] if isinstance(item, dict) else str(item) 
                               for item in raw_data]
                    else:
                        data = [raw_data['text']]
                        
            elif data_path.endswith('.txt'):
                with open(data_path, 'r') as f:
                    content = f.read()
                    # Split into chunks
                    data = self.preprocessor.chunk_text(content, self.max_length)
            
            print(f"Loaded {len(data)} training examples")
            
        except Exception as e:
            print(f"Error loading data: {e}")
            # Create dummy data for testing
            data = [f"This is example text number {i}." for i in range(1000)]
        
        return data
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Get a training example"""
        text = self.data[idx]
        
        # Preprocess text
        clean_text = self.preprocessor.clean_text(text)
        
        # Tokenize
        encoded = self.tokenizer.encode(clean_text, max_length=self.max_length)
        
        # Create input and target sequences
        input_ids = encoded['input_ids'].squeeze(0)  # [seq_len]
        
        # For language modeling, target is input shifted by 1
        target_ids = torch.cat([input_ids[1:], torch.tensor([self.tokenizer.tokenizer.eos_token_id])])
        
        return {
            'input_ids': input_ids[:-1],  # [seq_len-1]
            'target_ids': target_ids[:-1],  # [seq_len-1]
            'attention_mask': encoded['attention_mask'].squeeze(0)[:-1]
        }

class DomainSpecificDataset(Dataset):
    """Dataset for domain-specific specialist training"""
    
    def __init__(self, domain_data: Dict[str, List[str]], domain_id: int,

                 tokenizer: MambaTokenizer, preprocessor: TextPreprocessor, config):
        self.domain_id = domain_id
        self.tokenizer = tokenizer
        self.preprocessor = preprocessor
        self.config = config
        
        # Get domain-specific data
        domain_name = f"domain_{domain_id}"
        self.data = domain_data.get(domain_name, [])
        
        if not self.data:
            # Create synthetic domain data for testing
            self.data = [f"Domain {domain_id} specific text example {i}." 
                        for i in range(100)]
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Get domain-specific training example"""
        text = self.data[idx]
        
        # Preprocess and tokenize
        clean_text = self.preprocessor.clean_text(text)
        encoded = self.tokenizer.encode(clean_text, max_length=self.config.max_seq_len)
        
        input_ids = encoded['input_ids'].squeeze(0)
        target_ids = torch.cat([input_ids[1:], torch.tensor([self.tokenizer.tokenizer.eos_token_id])])
        
        return {
            'input_ids': input_ids[:-1],
            'target_ids': target_ids[:-1],
            'attention_mask': encoded['attention_mask'].squeeze(0)[:-1],
            'domain_id': self.domain_id
        }

def create_data_loaders(config, tokenizer: MambaTokenizer, 

                       preprocessor: TextPreprocessor) -> Dict[str, DataLoader]:
    """Create data loaders for training"""
    
    # Main training dataset
    train_dataset = MambaDataset(
        data_path=getattr(config, 'train_data_path', 'train_data.txt'),
        tokenizer=tokenizer,
        preprocessor=preprocessor,
        config=config
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    
    # Domain-specific datasets for specialist training
    domain_loaders = {}
    
    # Load domain-specific data (placeholder)
    domain_data = {}  # Should load actual domain-specific datasets
    
    for domain_id in range(config.num_specialists):
        domain_dataset = DomainSpecificDataset(
            domain_data=domain_data,
            domain_id=domain_id,
            tokenizer=tokenizer,
            preprocessor=preprocessor,
            config=config
        )
        
        domain_loader = DataLoader(
            domain_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=2
        )
        
        domain_loaders[domain_id] = domain_loader
    
    return {
        'main': train_loader,
        'domains': domain_loaders
    }