File size: 1,734 Bytes
9067733
 
 
 
 
03d287b
 
 
 
 
9067733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn
from torchvision.transforms import Compose, ToTensor, RandomHorizontalFlip, Normalize, Resize, RandomRotation
import numpy as np
from torch.utils.data import DataLoader
from DeePixBis.Dataset import PixWiseDataset
from DeePixBis.Model import DeePixBiS
from DeePixBis.Loss import PixWiseBCELoss
from DeePixBis.Metrics import predict, test_accuracy, test_loss
from DeePixBis.Trainer import Trainer

model = DeePixBiS()
model.load_state_dict(torch.load('./DeePixBiS.pth'))

loss_fn = PixWiseBCELoss()

opt = torch.optim.Adam(model.parameters(), lr=0.0001)

train_tfms = Compose([Resize([224, 224]),
                      RandomHorizontalFlip(),
                      RandomRotation(10),
                      ToTensor(),
                      Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

test_tfms = Compose([Resize([224, 224]),
                     ToTensor(),
                     Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

train_dataset = PixWiseDataset('./train_data.csv', transform=train_tfms)
train_ds = train_dataset.dataset()

val_dataset = PixWiseDataset('./test_data.csv', transform=test_tfms)
val_ds = val_dataset.dataset()

batch_size = 10
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=0, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size, shuffle=True, num_workers=0, pin_memory=True)

# for x, y, z in val_dl:
# 	_, zp = model(x)
# 	print(zp)
# 	print (z)
# 	break

# print(test_accuracy(model, train_dl))
# print(test_loss(model, train_dl, loss_fn))

# 5 epochs ran

trainer = Trainer(train_dl, val_dl, model, 1, opt, loss_fn)

print('Training Beginning\n')
trainer.fit()
print('\nTraining Complete')
torch.save(model.state_dict(), './DeePixBiS.pth')