import torch from torch import nn import torch.nn.functional as F batch_size = 32 block_size = 8 max_iters = 3000 eval_interval = 300 learning_rate = 1e-2 device = "cuda:1" if torch.cuda.is_available() else "cpu" eval_iters = 200 torch.manual_seed(1123) with open("input.txt") as f: text = f.read() chars = sorted(list(set(text))) vocab_size = len(chars) stoi = {ch: i for i, ch in enumerate(chars)} itos = {i: ch for i, ch in enumerate(chars)} encode = lambda s: [stoi[c] for c in s] decode = lambda l: "".join([itos[i] for i in l]) data = torch.tensor(encode(text), dtype=torch.long) n = int(0.9 * len(data)) train_data = data[:n] val_data = data[n:] def get_batch(split): data = train_data if split == "train" else val_data ix = torch.randint(len(data) - block_size, (batch_size,)) x = torch.stack([data[i : i + block_size] for i in ix]) y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix]) return x, y @torch.no_grad() def estimate_loss(model: nn.Module): out = {} model.eval() for split in ["train", "val"]: losses = torch.zeros(eval_iters) for k in range(eval_iters): X, Y = get_batch(split) X, Y = X.to(device), Y.to(device) logits, loss = model(X, Y) losses[k] = loss.item() out[split] = losses.mean() model.train() return out class BigramLanguageModel(nn.Module): def __init__(self, vocab_size): super().__init__() self.token_embedding_table = nn.Embedding(vocab_size, vocab_size) def forward(self, idx, targets=None): logits = self.token_embedding_table(idx) # BTC loss = None if targets is not None: B, T, C = logits.shape logits = logits.view(B * T, C) targets = targets.view(B * T) loss = F.cross_entropy(logits, targets) return logits, loss def generate(self, idx, max_new_tokens): for _ in range(max_new_tokens): logits, loss = self(idx) # BxTxC logits = logits[:, -1, :] # BxC probs = F.softmax(logits, dim=-1) # BxC idx_next = torch.multinomial(probs, num_samples=1) # Bx1 idx = torch.cat((idx, idx_next), dim=1) # BxT+1 return idx model = BigramLanguageModel(vocab_size) model = model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) for iter in range(max_iters): if iter % eval_interval == 0: losses = estimate_loss(model) print( f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" ) xb, yb = get_batch("train") xb, yb = xb.to(device), yb.to(device) logits, loss = model(xb, yb) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() context = torch.zeros((1, 1), dtype=torch.long, device=device) results = decode(model.generate(context, max_new_tokens=100)[0].tolist()) print(results)