| | import torch |
| | from torch.utils.data import Dataset |
| | import json |
| | import random |
| |
|
| |
|
| | class MTPDataset(Dataset): |
| | """Dataset mejorado con augmentación de datos""" |
| | |
| | def __init__(self, corpus_path, tokenizer, max_seq_len=512, |
| | use_augmentation=False, augmentation_prob=0.3): |
| | self.tokenizer = tokenizer |
| | self.max_seq_len = max_seq_len |
| | self.use_augmentation = use_augmentation |
| | self.augmentation_prob = augmentation_prob |
| | self.data = [] |
| | |
| | |
| | with open(corpus_path, 'r', encoding='utf-8') as f: |
| | for line in f: |
| | entry = json.loads(line) |
| | if 'instruction' in entry and 'response' in entry: |
| | self.data.append(entry) |
| | |
| | print(f"✓ Loaded {len(self.data)} examples from corpus") |
| | if use_augmentation: |
| | print(f"✓ Data augmentation enabled (prob={augmentation_prob})") |
| | |
| | def __len__(self): |
| | return len(self.data) |
| | |
| | def augment_text(self, text): |
| | """Augmentación simple de texto""" |
| | if not self.use_augmentation or random.random() > self.augmentation_prob: |
| | return text |
| | |
| | |
| | if random.random() < 0.3: |
| | text = text.strip() |
| | |
| | |
| | if random.random() < 0.2: |
| | if text.endswith('.'): |
| | text = text[:-1] |
| | elif not text.endswith(('.', '!', '?')): |
| | text = text + '.' |
| | |
| | return text |
| | |
| | def __getitem__(self, idx): |
| | entry = self.data[idx] |
| | |
| | instruction = entry['instruction'] |
| | response = entry['response'] |
| | |
| | |
| | instruction = self.augment_text(instruction) |
| | response = self.augment_text(response) |
| | |
| | |
| | full_text = f"### Instrucción:\n{instruction}\n\n### Respuesta:\n{response}" |
| | |
| | |
| | tokens = self.tokenizer.encode(full_text) |
| | |
| | |
| | tokens = [self.tokenizer.bos_id()] + tokens + [self.tokenizer.eos_id()] |
| | |
| | |
| | if len(tokens) > self.max_seq_len: |
| | |
| | tokens = [tokens[0]] + tokens[1:self.max_seq_len-1] + [self.tokenizer.eos_id()] |
| | |
| | |
| | input_ids = torch.tensor(tokens[:-1], dtype=torch.long) |
| | target_ids = torch.tensor(tokens[1:], dtype=torch.long) |
| | |
| | return input_ids, target_ids |
| |
|
| |
|
| | def collate_fn(batch, pad_id=0): |
| | """Custom collate function con padding inteligente""" |
| | input_ids = [item[0] for item in batch] |
| | target_ids = [item[1] for item in batch] |
| | |
| | |
| | max_len = max(len(ids) for ids in input_ids) |
| | |
| | |
| | input_ids_padded = [] |
| | target_ids_padded = [] |
| | |
| | for inp, tgt in zip(input_ids, target_ids): |
| | pad_len = max_len - len(inp) |
| | input_ids_padded.append(torch.cat([inp, torch.full((pad_len,), pad_id, dtype=torch.long)])) |
| | target_ids_padded.append(torch.cat([tgt, torch.full((pad_len,), pad_id, dtype=torch.long)])) |
| | |
| | return torch.stack(input_ids_padded), torch.stack(target_ids_padded) |