| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader, IterableDataset |
| from datasets import load_dataset, concatenate_datasets, interleave_datasets |
| from typing import Dict, List, Optional, Any, Union |
| import random |
| import numpy as np |
| from tqdm import tqdm |
| import warnings |
| from PIL import Image |
| import requests |
| from io import BytesIO |
| from torchvision import transforms |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| warnings.filterwarnings("ignore", category=UserWarning) |
|
|
| from data_config import ( |
| PRETRAIN_DATASETS, |
| POSTTRAIN_DATASETS, |
| TEST_DATASETS, |
| PRETRAIN_MIX, |
| POSTTRAIN_MIX, |
| PREPROCESSING_CONFIG, |
| DATASET_CACHE_DIR, |
| HF_CACHE_DIR |
| ) |
|
|
| |
| image_transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| class PreTrainDataset(IterableDataset): |
| def __init__( |
| self, |
| mix_name: str = 'default', |
| tokenizer=None, |
| max_length: int = 2048, |
| streaming: bool = True, |
| seed: int = 42, |
| max_samples: Optional[int] = None |
| ): |
| super().__init__() |
| |
| if tokenizer is None: |
| raise ValueError("tokenizer cannot be None") |
| |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.streaming = streaming |
| self.seed = seed |
| self.max_samples = max_samples |
| self.samples_generated = 0 |
| |
| |
| if mix_name not in PRETRAIN_MIX: |
| raise ValueError(f"Unknown mix: {mix_name}. Available: {list(PRETRAIN_MIX.keys())}") |
| |
| mix_config = PRETRAIN_MIX[mix_name] |
| dataset_names = mix_config.get('datasets', []) |
| weights = mix_config.get('weights', []) |
| |
| if not dataset_names: |
| raise ValueError(f"No datasets found in mix: {mix_name}") |
| |
| logger.info(f"Loading pretrain mix: {mix_name}") |
| logger.info(f" Datasets: {dataset_names}") |
| logger.info(f" Weights: {weights}") |
| |
| |
| self.datasets = [] |
| self.probabilities = [] |
| |
| for name, weight in zip(dataset_names, weights): |
| if name not in PRETRAIN_DATASETS: |
| logger.warning(f"Dataset {name} not found in PRETRAIN_DATASETS, skipping") |
| continue |
| |
| config = PRETRAIN_DATASETS[name] |
| try: |
| ds = self._load_dataset(config) |
| if ds is not None: |
| self.datasets.append((name, ds, config)) |
| self.probabilities.append(weight) |
| logger.info(f" Successfully loaded {name}") |
| except Exception as e: |
| logger.error(f"Error loading {name}: {e}") |
| continue |
| |
| if not self.datasets: |
| raise ValueError("No datasets loaded successfully") |
| |
| |
| total = sum(self.probabilities) |
| self.probabilities = [p / total for p in self.probabilities] |
| |
| logger.info(f"Successfully loaded {len(self.datasets)} datasets") |
|
|
| def _load_dataset(self, config: Dict): |
| try: |
| load_kwargs = { |
| 'path': config['hf_path'], |
| 'split': config.get('split', 'train'), |
| 'streaming': config.get('streaming', self.streaming), |
| 'cache_dir': HF_CACHE_DIR, |
| } |
| |
| if 'config' in config: |
| load_kwargs['name'] = config['config'] |
| |
| ds = load_dataset(**load_kwargs) |
| return ds |
| except Exception as e: |
| logger.error(f"Failed to load {config.get('hf_path', 'unknown')}: {e}") |
| return None |
|
|
| def _process_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]: |
| try: |
| text_field = config.get('text_field', 'text') |
| text = sample.get(text_field, '') |
| |
| if not text or not isinstance(text, str): |
| return None |
| |
| text = text.strip() |
| if len(text) < 10: |
| return None |
| |
| |
| encoding = self.tokenizer( |
| text, |
| max_length=self.max_length, |
| truncation=True, |
| padding='max_length', |
| return_tensors='pt' |
| ) |
| |
| return { |
| 'input_ids': encoding['input_ids'].squeeze(0), |
| 'attention_mask': encoding['attention_mask'].squeeze(0), |
| 'type': 'text' |
| } |
| except Exception as e: |
| logger.debug(f"Error processing text sample: {e}") |
| return None |
|
|
| def _process_image_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]: |
| try: |
| text_field = config.get('text_field', 'caption') |
| image_field = config.get('image_field', 'image') |
| |
| text = sample.get(text_field, '') |
| image = sample.get(image_field) |
| |
| if not text or image is None: |
| return None |
| |
| |
| if isinstance(image, str): |
| try: |
| response = requests.get(image, timeout=5) |
| image = Image.open(BytesIO(response.content)).convert('RGB') |
| except Exception as img_error: |
| logger.debug(f"Failed to load image from URL: {img_error}") |
| return None |
| elif isinstance(image, Image.Image): |
| image = image.convert('RGB') |
| else: |
| return None |
| |
| |
| image_tensor = image_transform(image) |
| |
| |
| encoding = self.tokenizer( |
| text, |
| max_length=self.max_length, |
| truncation=True, |
| padding='max_length', |
| return_tensors='pt' |
| ) |
| |
| return { |
| 'input_ids': encoding['input_ids'].squeeze(0), |
| 'attention_mask': encoding['attention_mask'].squeeze(0), |
| 'image': image_tensor, |
| 'type': 'image_text' |
| } |
| except Exception as e: |
| logger.debug(f"Error processing image-text sample: {e}") |
| return None |
|
|
| def __iter__(self): |
| """迭代器""" |
| worker_info = torch.utils.data.get_worker_info() |
| if worker_info is not None: |
| |
| random.seed(self.seed + worker_info.id) |
| np.random.seed(self.seed + worker_info.id) |
| else: |
| random.seed(self.seed) |
| np.random.seed(self.seed) |
| |
| |
| iterators = [iter(ds) for _, ds, _ in self.datasets] |
| self.samples_generated = 0 |
| |
| while True: |
| |
| if self.max_samples and self.samples_generated >= self.max_samples: |
| break |
| |
| try: |
| |
| idx = np.random.choice(len(self.datasets), p=self.probabilities) |
| name, _, config = self.datasets[idx] |
| |
| |
| sample = next(iterators[idx]) |
| |
| |
| processed = None |
| if config.get('type') in ['text', 'code']: |
| processed = self._process_text_sample(sample, config) |
| elif config.get('type') == 'image_text': |
| processed = self._process_image_text_sample(sample, config) |
| else: |
| logger.debug(f"Unknown type: {config.get('type')}") |
| continue |
| |
| if processed is not None: |
| self.samples_generated += 1 |
| yield processed |
| |
| except StopIteration: |
| |
| try: |
| iterators[idx] = iter(self.datasets[idx][1]) |
| except Exception as e: |
| logger.error(f"Failed to recreate iterator for dataset {idx}: {e}") |
| break |
| except Exception as e: |
| logger.debug(f"Error in iterator: {e}") |
| continue |
|
|
|
|
| class PostTrainDataset(Dataset): |
| def __init__( |
| self, |
| mix_name: str = 'default', |
| tokenizer=None, |
| max_length: int = 2048, |
| max_samples: Optional[int] = None, |
| split: str = 'train' |
| ): |
| super().__init__() |
| |
| if tokenizer is None: |
| raise ValueError("tokenizer cannot be None") |
| |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.split = split |
| |
| |
| if mix_name not in POSTTRAIN_MIX: |
| raise ValueError(f"Unknown mix: {mix_name}. Available: {list(POSTTRAIN_MIX.keys())}") |
| |
| mix_config = POSTTRAIN_MIX[mix_name] |
| dataset_names = mix_config.get('datasets', []) |
| weights = mix_config.get('weights', []) |
| |
| if not dataset_names: |
| raise ValueError(f"No datasets found in mix: {mix_name}") |
| |
| logger.info(f"Loading posttrain mix: {mix_name}") |
| logger.info(f" Datasets: {dataset_names}") |
| |
| |
| all_datasets = [] |
| |
| for name in dataset_names: |
| if name not in POSTTRAIN_DATASETS: |
| logger.warning(f"Dataset {name} not found in POSTTRAIN_DATASETS") |
| continue |
| |
| config = POSTTRAIN_DATASETS[name] |
| try: |
| load_kwargs = { |
| 'path': config['hf_path'], |
| 'split': split, |
| 'streaming': config.get('streaming', False), |
| 'cache_dir': HF_CACHE_DIR, |
| } |
| if 'data_files' in config: |
| load_kwargs['data_files'] = config['data_files'] |
| if 'config' in config: |
| load_kwargs['name'] = config['config'] |
| |
| ds = load_dataset(**load_kwargs) |
| |
| |
| if config.get('max_samples'): |
| if hasattr(ds, 'take'): |
| ds = ds.take(config['max_samples']) |
| elif hasattr(ds, 'select'): |
| ds = ds.select(range(min(len(ds), config['max_samples']))) |
| |
| |
| def add_source(example): |
| example['_source'] = name |
| example['_config'] = config |
| return example |
| |
| ds = ds.map(add_source) |
| all_datasets.append(ds) |
| |
| ds_len = len(ds) if hasattr(ds, '__len__') else 'streaming' |
| logger.info(f" Loaded {name}: {ds_len} samples") |
| |
| except Exception as e: |
| logger.error(f"Error loading {name}: {e}") |
| continue |
| |
| |
| if not all_datasets: |
| raise ValueError("No datasets loaded successfully") |
| |
| if len(all_datasets) == 1: |
| self.dataset = all_datasets[0] |
| else: |
| |
| probabilities = [w / sum(weights[:len(all_datasets)]) |
| for w in weights[:len(all_datasets)]] |
| self.dataset = interleave_datasets( |
| all_datasets, |
| probabilities=probabilities, |
| seed=42, |
| stopping_strategy='all_exhausted' |
| ) |
| |
| |
| if max_samples and hasattr(self.dataset, '__len__'): |
| actual_len = min(len(self.dataset), max_samples) |
| self.dataset = self.dataset.select(range(actual_len)) |
| |
| dataset_len = len(self.dataset) if hasattr(self.dataset, '__len__') else 'streaming' |
| logger.info(f"Total samples: {dataset_len}") |
|
|
| def _format_instruction(self, sample: Dict, config: Dict) -> str: |
| """格式化instruction""" |
| try: |
| data_type = config.get('type', 'instruction') |
| |
| if data_type == 'instruction': |
| instruction_field = config.get('instruction_field', 'instruction') |
| input_field = config.get('input_field', 'input') |
| context_field = config.get('context_field', None) |
| |
| instruction = sample.get(instruction_field, '') |
| input_text = sample.get(input_field, '') |
| context = sample.get(context_field, '') if context_field else '' |
| |
| |
| prompt_parts = [f"Instruction: {instruction}"] |
| |
| if context: |
| prompt_parts.append(f"Context: {context}") |
| |
| if input_text: |
| prompt_parts.append(f"Input: {input_text}") |
| |
| prompt_parts.append("Response:") |
| return "\n".join(prompt_parts) |
| |
| elif data_type == 'conversation': |
| if 'conversations' in sample: |
| conversations = sample['conversations'] |
| if isinstance(conversations, list) and len(conversations) > 0: |
| dialogue = [] |
| for conv in conversations[:-1]: |
| role = conv.get('from', 'user') |
| content = conv.get('value', '') |
| dialogue.append(f"{role}: {content}") |
| return "\n".join(dialogue) + "\nassistant:" |
| |
| elif 'messages' in sample: |
| |
| messages = sample['messages'] |
| if isinstance(messages, list) and len(messages) > 0: |
| dialogue = [] |
| for msg in messages[:-1]: |
| role = msg.get('role', 'user') |
| content = msg.get('content', '') |
| dialogue.append(f"{role}: {content}") |
| return "\n".join(dialogue) + "\nassistant:" |
| |
| |
| return sample.get('text', '') |
| |
| elif data_type == 'code_instruction': |
| |
| instruction_field = config.get('instruction_field', 'instruction') |
| instruction = sample.get(instruction_field, '') |
| return f"### Instruction:\n{instruction}\n### Response:" |
| |
| elif data_type == 'multimodal_instruction': |
| |
| instruction_field = config.get('instruction_field', 'conversations') |
| conversations = sample.get(instruction_field, []) |
| if isinstance(conversations, list) and len(conversations) > 0: |
| |
| dialogue = [] |
| for conv in conversations[:-1]: |
| role = conv.get('from', 'user') |
| content = conv.get('value', '') |
| dialogue.append(f"{role}: {content}") |
| return "\n".join(dialogue) + "\nassistant:" |
| return "" |
| |
| else: |
| return sample.get(config.get('instruction_field', 'text'), '') |
| except Exception as e: |
| logger.debug(f"Error formatting instruction: {e}") |
| return "" |
|
|
| def _get_response(self, sample: Dict, config: Dict) -> str: |
| try: |
| data_type = config.get('type', 'instruction') |
| |
| if data_type == 'instruction' or data_type == 'code_instruction': |
| response_field = config.get('response_field', 'output') |
| return sample.get(response_field, '') |
| |
| elif data_type == 'conversation': |
| |
| if 'conversations' in sample: |
| conversations = sample['conversations'] |
| if isinstance(conversations, list) and len(conversations) > 0: |
| return conversations[-1].get('value', '') |
| |
| elif 'messages' in sample: |
| messages = sample['messages'] |
| if isinstance(messages, list) and len(messages) > 0: |
| return messages[-1].get('content', '') |
| |
| return "" |
| |
| elif data_type == 'multimodal_instruction': |
| instruction_field = config.get('instruction_field', 'conversations') |
| conversations = sample.get(instruction_field, []) |
| if isinstance(conversations, list) and len(conversations) > 0: |
| return conversations[-1].get('value', '') |
| return "" |
| |
| else: |
| response_field = config.get('response_field', 'output') |
| return sample.get(response_field, '') |
| except Exception as e: |
| logger.debug(f"Error getting response: {e}") |
| return "" |
|
|
| def __len__(self): |
| return len(self.dataset) if hasattr(self.dataset, '__len__') else 0 |
|
|
| def __getitem__(self, idx): |
| try: |
| sample = self.dataset[idx] |
| |
| |
| if '_config' not in sample: |
| logger.warning(f"Sample at index {idx} missing _config") |
| return None |
| |
| config = sample['_config'] |
| |
| |
| instruction_text = self._format_instruction(sample, config) |
| response_text = self._get_response(sample, config) |
| |
| if not instruction_text or not response_text: |
| return None |
| |
| pad_token_id = self.tokenizer.pad_token_id |
| if pad_token_id is None: |
| pad_token_id = self.tokenizer.eos_token_id |
| instruction_max_len = self.max_length // 2 |
| |
| |
| instruction_enc = self.tokenizer( |
| instruction_text, |
| truncation=True, |
| max_length=instruction_max_len, |
| add_special_tokens=False, |
| return_tensors='pt' |
| ) |
| instr_ids = instruction_enc['input_ids'].squeeze(0) |
| |
| |
| instr_len = instr_ids.size(0) |
| if instr_len < instruction_max_len: |
| padding = torch.full((instruction_max_len - instr_len,), pad_token_id, dtype=torch.long) |
| instr_ids = torch.cat([instr_ids, padding]) |
| |
| instr_mask = torch.cat([torch.ones(instr_len, dtype=torch.long), torch.zeros(instruction_max_len - instr_len, dtype=torch.long)]) |
| else: |
| instr_mask = torch.ones(instruction_max_len, dtype=torch.long) |
|
|
| response_max_len = self.max_length // 2 |
| |
| |
| response_enc = self.tokenizer( |
| response_text, |
| truncation=True, |
| max_length=response_max_len - 1, |
| add_special_tokens=False, |
| return_tensors='pt' |
| ) |
| resp_ids = response_enc['input_ids'].squeeze(0) |
| |
| eos_token = torch.tensor([self.tokenizer.eos_token_id], dtype=torch.long) |
| resp_ids = torch.cat([resp_ids, eos_token]) |
| |
| |
| curr_resp_len = resp_ids.size(0) |
| if curr_resp_len < response_max_len: |
| padding = torch.full((response_max_len - curr_resp_len,), pad_token_id, dtype=torch.long) |
| resp_ids = torch.cat([resp_ids, padding]) |
| resp_mask = torch.cat([torch.ones(curr_resp_len, dtype=torch.long), torch.zeros(response_max_len - curr_resp_len, dtype=torch.long)]) |
| else: |
| resp_mask = torch.ones(response_max_len, dtype=torch.long) |
|
|
| result = { |
| 'instruction': instr_ids, |
| 'response': resp_ids, |
| 'instruction_mask': instr_mask, |
| 'response_mask': resp_mask, |
| 'task': sample.get('_source', 'unknown'), |
| 'modality_data': None |
| } |
| |
| if config.get('type') == 'multimodal_instruction' and 'image' in sample: |
| try: |
| image = sample['image'] |
| if isinstance(image, Image.Image): |
| image = image.convert('RGB') |
| image_tensor = image_transform(image) |
| result['modality_data'] = {'image': image_tensor} |
| except Exception as e: |
| logger.debug(f"Error processing image: {e}") |
| |
| return result |
| |
| except Exception as e: |
| logger.debug(f"Error getting item at index {idx}: {e}") |
| import traceback |
| traceback.print_exc() |
| return None |
|
|
|
|
| class PreferenceDataset(Dataset): |
| def __init__( |
| self, |
| dataset_name: str = 'hh_rlhf', |
| tokenizer=None, |
| max_length: int = 1024, |
| max_samples: Optional[int] = None, |
| split: str = 'train' |
| ): |
| super().__init__() |
| |
| if tokenizer is None: |
| raise ValueError("tokenizer cannot be None") |
| |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| |
| if dataset_name not in POSTTRAIN_DATASETS: |
| raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(POSTTRAIN_DATASETS.keys())}") |
| |
| config = POSTTRAIN_DATASETS[dataset_name] |
| if config.get('type') != 'preference': |
| raise ValueError(f"{dataset_name} is not a preference dataset (type: {config.get('type')})") |
| |
| logger.info(f"Loading preference dataset: {dataset_name}") |
| |
| load_kwargs = { |
| 'path': config['hf_path'], |
| 'split': split, |
| 'cache_dir': HF_CACHE_DIR, |
| } |
| |
| if 'config' in config: |
| load_kwargs['name'] = config['config'] |
| |
| self.dataset = load_dataset(**load_kwargs) |
| |
| self.chosen_field = config.get('chosen_field', 'chosen') |
| self.rejected_field = config.get('rejected_field', 'rejected') |
| |
| if max_samples and len(self.dataset) > max_samples: |
| self.dataset = self.dataset.select(range(max_samples)) |
| |
| logger.info(f"Loaded {len(self.dataset)} preference pairs") |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| try: |
| sample = self.dataset[idx] |
| |
| chosen_text = sample.get(self.chosen_field, '') |
| rejected_text = sample.get(self.rejected_field, '') |
| |
| if not chosen_text or not rejected_text: |
| return None |
| |
| |
| chosen_enc = self.tokenizer( |
| chosen_text, |
| max_length=self.max_length, |
| truncation=True, |
| padding='max_length', |
| return_tensors='pt' |
| ) |
| |
| rejected_enc = self.tokenizer( |
| rejected_text, |
| max_length=self.max_length, |
| truncation=True, |
| padding='max_length', |
| return_tensors='pt' |
| ) |
| |
| return ( |
| chosen_enc['input_ids'].squeeze(0), |
| rejected_enc['input_ids'].squeeze(0), |
| chosen_enc['attention_mask'].squeeze(0), |
| rejected_enc['attention_mask'].squeeze(0) |
| ) |
| |
| except Exception as e: |
| logger.debug(f"Error getting preference item at index {idx}: {e}") |
| return None |
|
|
|
|
| def collate_fn_v2(batch): |
| batch = [item for item in batch if item is not None] |
| |
| if not batch: |
| logger.warning("Empty batch after filtering None values") |
| |
| return { |
| 'input_ids': torch.empty(0), |
| 'attention_mask': torch.empty(0) |
| } |
| |
| |
| if isinstance(batch[0], tuple): |
| if len(batch[0]) == 4: |
| chosen = torch.stack([item[0] for item in batch]) |
| rejected = torch.stack([item[1] for item in batch]) |
| chosen_mask = torch.stack([item[2] for item in batch]) |
| rejected_mask = torch.stack([item[3] for item in batch]) |
| return { |
| 'chosen': chosen, |
| 'rejected': rejected, |
| 'chosen_mask': chosen_mask, |
| 'rejected_mask': rejected_mask |
| } |
| else: |
| chosen = torch.stack([item[0] for item in batch]) |
| rejected = torch.stack([item[1] for item in batch]) |
| return {'chosen': chosen, 'rejected': rejected} |
|
|
| keys = batch[0].keys() |
| collated = {} |
|
|
| for key in keys: |
| if key in ['instruction', 'response', 'instruction_mask', |
| 'response_mask', 'input_ids', 'attention_mask']: |
| tensors = [item[key] for item in batch if item.get(key) is not None] |
| if tensors: |
| collated[key] = torch.stack(tensors) |
| else: |
| collated[key] = None |
| elif key == 'modality_data': |
| |
| modality_list = [item[key] for item in batch if item.get(key) is not None] |
| if modality_list and any(m is not None for m in modality_list): |
| |
| images = [m.get('image') for m in modality_list if m and 'image' in m] |
| if images: |
| collated[key] = {'image': torch.stack(images)} |
| else: |
| collated[key] = None |
| else: |
| collated[key] = None |
| else: |
| collated[key] = [item[key] for item in batch] |
|
|
| return collated |
|
|
|
|
| def create_pretrain_dataloader( |
| mix_name: str = 'default', |
| tokenizer=None, |
| batch_size: int = 8, |
| num_workers: int = 4, |
| max_length: int = 2048, |
| max_samples: Optional[int] = None |
| ): |
| dataset = PreTrainDataset( |
| mix_name=mix_name, |
| tokenizer=tokenizer, |
| max_length=max_length, |
| streaming=True, |
| max_samples=max_samples |
| ) |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| collate_fn=collate_fn_v2 |
| ) |
|
|
|
|
| def create_posttrain_dataloader( |
| mix_name: str = 'default', |
| tokenizer=None, |
| batch_size: int = 8, |
| num_workers: int = 4, |
| max_length: int = 2048, |
| max_samples: Optional[int] = None, |
| split: str = 'train', |
| shuffle: bool = True |
| ): |
| dataset = PostTrainDataset( |
| mix_name=mix_name, |
| tokenizer=tokenizer, |
| max_length=max_length, |
| max_samples=max_samples, |
| split=split |
| ) |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=num_workers, |
| collate_fn=collate_fn_v2, |
| pin_memory=True, |
| drop_last=False |
| ) |
|
|
|
|
| def create_preference_dataloader( |
| dataset_name: str = 'hh_rlhf', |
| tokenizer=None, |
| batch_size: int = 8, |
| num_workers: int = 4, |
| max_length: int = 1024, |
| max_samples: Optional[int] = None, |
| split: str = 'train', |
| shuffle: bool = True |
| ): |
| dataset = PreferenceDataset( |
| dataset_name=dataset_name, |
| tokenizer=tokenizer, |
| max_length=max_length, |
| max_samples=max_samples, |
| split=split |
| ) |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=num_workers, |
| collate_fn=collate_fn_v2, |
| pin_memory=True |
| ) |