Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import math, time, os | |
| from torch.utils.data import Dataset, DataLoader | |
| import tiktoken | |
| # from torch.cuda.amp import autocast, GradScaler | |
| from torch.amp.autocast_mode import autocast | |
| from torch.amp.grad_scaler import GradScaler | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| # Load dataset | |
| dataset = load_dataset("Bingsu/openwebtext_20p") | |
| # This gives you cleaned, plain text articles1 | |
| print(dataset['train'][100]['text'][:500]) # pyright: ignore[reportArgumentType] # Print the first 500 characters of the first article | |
| print(dataset['train'][600000]) # pyright: ignore[reportArgumentType] | |
| class TextDataset(Dataset): | |
| def __init__(self, hf_dataset, tokenizer, block_size): | |
| self.dataset = hf_dataset | |
| self.tokenizer = tokenizer | |
| self.block_size = block_size | |
| def __len__(self): | |
| return len(self.dataset['train']) | |
| def __getitem__(self, idx): | |
| # Start with a random index sample | |
| rand_idx = torch.randint(0, len(self.dataset['train']), (1,)).item() | |
| text = self.dataset['train'][rand_idx]['text'] | |
| tokens = self.tokenizer.encode(text) | |
| # Keep appending more samples if too short | |
| while len(tokens) < self.block_size + 1: | |
| next_idx = torch.randint(0, len(self.dataset['train']), (1,)).item() | |
| next_text = self.dataset['train'][next_idx]['text'] | |
| tokens.extend(self.tokenizer.encode(" " + next_text)) | |
| # Prevent runaway growth | |
| if len(tokens) > self.block_size * 2: | |
| break | |
| # Truncate to block_size + 1 | |
| tokens = torch.tensor(tokens[: self.block_size + 1]) | |
| x = tokens[: self.block_size] | |
| y = tokens[1 : self.block_size + 1] | |
| return x.long(), y.long() | |
| # hyperparameters | |
| train_model = True | |
| block_size = 256 | |
| n_layers = 8 | |
| n_heads = 8 | |
| dropout_p = 0.1 | |
| batch_size = 8 | |
| learning_rate = 3e-4 | |
| n_embedding = 512 | |
| max_iters = 5000 | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| class GPTModel(nn.Module): | |
| def __init__(self, vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size): | |
| super(GPTModel, self).__init__() | |
| self.token_embedding = nn.Embedding(vocab_size, n_embedding) | |
| self.position_embedding = nn.Embedding(block_size, n_embedding) | |
| self.layers = nn.ModuleList([ | |
| nn.TransformerEncoderLayer(d_model=n_embedding, nhead=n_heads, dropout=dropout_p) | |
| for _ in range(n_layers) | |
| ]) | |
| self.ln_f = nn.LayerNorm(n_embedding) | |
| self.head = nn.Linear(n_embedding, vocab_size) | |
| self.dropout = nn.Dropout(dropout_p) | |
| self.block_size = block_size | |
| def forward(self, x): | |
| bsz, seq_len = x.size() | |
| positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(bsz, seq_len) | |
| x = self.token_embedding(x) + self.position_embedding(positions) | |
| x = self.dropout(x) | |
| for layer in self.layers: | |
| x = layer(x) | |
| x = self.ln_f(x) | |
| logits = self.head(x) | |
| return logits | |
| # Initialize tokenizer and dataset | |
| tokenizer = tiktoken.get_encoding("gpt2") | |
| train_dataset = TextDataset(dataset, tokenizer, block_size=block_size) | |
| train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=16) | |
| # Define model objects | |
| vocab_size = tokenizer.n_vocab | |
| model = GPTModel(vocab_size, n_embedding, n_layers, n_heads, dropout_p, block_size).to(device) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) | |
| loss_fn = nn.CrossEntropyLoss() | |
| # Training loop | |
| def train(): | |
| torch.set_float32_matmul_precision('high') | |
| scaler = GradScaler(device) | |
| if train_model: | |
| compiled_model = torch.compile(model) | |
| pbar = tqdm(range(max_iters), desc="Training", ncols=100) | |
| data_iter = iter(train_dataloader) | |
| for count in pbar: | |
| xb, yb = next(data_iter) | |
| xb, yb = xb.to(device), yb.to(device) | |
| with autocast(device, dtype=torch.float16): | |
| logits = compiled_model(xb) | |
| loss = loss_fn(logits.view(-1, vocab_size), yb.view(-1)) | |
| # backward pass with gradient scaling | |
| optimizer.zero_grad() | |
| scaler.scale(loss).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| # update bar text dynamically | |
| pbar.set_postfix({"loss": f"{loss.item():.4f}"}) | |
| def generate_text(model, tokenizer, prompt, max_new_tokens, block_size, device): | |
| model.eval() | |
| # Encode the prompt text into token IDs | |
| tokens = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device) | |
| for _ in range(max_new_tokens): | |
| # Only keep the last block_size tokens for context | |
| input_tokens = tokens[:, -block_size:] | |
| # Get logits and take the last token's distribution | |
| logits = model(input_tokens) | |
| logits = logits[:, -1, :] # (batch=1, vocab) | |
| probs = F.softmax(logits, dim=-1) | |
| # Sample from the distribution | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| tokens = torch.cat((tokens, next_token), dim=1) | |
| # Decode back into text | |
| output_text = tokenizer.decode(tokens[0].tolist()) | |
| return output_text | |
| def save_model(model, filepath): | |
| if not os.path.exists(os.path.dirname(filepath)): | |
| os.makedirs(os.path.dirname(filepath)) | |
| torch.save(model.state_dict(), filepath) | |
| def load_model(model, filepath): | |
| model.load_state_dict(torch.load(filepath)) | |
| return model | |
| def main(): | |
| if train_model: | |
| train() | |
| save_model(model, "checkpoints/gpt_model-1.pth") | |
| else: | |
| model.load_state_dict(torch.load("checkpoints/gpt_model-1.pth")) | |
| # Example of generating text after training or loading | |
| prompt = "me when the " | |
| generated_text = generate_text(model, tokenizer, prompt, max_new_tokens=50, block_size=block_size, device=device) | |
| print(generated_text) | |
| if __name__ == "__main__": | |
| main() |