| """ |
| VicAI Dataset |
| Dataset handling for training on Wikipedia and other text sources. |
| """ |
|
|
| import os |
| import random |
| import re |
| from typing import Dict, Iterator, List, Optional |
|
|
| import requests |
| import torch |
| from torch.utils.data import Dataset, IterableDataset |
|
|
|
|
| class WikipediaDataset(IterableDataset): |
| """Stream Wikipedia articles for training.""" |
| |
| def __init__( |
| self, |
| tokenizer, |
| max_length: int = 2048, |
| languages: List[str] = ['en'], |
| min_article_length: int = 100, |
| ): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.languages = languages |
| self.min_article_length = min_article_length |
| self.base_url = "https://en.wikipedia.org/w/api.php" |
| |
| def _fetch_random_article(self) -> Optional[str]: |
| """Fetch a random Wikipedia article.""" |
| try: |
| params = { |
| 'action': 'query', |
| 'format': 'json', |
| 'generator': 'random', |
| 'grnnamespace': 0, |
| 'grnlimit': 1, |
| 'prop': 'extracts', |
| 'explaintext': True, |
| 'exsentences': 50, |
| } |
| response = requests.get(self.base_url, params=params, timeout=10) |
| data = response.json() |
| |
| pages = data['query']['pages'] |
| for page_id, page_data in pages.items(): |
| text = page_data.get('extract', '') |
| if len(text) > self.min_article_length: |
| return text |
| return None |
| except Exception as e: |
| print(f"Error fetching article: {e}") |
| return None |
| |
| def _fetch_article_by_title(self, title: str) -> Optional[str]: |
| """Fetch a specific Wikipedia article by title.""" |
| try: |
| params = { |
| 'action': 'query', |
| 'format': 'json', |
| 'titles': title, |
| 'prop': 'extracts', |
| 'explaintext': True, |
| 'exlimit': 1, |
| } |
| response = requests.get(self.base_url, params=params, timeout=10) |
| data = response.json() |
| |
| pages = data['query']['pages'] |
| for page_id, page_data in pages.items(): |
| if page_id != '-1': |
| return page_data.get('extract', '') |
| return None |
| except Exception as e: |
| print(f"Error fetching article: {e}") |
| return None |
| |
| def _clean_text(self, text: str) -> str: |
| """Clean Wikipedia text.""" |
| |
| text = re.sub(r'\[\d+\]', '', text) |
| |
| text = re.sub(r'\s+', ' ', text) |
| |
| text = re.sub(r'[^\w\s.,!?;:\'\"()-]', ' ', text) |
| return text.strip() |
| |
| def _tokenize_text(self, text: str) -> List[int]: |
| """Tokenize text and create chunks.""" |
| cleaned = self._clean_text(text) |
| tokens = self.tokenizer.encode(cleaned, add_special_tokens=True) |
| return tokens |
| |
| def __iter__(self): |
| """Iterate over Wikipedia articles.""" |
| while True: |
| text = self._fetch_random_article() |
| if text: |
| tokens = self._tokenize_text(text) |
| |
| |
| for i in range(0, len(tokens), self.max_length): |
| chunk = tokens[i:i + self.max_length] |
| if len(chunk) > 10: |
| |
| if len(chunk) < self.max_length: |
| chunk.extend([self.tokenizer.pad_token_id] * (self.max_length - len(chunk))) |
| |
| input_ids = torch.tensor(chunk[:-1]) |
| labels = torch.tensor(chunk[1:]) |
| |
| yield { |
| 'input_ids': input_ids, |
| 'labels': labels, |
| 'attention_mask': (input_ids != self.tokenizer.pad_token_id).long(), |
| } |
|
|
|
|
| class TextFileDataset(Dataset): |
| """Dataset from local text files.""" |
| |
| def __init__( |
| self, |
| file_path: str, |
| tokenizer, |
| max_length: int = 2048, |
| stride: int = 512, |
| ): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.stride = stride |
| |
| print(f"Loading dataset from {file_path}...") |
| with open(file_path, 'r', encoding='utf-8') as f: |
| text = f.read() |
| |
| |
| self.tokens = tokenizer.encode(text, add_special_tokens=False) |
| |
| |
| self.chunks = [] |
| for i in range(0, len(self.tokens) - max_length, stride): |
| chunk = self.tokens[i:i + max_length + 1] |
| if len(chunk) == max_length + 1: |
| self.chunks.append(chunk) |
| |
| print(f"Created {len(self.chunks)} chunks from {len(self.tokens)} tokens") |
| |
| def __len__(self): |
| return len(self.chunks) |
| |
| def __getitem__(self, idx): |
| chunk = self.chunks[idx] |
| input_ids = torch.tensor(chunk[:-1]) |
| labels = torch.tensor(chunk[1:]) |
| |
| return { |
| 'input_ids': input_ids, |
| 'labels': labels, |
| 'attention_mask': torch.ones_like(input_ids), |
| } |
|
|
|
|
| class MixedDataset(IterableDataset): |
| """Mix multiple data sources.""" |
| |
| def __init__( |
| self, |
| datasets: List[IterableDataset], |
| weights: Optional[List[float]] = None, |
| ): |
| self.datasets = datasets |
| self.weights = weights or [1.0] * len(datasets) |
| assert len(self.datasets) == len(self.weights) |
| |
| |
| total = sum(self.weights) |
| self.weights = [w / total for w in self.weights] |
| |
| def __iter__(self): |
| """Sample from datasets according to weights.""" |
| iterators = [iter(ds) for ds in self.datasets] |
| |
| while True: |
| |
| dataset_idx = random.choices(range(len(self.datasets)), weights=self.weights)[0] |
| |
| try: |
| yield next(iterators[dataset_idx]) |
| except StopIteration: |
| |
| iterators[dataset_idx] = iter(self.datasets[dataset_idx]) |
| yield next(iterators[dataset_idx]) |
|
|
|
|
| class PretokenizedDataset(Dataset): |
| """Dataset from pre-tokenized binary files.""" |
| |
| def __init__(self, data_dir: str, max_length: int = 2048): |
| self.data_dir = data_dir |
| self.max_length = max_length |
| |
| |
| self.files = [] |
| for fname in os.listdir(data_dir): |
| if fname.endswith('.pt'): |
| self.files.append(os.path.join(data_dir, fname)) |
| |
| self.files.sort() |
| print(f"Found {len(self.files)} pre-tokenized files") |
| |
| |
| self.lengths = [] |
| for f in self.files: |
| data = torch.load(f, map_location='cpu') |
| self.lengths.append(len(data) // max_length) |
| |
| self.total_length = sum(self.lengths) |
| |
| def __len__(self): |
| return self.total_length |
| |
| def __getitem__(self, idx): |
| |
| file_idx = 0 |
| remaining = idx |
| for i, length in enumerate(self.lengths): |
| if remaining < length: |
| file_idx = i |
| break |
| remaining -= length |
| |
| |
| data = torch.load(self.files[file_idx], map_location='cpu') |
| start = remaining * self.max_length |
| chunk = data[start:start + self.max_length + 1] |
| |
| input_ids = chunk[:-1] |
| labels = chunk[1:] |
| |
| return { |
| 'input_ids': input_ids, |
| 'labels': labels, |
| 'attention_mask': torch.ones_like(input_ids), |
| } |
|
|
|
|
| def download_wikipedia_dump(output_dir: str, language: str = 'en'): |
| """Download Wikipedia dump for offline processing.""" |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| base_url = f"https://dumps.wikimedia.org/{language}wiki/latest/" |
| files = [ |
| f"{language}wiki-latest-pages-articles-multistream.xml.bz2", |
| ] |
| |
| for filename in files: |
| url = base_url + filename |
| output_path = os.path.join(output_dir, filename) |
| |
| if os.path.exists(output_path): |
| print(f"{filename} already exists") |
| continue |
| |
| print(f"Downloading {filename}...") |
| try: |
| response = requests.get(url, stream=True) |
| response.raise_for_status() |
| |
| with open(output_path, 'wb') as f: |
| for chunk in response.iter_content(chunk_size=8192): |
| f.write(chunk) |
| |
| print(f"Downloaded {filename}") |
| except Exception as e: |
| print(f"Error downloading {filename}: {e}") |
|
|
|
|
| def create_sample_corpus(output_file: str = "sample_corpus.txt", num_articles: int = 1000): |
| """Create a sample corpus by fetching Wikipedia articles.""" |
| print(f"Creating sample corpus with {num_articles} articles...") |
| |
| dataset = WikipediaDataset( |
| tokenizer=None, |
| max_length=100000, |
| ) |
| |
| articles = [] |
| for i, item in enumerate(dataset): |
| if i >= num_articles: |
| break |
| |
| |
| text = dataset._fetch_random_article() |
| if text: |
| articles.append(text) |
| |
| if (i + 1) % 100 == 0: |
| print(f" Fetched {i + 1}/{num_articles} articles") |
| |
| |
| with open(output_file, 'w', encoding='utf-8') as f: |
| for article in articles: |
| f.write(article + '\n\n<|endoftext|>\n\n') |
| |
| print(f"Sample corpus saved to {output_file}") |
| return output_file |
|
|
|
|
| def prepare_openwebtext_data(output_dir: str, max_files: int = 100): |
| """ |
| Download and prepare OpenWebText corpus. |
| Note: This is a placeholder - actual OpenWebText requires specific download. |
| """ |
| os.makedirs(output_dir, exist_ok=True) |
| print(f"OpenWebText data preparation placeholder") |
| print(f"Please download OpenWebText from https://github.com/jcpeterson/openwebtext") |
| print(f"and place files in {output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| |
| from tokenizer import BPETokenizer |
| |
| |
| sample_texts = [ |
| "This is a sample text for testing.", |
| "Wikipedia contains many interesting articles.", |
| "Machine learning models need lots of data.", |
| ] |
| tokenizer = BPETokenizer(vocab_size=1000) |
| tokenizer.train(sample_texts) |
| |
| |
| print("\nTesting Wikipedia dataset...") |
| wiki_dataset = WikipediaDataset(tokenizer, max_length=128) |
| |
| for i, batch in enumerate(wiki_dataset): |
| if i >= 3: |
| break |
| print(f"\nBatch {i + 1}:") |
| print(f" Input shape: {batch['input_ids'].shape}") |
| print(f" Labels shape: {batch['labels'].shape}") |
|
|