rzimmerdev commited on
Commit
a0f925f
1 Parent(s): abe84f3

Fixed devices and demo

Browse files
.gitignore CHANGED
@@ -6,3 +6,4 @@
6
  /downloads/*
7
  /checkpoints/*
8
  /checkpoints/
 
 
6
  /downloads/*
7
  /checkpoints/*
8
  /checkpoints/
9
+ /venv/
Dockerfile ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ FROM python
2
+ WORKDIR .
3
+ RUN pip install -r requirements.txt
app.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from src.demo import main
2
+
3
+ main()
images/.gitkeep DELETED
File without changes
main.py CHANGED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from models import CNN
4
+ from dataset import DatasetMNIST, download_mnist
5
+ from train import get_dataloaders, train_net_manually, train_net_lightning
6
+
7
+
8
+ def main(device):
9
+ mnist = download_mnist("downloads/mnist/")
10
+ dataset, test_data = DatasetMNIST(*mnist["train"]), DatasetMNIST(*mnist["test"])
11
+ train_loader, validate_loader, test_loader = get_dataloaders(dataset, test_data)
12
+
13
+ # Training manually
14
+ net = CNN(input_channels=1, num_classes=10).to(device)
15
+ optim = torch.optim.Adam(net.parameters(), lr=1e-4)
16
+ loss_fn = torch.nn.CrossEntropyLoss()
17
+ max_epochs = 1
18
+
19
+ train_net_manually(net, optim, loss_fn, train_loader, validate_loader, max_epochs, device)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ main("cpu")
notebooks/mlruns/0/meta.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ artifact_location: mlflow-artifacts:/0
2
+ creation_time: 1671993138419
3
+ experiment_id: '0'
4
+ last_update_time: 1671993138419
5
+ lifecycle_stage: active
6
+ name: Default
notebooks/optimizers.ipynb ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "collapsed": true,
8
+ "pycharm": {
9
+ "name": "#%%\n"
10
+ }
11
+ },
12
+ "outputs": [],
13
+ "source": []
14
+ }
15
+ ],
16
+ "metadata": {
17
+ "kernelspec": {
18
+ "display_name": "Python 3",
19
+ "language": "python",
20
+ "name": "python3"
21
+ },
22
+ "language_info": {
23
+ "codemirror_mode": {
24
+ "name": "ipython",
25
+ "version": 2
26
+ },
27
+ "file_extension": ".py",
28
+ "mimetype": "text/x-python",
29
+ "name": "python",
30
+ "nbconvert_exporter": "python",
31
+ "pygments_lexer": "ipython2",
32
+ "version": "2.7.6"
33
+ }
34
+ },
35
+ "nbformat": 4,
36
+ "nbformat_minor": 0
37
+ }
requirements.txt ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.3
2
+ aiosignal==1.3.1
3
+ altair==4.2.0
4
+ anyio==3.6.2
5
+ async-timeout==4.0.2
6
+ attrs==22.2.0
7
+ certifi==2022.12.7
8
+ charset-normalizer==2.1.1
9
+ click==8.1.3
10
+ contourpy==1.0.6
11
+ cycler==0.11.0
12
+ entrypoints==0.4
13
+ fastapi==0.88.0
14
+ ffmpy==0.3.0
15
+ fonttools==4.38.0
16
+ frozenlist==1.3.3
17
+ fsspec==2022.11.0
18
+ gradio==3.15.0
19
+ h11==0.14.0
20
+ httpcore==0.16.3
21
+ httpx==0.23.1
22
+ idna==3.4
23
+ Jinja2==3.1.2
24
+ jsonschema==4.17.3
25
+ kiwisolver==1.4.4
26
+ lightning-utilities==0.5.0
27
+ linkify-it-py==1.0.3
28
+ markdown-it-py==2.1.0
29
+ MarkupSafe==2.1.1
30
+ matplotlib==3.6.2
31
+ mdit-py-plugins==0.3.3
32
+ mdurl==0.1.2
33
+ multidict==6.0.4
34
+ numpy==1.24.0
35
+ orjson==3.8.3
36
+ packaging==22.0
37
+ pandas==1.5.2
38
+ Pillow==9.3.0
39
+ plotly==5.11.0
40
+ protobuf==3.20.1
41
+ pycryptodome==3.16.0
42
+ pydantic==1.10.2
43
+ pydub==0.25.1
44
+ pyparsing==3.0.9
45
+ pyrsistent==0.19.2
46
+ python-dateutil==2.8.2
47
+ python-multipart==0.0.5
48
+ pytorch-lightning==1.8.6
49
+ pytz==2022.7
50
+ PyYAML==6.0
51
+ requests==2.28.1
52
+ rfc3986==1.5.0
53
+ six==1.16.0
54
+ sniffio==1.3.0
55
+ starlette==0.22.0
56
+ tenacity==8.1.0
57
+ tensorboardX==2.5.1
58
+ toolz==0.12.0
59
+ torch==1.13.1
60
+ torchaudio==0.13.1
61
+ torchmetrics==0.11.0
62
+ torchvision==0.14.1
63
+ tqdm==4.64.1
64
+ typing_extensions==4.4.0
65
+ uc-micro-py==1.0.1
66
+ urllib3==1.26.13
67
+ uvicorn==0.20.0
68
+ websockets==10.4
69
+ yarl==1.8.2
src/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (2.09 kB). View file
 
src/__pycache__/dataset.cpython-39.pyc ADDED
Binary file (2.08 kB). View file
 
src/__pycache__/demo.cpython-310.pyc ADDED
Binary file (977 Bytes). View file
 
src/__pycache__/downloader.cpython-310.pyc ADDED
Binary file (1.87 kB). View file
 
src/__pycache__/downloader.cpython-39.pyc ADDED
Binary file (1.85 kB). View file
 
src/__pycache__/models.cpython-310.pyc ADDED
Binary file (1.76 kB). View file
 
src/__pycache__/models.cpython-39.pyc ADDED
Binary file (1.77 kB). View file
 
src/__pycache__/predict.cpython-310.pyc ADDED
Binary file (2.5 kB). View file
 
src/__pycache__/predict.cpython-39.pyc ADDED
Binary file (2.41 kB). View file
 
src/__pycache__/train.cpython-310.pyc ADDED
Binary file (2.4 kB). View file
 
src/__pycache__/trainer.cpython-310.pyc ADDED
Binary file (1.68 kB). View file
 
src/__pycache__/trainer.cpython-39.pyc ADDED
Binary file (1.74 kB). View file
 
src/dataset.py CHANGED
@@ -9,7 +9,7 @@ import numpy as np
9
  from src.downloader import download_dataset
10
 
11
 
12
- def load_mnist(download_dir):
13
  download_dataset("mnist", download_dir)
14
 
15
  return {"train": (download_dir + "train_images", download_dir + "train_labels"),
@@ -45,8 +45,8 @@ class DatasetMNIST(Dataset):
45
 
46
 
47
  if __name__ == "__main__":
48
- download_dir = "../downloads/mnist/"
49
- mnist = load_mnist(download_dir)
50
 
51
  dataset = DatasetMNIST(*mnist["train"])
52
 
 
9
  from src.downloader import download_dataset
10
 
11
 
12
+ def download_mnist(download_dir):
13
  download_dataset("mnist", download_dir)
14
 
15
  return {"train": (download_dir + "train_images", download_dir + "train_labels"),
 
45
 
46
 
47
  if __name__ == "__main__":
48
+ download_dir = "downloads/mnist/"
49
+ mnist = download_mnist(download_dir)
50
 
51
  dataset = DatasetMNIST(*mnist["train"])
52
 
src/demo.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from src.predict import predict_interval, load_torch_net
4
+
5
+
6
+ def predict_gradio_canvas(x, net, device="cuda"):
7
+ if x is None:
8
+ return {0: 0}
9
+ else:
10
+ x = torch.from_numpy(x.reshape(1, 28, 28)).to(dtype=torch.float32, device=device)
11
+ return predict_interval(x, net, device)
12
+
13
+
14
+ def main(device="cuda"):
15
+ net = load_torch_net("../checkpoints/pytorch/version_1.pt")
16
+
17
+ gr.Interface(fn=lambda x: predict_gradio_canvas(x, net, device),
18
+ inputs="sketchpad",
19
+ outputs="label",
20
+ live=True).launch()
21
+
22
+
23
+ if __name__ == "__main__":
24
+ main(device="cpu")
src/predict.py CHANGED
@@ -7,19 +7,20 @@ import numpy as np
7
  import plotly.express as px
8
  from plotly.subplots import make_subplots
9
 
10
- from trainer import LitTrainer
11
- from models import CNN
12
- from dataset import DatasetMNIST, load_mnist
13
 
14
 
15
- def load_pl_net(path="../checkpoints/lightning_logs/version_26/checkpoints/epoch=9-step=1000.ckpt"):
16
  pl_net = LitTrainer.load_from_checkpoint(path, model=CNN(1, 10))
17
  return pl_net
18
 
19
 
20
- def load_torch_net(path="../checkpoints/pytorch/version_0.pt"):
21
- net = torch.load(path)
22
- net.eval()
 
23
  return net
24
 
25
 
@@ -30,22 +31,41 @@ def get_sequence(model):
30
 
31
  while i < 10:
32
  x, y = dataset[j]
33
- y_pred = model(x.to("cuda")).detach().cpu()
34
- p = torch.max(nn.functional.softmax(y_pred, dim=0))
35
- y_pred = int(np.argmax(y_pred))
36
- if y_pred == i and p > 0.95:
37
  img = np.flip(np.array(x.reshape(28, 28)), 0)
38
- fig.add_trace(px.imshow(img).data[0], row=int(i/5)+1, col=i%5+1)
39
  i += 1
40
  j += 1
41
  return fig
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  if __name__ == "__main__":
45
- mnist = load_mnist("../downloads/mnist/")
46
  dataset, test_data = DatasetMNIST(*mnist["train"]), DatasetMNIST(*mnist["test"])
47
 
48
  print("PyTorch Lightning Network")
49
- get_sequence(load_pl_net().to("cuda")).write_image("images/pl_net.png")
50
  print("Manual Network")
51
- get_sequence(load_torch_net().to("cuda")).write_image("images/pytorch_net.png")
 
7
  import plotly.express as px
8
  from plotly.subplots import make_subplots
9
 
10
+ from src.trainer import LitTrainer
11
+ from src.models import CNN
12
+ from src.dataset import DatasetMNIST, download_mnist
13
 
14
 
15
+ def load_pl_net(path="checkpoints/lightning_logs/version_26/checkpoints/epoch=9-step=1000.ckpt"):
16
  pl_net = LitTrainer.load_from_checkpoint(path, model=CNN(1, 10))
17
  return pl_net
18
 
19
 
20
+ def load_torch_net(path="checkpoints/pytorch/version_0.pt"):
21
+ state_dict = torch.load(path)
22
+ net = CNN(1, 10)
23
+ net.load_state_dict(state_dict)
24
  return net
25
 
26
 
 
31
 
32
  while i < 10:
33
  x, y = dataset[j]
34
+
35
+ predicted, p = predict(x, model)
36
+
37
+ if predicted == i and p > 0.95:
38
  img = np.flip(np.array(x.reshape(28, 28)), 0)
39
+ fig.add_trace(px.imshow(img).data[0], row=int(i/5)+1, col=i % 5+1)
40
  i += 1
41
  j += 1
42
  return fig
43
 
44
 
45
+ def predict(x, model, device="cuda"):
46
+ y_pred = model(x.to(device)).detach().cpu()
47
+ predicted = int(np.argmax(y_pred))
48
+ p = torch.max(nn.functional.softmax(y_pred, dim=0))
49
+
50
+ return predicted, p
51
+
52
+
53
+ def predict_interval(x, model, device="cuda"):
54
+ y_pred = model(x.to(device))
55
+
56
+ print(y_pred)
57
+
58
+ predicted = np.argsort(y_pred.cpu().detach().numpy())
59
+ p = nn.functional.softmax(y_pred, dim=0)
60
+
61
+ return {int(i): float(p[i]) for i in predicted}
62
+
63
+
64
  if __name__ == "__main__":
65
+ mnist = download_mnist("downloads/mnist/")
66
  dataset, test_data = DatasetMNIST(*mnist["train"]), DatasetMNIST(*mnist["test"])
67
 
68
  print("PyTorch Lightning Network")
69
+ get_sequence(load_pl_net().to("cuda")).show()
70
  print("Manual Network")
71
+ get_sequence(load_torch_net().to("cuda")).show()
src/train.py CHANGED
@@ -1,57 +1,69 @@
1
  import torch
2
- import numpy as np
3
- from torch import nn, optim
4
- from torch.utils.data import random_split
5
  import pytorch_lightning as pl
 
6
 
 
7
 
8
- from trainer import LitTrainer
9
- from models import CNN
10
 
 
 
11
 
12
- def main():
13
- from torch.utils.data import DataLoader
14
- from src.dataset import DatasetMNIST, load_mnist
15
-
16
- mnist = load_mnist("../downloads/mnist/")
17
- dataset, test_data = DatasetMNIST(*mnist["train"]), DatasetMNIST(*mnist["test"])
18
 
 
19
  train_size = round(len(dataset) * 0.8)
20
  validate_size = len(dataset) - train_size
21
  train_data, validate_data = random_split(dataset, [train_size, validate_size])
22
 
23
- train_dataloader = DataLoader(train_data, num_workers=6) # My CPU has 8 cores
24
- validate_dataloader = DataLoader(validate_data, num_workers=2)
25
- test_dataloader = DataLoader(test_data, num_workers=8) # My CPU has 8 cores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- net = CNN(input_channels=1, num_classes=10).to("cuda")
28
- opt = optim.Adam(net.parameters(), lr=1e-4)
29
- loss_fn = nn.CrossEntropyLoss()
30
- max_epochs = 10
31
- for i in range(max_epochs):
32
- for idx, batch in enumerate(train_dataloader):
33
- x, y = batch
34
- x = x.to("cuda")
35
- y = y.to("cuda")
36
 
37
- y_pred = net(x).reshape(1, -1)
38
- loss = loss_fn(y_pred, y)
 
 
39
 
40
- opt.zero_grad()
41
- loss.backward()
42
- opt.step()
43
 
44
  if idx % 1000 == 0:
45
- print(f"Loss: {loss.item()} ({idx} / {len(train_dataloader)})")
46
 
47
- torch.save(net, "../checkpoints/pytorch/version_1.pt")
 
 
48
 
49
- # grayscale channels = 1, mnist num_labels = 10
50
- trainer = pl.Trainer(limit_train_batches=100, max_epochs=10, default_root_dir="../checkpoints")
51
- pl_net = LitTrainer(CNN(input_channels=1, num_classes=10))
52
- trainer.fit(pl_net, train_dataloader, validate_dataloader)
53
- trainer.test(model=pl_net, dataloaders=test_dataloader)
54
 
 
 
55
 
56
- if __name__ == "__main__":
57
- main()
 
 
 
 
 
1
  import torch
2
+ from torch.utils.data import random_split, DataLoader
 
 
3
  import pytorch_lightning as pl
4
+ from pytorch_lightning.loggers import MLFlowLogger
5
 
6
+ from src.trainer import LitTrainer
7
 
 
 
8
 
9
+ def argmax(a):
10
+ return max(range(len(a)), key=lambda x: a[x])
11
 
 
 
 
 
 
 
12
 
13
+ def get_dataloaders(dataset, test_data):
14
  train_size = round(len(dataset) * 0.8)
15
  validate_size = len(dataset) - train_size
16
  train_data, validate_data = random_split(dataset, [train_size, validate_size])
17
 
18
+ # For 8 CPU cores
19
+ return DataLoader(train_data, num_workers=8), \
20
+ DataLoader(validate_data, num_workers=8), \
21
+ DataLoader(test_data, num_workers=8)
22
+
23
+
24
+ def train_loop(net, batch, loss_fn, optim, device="cuda"):
25
+ x, y = batch
26
+ x = x.to(device)
27
+ y = y.to(device)
28
+
29
+ y_pred = net(x).reshape(1, -1)
30
+ loss = loss_fn(y_pred, y)
31
+ truth_count = argmax(y_pred.flatten()) == y
32
+
33
+ optim.zero_grad()
34
+ loss.backward()
35
+ optim.step()
36
+
37
+ return loss.item(), truth_count
38
+
39
+
40
+ def train_net_manually(net, optim, loss_fn, train_loader, validate_loader=None, epochs=10, device="cuda"):
41
+ for i in range(epochs):
42
 
43
+ print("Epoch: 0")
 
 
 
 
 
 
 
 
44
 
45
+ epoch_loss = 0
46
+ epoch_truth_count = 0
47
+ for idx, batch in enumerate(train_loader):
48
+ loss, truth_count = train_loop(net, batch, loss_fn, optim, device)
49
 
50
+ epoch_loss += loss
51
+ epoch_truth_count += truth_count
 
52
 
53
  if idx % 1000 == 0:
54
+ print(f"Loss: {loss} ({idx} / {len(train_loader)} x {i})")
55
 
56
+ print(f"Epoch Loss: {epoch_loss}")
57
+ print(f"Epoch Accuracy: {epoch_truth_count / len(train_loader)}")
58
+ torch.save(net.state_dict(), "checkpoints/pytorch/version_1.pt")
59
 
 
 
 
 
 
60
 
61
+ def train_net_lightning(net, optim, loss_fn, train_loader, validate_loader=None, epochs=10):
62
+ logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs")
63
 
64
+ pl_net = LitTrainer(net)
65
+ pl_net.optim = optim
66
+ pl_net.loss = loss_fn
67
+ trainer = pl.Trainer(limit_train_batches=100, max_epochs=epochs,
68
+ default_root_dir="../checkpoints", logger=logger)
69
+ trainer.fit(pl_net, train_loader, validate_loader)
src/trainer.py CHANGED
@@ -1,6 +1,5 @@
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
- import torch
4
  from torch import nn, optim
5
  import pytorch_lightning as pl
6
 
 
1
  #!/usr/bin/env python
2
  # coding: utf-8
 
3
  from torch import nn, optim
4
  import pytorch_lightning as pl
5