Spaces:
Running
Running
rzimmerdev
commited on
Commit
•
a0f925f
1
Parent(s):
abe84f3
Fixed devices and demo
Browse files- .gitignore +1 -0
- Dockerfile +3 -0
- app.py +3 -0
- images/.gitkeep +0 -0
- main.py +23 -0
- notebooks/mlruns/0/meta.yaml +6 -0
- notebooks/optimizers.ipynb +37 -0
- requirements.txt +69 -0
- src/__pycache__/dataset.cpython-310.pyc +0 -0
- src/__pycache__/dataset.cpython-39.pyc +0 -0
- src/__pycache__/demo.cpython-310.pyc +0 -0
- src/__pycache__/downloader.cpython-310.pyc +0 -0
- src/__pycache__/downloader.cpython-39.pyc +0 -0
- src/__pycache__/models.cpython-310.pyc +0 -0
- src/__pycache__/models.cpython-39.pyc +0 -0
- src/__pycache__/predict.cpython-310.pyc +0 -0
- src/__pycache__/predict.cpython-39.pyc +0 -0
- src/__pycache__/train.cpython-310.pyc +0 -0
- src/__pycache__/trainer.cpython-310.pyc +0 -0
- src/__pycache__/trainer.cpython-39.pyc +0 -0
- src/dataset.py +3 -3
- src/demo.py +24 -0
- src/predict.py +35 -15
- src/train.py +49 -37
- src/trainer.py +0 -1
.gitignore
CHANGED
@@ -6,3 +6,4 @@
|
|
6 |
/downloads/*
|
7 |
/checkpoints/*
|
8 |
/checkpoints/
|
|
|
|
6 |
/downloads/*
|
7 |
/checkpoints/*
|
8 |
/checkpoints/
|
9 |
+
/venv/
|
Dockerfile
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
FROM python
|
2 |
+
WORKDIR .
|
3 |
+
RUN pip install -r requirements.txt
|
app.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from src.demo import main
|
2 |
+
|
3 |
+
main()
|
images/.gitkeep
DELETED
File without changes
|
main.py
CHANGED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from models import CNN
|
4 |
+
from dataset import DatasetMNIST, download_mnist
|
5 |
+
from train import get_dataloaders, train_net_manually, train_net_lightning
|
6 |
+
|
7 |
+
|
8 |
+
def main(device):
|
9 |
+
mnist = download_mnist("downloads/mnist/")
|
10 |
+
dataset, test_data = DatasetMNIST(*mnist["train"]), DatasetMNIST(*mnist["test"])
|
11 |
+
train_loader, validate_loader, test_loader = get_dataloaders(dataset, test_data)
|
12 |
+
|
13 |
+
# Training manually
|
14 |
+
net = CNN(input_channels=1, num_classes=10).to(device)
|
15 |
+
optim = torch.optim.Adam(net.parameters(), lr=1e-4)
|
16 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
17 |
+
max_epochs = 1
|
18 |
+
|
19 |
+
train_net_manually(net, optim, loss_fn, train_loader, validate_loader, max_epochs, device)
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
main("cpu")
|
notebooks/mlruns/0/meta.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
artifact_location: mlflow-artifacts:/0
|
2 |
+
creation_time: 1671993138419
|
3 |
+
experiment_id: '0'
|
4 |
+
last_update_time: 1671993138419
|
5 |
+
lifecycle_stage: active
|
6 |
+
name: Default
|
notebooks/optimizers.ipynb
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {
|
7 |
+
"collapsed": true,
|
8 |
+
"pycharm": {
|
9 |
+
"name": "#%%\n"
|
10 |
+
}
|
11 |
+
},
|
12 |
+
"outputs": [],
|
13 |
+
"source": []
|
14 |
+
}
|
15 |
+
],
|
16 |
+
"metadata": {
|
17 |
+
"kernelspec": {
|
18 |
+
"display_name": "Python 3",
|
19 |
+
"language": "python",
|
20 |
+
"name": "python3"
|
21 |
+
},
|
22 |
+
"language_info": {
|
23 |
+
"codemirror_mode": {
|
24 |
+
"name": "ipython",
|
25 |
+
"version": 2
|
26 |
+
},
|
27 |
+
"file_extension": ".py",
|
28 |
+
"mimetype": "text/x-python",
|
29 |
+
"name": "python",
|
30 |
+
"nbconvert_exporter": "python",
|
31 |
+
"pygments_lexer": "ipython2",
|
32 |
+
"version": "2.7.6"
|
33 |
+
}
|
34 |
+
},
|
35 |
+
"nbformat": 4,
|
36 |
+
"nbformat_minor": 0
|
37 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp==3.8.3
|
2 |
+
aiosignal==1.3.1
|
3 |
+
altair==4.2.0
|
4 |
+
anyio==3.6.2
|
5 |
+
async-timeout==4.0.2
|
6 |
+
attrs==22.2.0
|
7 |
+
certifi==2022.12.7
|
8 |
+
charset-normalizer==2.1.1
|
9 |
+
click==8.1.3
|
10 |
+
contourpy==1.0.6
|
11 |
+
cycler==0.11.0
|
12 |
+
entrypoints==0.4
|
13 |
+
fastapi==0.88.0
|
14 |
+
ffmpy==0.3.0
|
15 |
+
fonttools==4.38.0
|
16 |
+
frozenlist==1.3.3
|
17 |
+
fsspec==2022.11.0
|
18 |
+
gradio==3.15.0
|
19 |
+
h11==0.14.0
|
20 |
+
httpcore==0.16.3
|
21 |
+
httpx==0.23.1
|
22 |
+
idna==3.4
|
23 |
+
Jinja2==3.1.2
|
24 |
+
jsonschema==4.17.3
|
25 |
+
kiwisolver==1.4.4
|
26 |
+
lightning-utilities==0.5.0
|
27 |
+
linkify-it-py==1.0.3
|
28 |
+
markdown-it-py==2.1.0
|
29 |
+
MarkupSafe==2.1.1
|
30 |
+
matplotlib==3.6.2
|
31 |
+
mdit-py-plugins==0.3.3
|
32 |
+
mdurl==0.1.2
|
33 |
+
multidict==6.0.4
|
34 |
+
numpy==1.24.0
|
35 |
+
orjson==3.8.3
|
36 |
+
packaging==22.0
|
37 |
+
pandas==1.5.2
|
38 |
+
Pillow==9.3.0
|
39 |
+
plotly==5.11.0
|
40 |
+
protobuf==3.20.1
|
41 |
+
pycryptodome==3.16.0
|
42 |
+
pydantic==1.10.2
|
43 |
+
pydub==0.25.1
|
44 |
+
pyparsing==3.0.9
|
45 |
+
pyrsistent==0.19.2
|
46 |
+
python-dateutil==2.8.2
|
47 |
+
python-multipart==0.0.5
|
48 |
+
pytorch-lightning==1.8.6
|
49 |
+
pytz==2022.7
|
50 |
+
PyYAML==6.0
|
51 |
+
requests==2.28.1
|
52 |
+
rfc3986==1.5.0
|
53 |
+
six==1.16.0
|
54 |
+
sniffio==1.3.0
|
55 |
+
starlette==0.22.0
|
56 |
+
tenacity==8.1.0
|
57 |
+
tensorboardX==2.5.1
|
58 |
+
toolz==0.12.0
|
59 |
+
torch==1.13.1
|
60 |
+
torchaudio==0.13.1
|
61 |
+
torchmetrics==0.11.0
|
62 |
+
torchvision==0.14.1
|
63 |
+
tqdm==4.64.1
|
64 |
+
typing_extensions==4.4.0
|
65 |
+
uc-micro-py==1.0.1
|
66 |
+
urllib3==1.26.13
|
67 |
+
uvicorn==0.20.0
|
68 |
+
websockets==10.4
|
69 |
+
yarl==1.8.2
|
src/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (2.09 kB). View file
|
|
src/__pycache__/dataset.cpython-39.pyc
ADDED
Binary file (2.08 kB). View file
|
|
src/__pycache__/demo.cpython-310.pyc
ADDED
Binary file (977 Bytes). View file
|
|
src/__pycache__/downloader.cpython-310.pyc
ADDED
Binary file (1.87 kB). View file
|
|
src/__pycache__/downloader.cpython-39.pyc
ADDED
Binary file (1.85 kB). View file
|
|
src/__pycache__/models.cpython-310.pyc
ADDED
Binary file (1.76 kB). View file
|
|
src/__pycache__/models.cpython-39.pyc
ADDED
Binary file (1.77 kB). View file
|
|
src/__pycache__/predict.cpython-310.pyc
ADDED
Binary file (2.5 kB). View file
|
|
src/__pycache__/predict.cpython-39.pyc
ADDED
Binary file (2.41 kB). View file
|
|
src/__pycache__/train.cpython-310.pyc
ADDED
Binary file (2.4 kB). View file
|
|
src/__pycache__/trainer.cpython-310.pyc
ADDED
Binary file (1.68 kB). View file
|
|
src/__pycache__/trainer.cpython-39.pyc
ADDED
Binary file (1.74 kB). View file
|
|
src/dataset.py
CHANGED
@@ -9,7 +9,7 @@ import numpy as np
|
|
9 |
from src.downloader import download_dataset
|
10 |
|
11 |
|
12 |
-
def
|
13 |
download_dataset("mnist", download_dir)
|
14 |
|
15 |
return {"train": (download_dir + "train_images", download_dir + "train_labels"),
|
@@ -45,8 +45,8 @@ class DatasetMNIST(Dataset):
|
|
45 |
|
46 |
|
47 |
if __name__ == "__main__":
|
48 |
-
download_dir = "
|
49 |
-
mnist =
|
50 |
|
51 |
dataset = DatasetMNIST(*mnist["train"])
|
52 |
|
|
|
9 |
from src.downloader import download_dataset
|
10 |
|
11 |
|
12 |
+
def download_mnist(download_dir):
|
13 |
download_dataset("mnist", download_dir)
|
14 |
|
15 |
return {"train": (download_dir + "train_images", download_dir + "train_labels"),
|
|
|
45 |
|
46 |
|
47 |
if __name__ == "__main__":
|
48 |
+
download_dir = "downloads/mnist/"
|
49 |
+
mnist = download_mnist(download_dir)
|
50 |
|
51 |
dataset = DatasetMNIST(*mnist["train"])
|
52 |
|
src/demo.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
from src.predict import predict_interval, load_torch_net
|
4 |
+
|
5 |
+
|
6 |
+
def predict_gradio_canvas(x, net, device="cuda"):
|
7 |
+
if x is None:
|
8 |
+
return {0: 0}
|
9 |
+
else:
|
10 |
+
x = torch.from_numpy(x.reshape(1, 28, 28)).to(dtype=torch.float32, device=device)
|
11 |
+
return predict_interval(x, net, device)
|
12 |
+
|
13 |
+
|
14 |
+
def main(device="cuda"):
|
15 |
+
net = load_torch_net("../checkpoints/pytorch/version_1.pt")
|
16 |
+
|
17 |
+
gr.Interface(fn=lambda x: predict_gradio_canvas(x, net, device),
|
18 |
+
inputs="sketchpad",
|
19 |
+
outputs="label",
|
20 |
+
live=True).launch()
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
main(device="cpu")
|
src/predict.py
CHANGED
@@ -7,19 +7,20 @@ import numpy as np
|
|
7 |
import plotly.express as px
|
8 |
from plotly.subplots import make_subplots
|
9 |
|
10 |
-
from trainer import LitTrainer
|
11 |
-
from models import CNN
|
12 |
-
from dataset import DatasetMNIST,
|
13 |
|
14 |
|
15 |
-
def load_pl_net(path="
|
16 |
pl_net = LitTrainer.load_from_checkpoint(path, model=CNN(1, 10))
|
17 |
return pl_net
|
18 |
|
19 |
|
20 |
-
def load_torch_net(path="
|
21 |
-
|
22 |
-
net
|
|
|
23 |
return net
|
24 |
|
25 |
|
@@ -30,22 +31,41 @@ def get_sequence(model):
|
|
30 |
|
31 |
while i < 10:
|
32 |
x, y = dataset[j]
|
33 |
-
|
34 |
-
p =
|
35 |
-
|
36 |
-
if
|
37 |
img = np.flip(np.array(x.reshape(28, 28)), 0)
|
38 |
-
fig.add_trace(px.imshow(img).data[0], row=int(i/5)+1, col=i%5+1)
|
39 |
i += 1
|
40 |
j += 1
|
41 |
return fig
|
42 |
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
if __name__ == "__main__":
|
45 |
-
mnist =
|
46 |
dataset, test_data = DatasetMNIST(*mnist["train"]), DatasetMNIST(*mnist["test"])
|
47 |
|
48 |
print("PyTorch Lightning Network")
|
49 |
-
get_sequence(load_pl_net().to("cuda")).
|
50 |
print("Manual Network")
|
51 |
-
get_sequence(load_torch_net().to("cuda")).
|
|
|
7 |
import plotly.express as px
|
8 |
from plotly.subplots import make_subplots
|
9 |
|
10 |
+
from src.trainer import LitTrainer
|
11 |
+
from src.models import CNN
|
12 |
+
from src.dataset import DatasetMNIST, download_mnist
|
13 |
|
14 |
|
15 |
+
def load_pl_net(path="checkpoints/lightning_logs/version_26/checkpoints/epoch=9-step=1000.ckpt"):
|
16 |
pl_net = LitTrainer.load_from_checkpoint(path, model=CNN(1, 10))
|
17 |
return pl_net
|
18 |
|
19 |
|
20 |
+
def load_torch_net(path="checkpoints/pytorch/version_0.pt"):
|
21 |
+
state_dict = torch.load(path)
|
22 |
+
net = CNN(1, 10)
|
23 |
+
net.load_state_dict(state_dict)
|
24 |
return net
|
25 |
|
26 |
|
|
|
31 |
|
32 |
while i < 10:
|
33 |
x, y = dataset[j]
|
34 |
+
|
35 |
+
predicted, p = predict(x, model)
|
36 |
+
|
37 |
+
if predicted == i and p > 0.95:
|
38 |
img = np.flip(np.array(x.reshape(28, 28)), 0)
|
39 |
+
fig.add_trace(px.imshow(img).data[0], row=int(i/5)+1, col=i % 5+1)
|
40 |
i += 1
|
41 |
j += 1
|
42 |
return fig
|
43 |
|
44 |
|
45 |
+
def predict(x, model, device="cuda"):
|
46 |
+
y_pred = model(x.to(device)).detach().cpu()
|
47 |
+
predicted = int(np.argmax(y_pred))
|
48 |
+
p = torch.max(nn.functional.softmax(y_pred, dim=0))
|
49 |
+
|
50 |
+
return predicted, p
|
51 |
+
|
52 |
+
|
53 |
+
def predict_interval(x, model, device="cuda"):
|
54 |
+
y_pred = model(x.to(device))
|
55 |
+
|
56 |
+
print(y_pred)
|
57 |
+
|
58 |
+
predicted = np.argsort(y_pred.cpu().detach().numpy())
|
59 |
+
p = nn.functional.softmax(y_pred, dim=0)
|
60 |
+
|
61 |
+
return {int(i): float(p[i]) for i in predicted}
|
62 |
+
|
63 |
+
|
64 |
if __name__ == "__main__":
|
65 |
+
mnist = download_mnist("downloads/mnist/")
|
66 |
dataset, test_data = DatasetMNIST(*mnist["train"]), DatasetMNIST(*mnist["test"])
|
67 |
|
68 |
print("PyTorch Lightning Network")
|
69 |
+
get_sequence(load_pl_net().to("cuda")).show()
|
70 |
print("Manual Network")
|
71 |
+
get_sequence(load_torch_net().to("cuda")).show()
|
src/train.py
CHANGED
@@ -1,57 +1,69 @@
|
|
1 |
import torch
|
2 |
-
import
|
3 |
-
from torch import nn, optim
|
4 |
-
from torch.utils.data import random_split
|
5 |
import pytorch_lightning as pl
|
|
|
6 |
|
|
|
7 |
|
8 |
-
from trainer import LitTrainer
|
9 |
-
from models import CNN
|
10 |
|
|
|
|
|
11 |
|
12 |
-
def main():
|
13 |
-
from torch.utils.data import DataLoader
|
14 |
-
from src.dataset import DatasetMNIST, load_mnist
|
15 |
-
|
16 |
-
mnist = load_mnist("../downloads/mnist/")
|
17 |
-
dataset, test_data = DatasetMNIST(*mnist["train"]), DatasetMNIST(*mnist["test"])
|
18 |
|
|
|
19 |
train_size = round(len(dataset) * 0.8)
|
20 |
validate_size = len(dataset) - train_size
|
21 |
train_data, validate_data = random_split(dataset, [train_size, validate_size])
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
|
28 |
-
opt = optim.Adam(net.parameters(), lr=1e-4)
|
29 |
-
loss_fn = nn.CrossEntropyLoss()
|
30 |
-
max_epochs = 10
|
31 |
-
for i in range(max_epochs):
|
32 |
-
for idx, batch in enumerate(train_dataloader):
|
33 |
-
x, y = batch
|
34 |
-
x = x.to("cuda")
|
35 |
-
y = y.to("cuda")
|
36 |
|
37 |
-
|
38 |
-
|
|
|
|
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
opt.step()
|
43 |
|
44 |
if idx % 1000 == 0:
|
45 |
-
print(f"Loss: {loss
|
46 |
|
47 |
-
|
|
|
|
|
48 |
|
49 |
-
# grayscale channels = 1, mnist num_labels = 10
|
50 |
-
trainer = pl.Trainer(limit_train_batches=100, max_epochs=10, default_root_dir="../checkpoints")
|
51 |
-
pl_net = LitTrainer(CNN(input_channels=1, num_classes=10))
|
52 |
-
trainer.fit(pl_net, train_dataloader, validate_dataloader)
|
53 |
-
trainer.test(model=pl_net, dataloaders=test_dataloader)
|
54 |
|
|
|
|
|
55 |
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
from torch.utils.data import random_split, DataLoader
|
|
|
|
|
3 |
import pytorch_lightning as pl
|
4 |
+
from pytorch_lightning.loggers import MLFlowLogger
|
5 |
|
6 |
+
from src.trainer import LitTrainer
|
7 |
|
|
|
|
|
8 |
|
9 |
+
def argmax(a):
|
10 |
+
return max(range(len(a)), key=lambda x: a[x])
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
def get_dataloaders(dataset, test_data):
|
14 |
train_size = round(len(dataset) * 0.8)
|
15 |
validate_size = len(dataset) - train_size
|
16 |
train_data, validate_data = random_split(dataset, [train_size, validate_size])
|
17 |
|
18 |
+
# For 8 CPU cores
|
19 |
+
return DataLoader(train_data, num_workers=8), \
|
20 |
+
DataLoader(validate_data, num_workers=8), \
|
21 |
+
DataLoader(test_data, num_workers=8)
|
22 |
+
|
23 |
+
|
24 |
+
def train_loop(net, batch, loss_fn, optim, device="cuda"):
|
25 |
+
x, y = batch
|
26 |
+
x = x.to(device)
|
27 |
+
y = y.to(device)
|
28 |
+
|
29 |
+
y_pred = net(x).reshape(1, -1)
|
30 |
+
loss = loss_fn(y_pred, y)
|
31 |
+
truth_count = argmax(y_pred.flatten()) == y
|
32 |
+
|
33 |
+
optim.zero_grad()
|
34 |
+
loss.backward()
|
35 |
+
optim.step()
|
36 |
+
|
37 |
+
return loss.item(), truth_count
|
38 |
+
|
39 |
+
|
40 |
+
def train_net_manually(net, optim, loss_fn, train_loader, validate_loader=None, epochs=10, device="cuda"):
|
41 |
+
for i in range(epochs):
|
42 |
|
43 |
+
print("Epoch: 0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
epoch_loss = 0
|
46 |
+
epoch_truth_count = 0
|
47 |
+
for idx, batch in enumerate(train_loader):
|
48 |
+
loss, truth_count = train_loop(net, batch, loss_fn, optim, device)
|
49 |
|
50 |
+
epoch_loss += loss
|
51 |
+
epoch_truth_count += truth_count
|
|
|
52 |
|
53 |
if idx % 1000 == 0:
|
54 |
+
print(f"Loss: {loss} ({idx} / {len(train_loader)} x {i})")
|
55 |
|
56 |
+
print(f"Epoch Loss: {epoch_loss}")
|
57 |
+
print(f"Epoch Accuracy: {epoch_truth_count / len(train_loader)}")
|
58 |
+
torch.save(net.state_dict(), "checkpoints/pytorch/version_1.pt")
|
59 |
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
+
def train_net_lightning(net, optim, loss_fn, train_loader, validate_loader=None, epochs=10):
|
62 |
+
logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs")
|
63 |
|
64 |
+
pl_net = LitTrainer(net)
|
65 |
+
pl_net.optim = optim
|
66 |
+
pl_net.loss = loss_fn
|
67 |
+
trainer = pl.Trainer(limit_train_batches=100, max_epochs=epochs,
|
68 |
+
default_root_dir="../checkpoints", logger=logger)
|
69 |
+
trainer.fit(pl_net, train_loader, validate_loader)
|
src/trainer.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
# coding: utf-8
|
3 |
-
import torch
|
4 |
from torch import nn, optim
|
5 |
import pytorch_lightning as pl
|
6 |
|
|
|
1 |
#!/usr/bin/env python
|
2 |
# coding: utf-8
|
|
|
3 |
from torch import nn, optim
|
4 |
import pytorch_lightning as pl
|
5 |
|