Spaces:
Running
Running
Ashish Reddy
commited on
Commit
·
8ea429a
1
Parent(s):
edfd494
la
Browse files- requirements.txt +1 -2
- train.py +1 -29
requirements.txt
CHANGED
|
@@ -1,3 +1,2 @@
|
|
| 1 |
torch
|
| 2 |
-
gradio
|
| 3 |
-
wandb
|
|
|
|
| 1 |
torch
|
| 2 |
+
gradio
|
|
|
train.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F,
|
| 2 |
|
| 3 |
batch_size = 64
|
| 4 |
max_len = 256
|
|
@@ -28,25 +28,6 @@ else:
|
|
| 28 |
device = torch.device('cpu')
|
| 29 |
print("Using device's CPU")
|
| 30 |
|
| 31 |
-
"""
|
| 32 |
-
--- WandB Integration ---
|
| 33 |
-
"""
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
wandb.init(
|
| 37 |
-
project="nano-model-shakesphere-training",
|
| 38 |
-
config={
|
| 39 |
-
"learning_rate": learning_rate,
|
| 40 |
-
"architecture": "decoder-only-model",
|
| 41 |
-
"dataset": "tinyshakesphere",
|
| 42 |
-
"d_model": d_model,
|
| 43 |
-
"n_layer": n_layer,
|
| 44 |
-
"n_head": n_head,
|
| 45 |
-
"max_iters": max_iters,
|
| 46 |
-
"dropout": dropout
|
| 47 |
-
}
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
with open('input.txt', 'r', encoding='utf-8') as f:
|
| 51 |
text = f.read()
|
| 52 |
|
|
@@ -128,12 +109,6 @@ if __name__ == "__main__":
|
|
| 128 |
losses = estimate_loss()
|
| 129 |
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
| 130 |
|
| 131 |
-
wandb.log({
|
| 132 |
-
"iter": iter,
|
| 133 |
-
"train/loss": losses['train'],
|
| 134 |
-
"val/loss": losses['val'],
|
| 135 |
-
"lr": learning_rate
|
| 136 |
-
})
|
| 137 |
iter_start = time.time()
|
| 138 |
xb, yb = get_batch("train")
|
| 139 |
logits, loss = model(xb, yb)
|
|
@@ -143,9 +118,6 @@ if __name__ == "__main__":
|
|
| 143 |
|
| 144 |
iter_time = time.time() - iter_start
|
| 145 |
print(f"Iteration {iter} completed in {iter_time:.2f} seconds")
|
| 146 |
-
wandb.log({"iter_time": iter_time})
|
| 147 |
-
|
| 148 |
-
wandb.finish()
|
| 149 |
|
| 150 |
print("Training finished. Saving model state...")
|
| 151 |
torch.save(model.state_dict(), 'nanogpt_model.pth')
|
|
|
|
| 1 |
+
import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F, time
|
| 2 |
|
| 3 |
batch_size = 64
|
| 4 |
max_len = 256
|
|
|
|
| 28 |
device = torch.device('cpu')
|
| 29 |
print("Using device's CPU")
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
with open('input.txt', 'r', encoding='utf-8') as f:
|
| 32 |
text = f.read()
|
| 33 |
|
|
|
|
| 109 |
losses = estimate_loss()
|
| 110 |
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
iter_start = time.time()
|
| 113 |
xb, yb = get_batch("train")
|
| 114 |
logits, loss = model(xb, yb)
|
|
|
|
| 118 |
|
| 119 |
iter_time = time.time() - iter_start
|
| 120 |
print(f"Iteration {iter} completed in {iter_time:.2f} seconds")
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
print("Training finished. Saving model state...")
|
| 123 |
torch.save(model.state_dict(), 'nanogpt_model.pth')
|