cycool29 commited on
Commit
32d3d37
·
1 Parent(s): d9217c6
handetect/__pycache__/configs.cpython-310.pyc ADDED
Binary file (1.56 kB). View file
 
handetect/__pycache__/data_loader.cpython-310.pyc ADDED
Binary file (1.07 kB). View file
 
handetect/__pycache__/models.cpython-310.pyc CHANGED
Binary files a/handetect/__pycache__/models.cpython-310.pyc and b/handetect/__pycache__/models.cpython-310.pyc differ
 
handetect/augment.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import Augmentor
3
+
4
+ tasks = ["1", "2", "3", "4", "5", "6"]
5
+
6
+ for task in tasks:
7
+ # Loop through all folders in Task 1 and generate augmented images for each class
8
+ for i in os.listdir("data/train/raw/Task " + task):
9
+ if i != ".DS_Store":
10
+ print("Augmenting images in class: ", i)
11
+ p = Augmentor.Pipeline(f"data/train/raw/Task {task}/{i}", output_directory=i, save_format="png")
12
+ p.rotate(probability=0.8, max_left_rotation=5, max_right_rotation=5)
13
+ p.flip_left_right(probability=0.8)
14
+ p.zoom_random(probability=0.8, percentage_area=0.8)
15
+ p.flip_top_bottom(probability=0.8)
16
+ p.random_brightness(probability=0.8, min_factor=0.5, max_factor=1.5)
17
+ p.random_contrast(probability=0.8, min_factor=0.5, max_factor=1.5)
18
+ p.random_color(probability=0.8, min_factor=0.5, max_factor=1.5)
19
+ # Generate 100 - total of original images so that the total number of images in each class is 100
20
+ p.sample(100 - len(p.augmentor_images))
21
+ # Move the folder to data/train/Task 1/augmented
22
+ # Create the folder if it does not exist
23
+ if not os.path.exists(f"data/train/augmented/Task {task}/"):
24
+ os.makedirs(f"data/train/augmented/Task {task}/")
25
+ # Move all images in the data/train/Task 1/i folder to data/train/Task 1/augmented/i
26
+ os.rename(
27
+ f"data/train/raw/Task {task}/{i}/{i}",
28
+ f"data/train/augmented/Task {task}/{i}",
29
+ )
30
+ # Rename all the augmented images to [01, 02, 03]
31
+ number = 0
32
+ for j in os.listdir(f"data/train/augmented/Task {task}/{i}"):
33
+ number = int(number) + 1
34
+ if len(str(number)) == 1:
35
+ number = "0" + str(number)
36
+ os.rename(
37
+ f"data/train/augmented/Task {task}/{i}/{j}",
38
+ f"data/train/augmented/Task {task}/{i}/{number}.png",
39
+ )
40
+
41
+
handetect/configs.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torchvision import transforms
4
+ from torch.utils.data import Dataset
5
+ from models import *
6
+
7
+ # Constants
8
+ RANDOM_SEED = 123
9
+ BATCH_SIZE = 64
10
+ NUM_EPOCHS = 100
11
+ LEARNING_RATE = 0.02750299610194638
12
+ STEP_SIZE = 10
13
+ GAMMA = 0.5
14
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15
+ NUM_PRINT = 100
16
+ TASK = 1
17
+ ORIG_DATA_DIR = r"data/train/raw/Task " + str(TASK)
18
+ AUG_DATA_DIR = r"data/train/augmented/Task " + str(TASK)
19
+ NUM_CLASSES = len(os.listdir(ORIG_DATA_DIR))
20
+ MODEL_SAVE_PATH = "output/checkpoints/model.pth"
21
+ MODEL = shufflenet_v2_x0_5(num_classes=NUM_CLASSES)
22
+
23
+ preprocess = transforms.Compose(
24
+ [
25
+ transforms.Resize((64, 64)), # Resize images to 64x64
26
+ transforms.ToTensor(), # Convert to tensor
27
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize
28
+ ]
29
+ )
30
+
31
+ # Custom dataset class
32
+ class CustomDataset(Dataset):
33
+ def __init__(self, dataset):
34
+ self.data = dataset
35
+
36
+ def __len__(self):
37
+ return len(self.data)
38
+
39
+ def __getitem__(self, idx):
40
+ img, label = self.data[idx]
41
+ return img, label
handetect/data_loader.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configs import *
2
+ from torchvision.datasets import ImageFolder
3
+ from torch.utils.data import random_split, DataLoader, Dataset
4
+
5
+
6
+ def load_data(original_dir, augmented_dir, preprocess):
7
+ # Load the dataset using ImageFolder
8
+ original_dataset = ImageFolder(root=original_dir, transform=preprocess)
9
+ augmented_dataset = ImageFolder(root=augmented_dir, transform=preprocess)
10
+ dataset = original_dataset + augmented_dataset
11
+
12
+ print("Classes: ", *original_dataset.classes, sep = ' ')
13
+ print("Length of original dataset: ", len(original_dataset))
14
+ print("Length of augmented dataset: ", len(augmented_dataset))
15
+ print("Length of total dataset: ", len(dataset))
16
+
17
+ # Split the dataset into train and validation sets
18
+ train_size = int(0.8 * len(dataset))
19
+ val_size = len(dataset) - train_size
20
+ train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
21
+
22
+ # Create data loaders for the custom dataset
23
+ train_loader = DataLoader(
24
+ CustomDataset(train_dataset), batch_size=BATCH_SIZE, shuffle=True, num_workers=0
25
+ )
26
+ valid_loader = DataLoader(
27
+ CustomDataset(val_dataset), batch_size=BATCH_SIZE, num_workers=0
28
+ )
29
+
30
+ return train_loader, valid_loader
handetect/eval.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torchvision.transforms import transforms
4
+ from sklearn.metrics import f1_score
5
+ from models import *
6
+ import pathlib
7
+ from PIL import Image
8
+ from torchmetrics import ConfusionMatrix, Accuracy
9
+ import matplotlib.pyplot as plt
10
+ from configs import *
11
+
12
+ image_path = "data/test/Task 1/"
13
+
14
+ # constants
15
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
+ NUM_CLASSES = 6
17
+
18
+ # load the model
19
+ images = list(pathlib.Path(image_path).rglob("*.png"))
20
+ classes = os.listdir(image_path)
21
+ print(images)
22
+
23
+ true_classs = []
24
+ predicted_labels = []
25
+
26
+ model = mobilenet_v3_small(pretrained=False, num_classes=NUM_CLASSES)
27
+ model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
28
+ model.eval()
29
+ model = model.to(DEVICE)
30
+
31
+
32
+ # Define transformation for preprocessing
33
+ preprocess = transforms.Compose(
34
+ [
35
+ transforms.Resize((64, 64)), # Resize images to 64x64
36
+ transforms.Grayscale(num_output_channels=3), # Convert to grayscale
37
+ transforms.ToTensor(), # Convert to tensor
38
+ transforms.Normalize((0.5,), (0.5,)), # Normalize (for grayscale)
39
+ ]
40
+ )
41
+
42
+ # evaluate the model
43
+ all_predictions = []
44
+ true_labels = []
45
+
46
+ def predict_image(image_path, model, transform):
47
+ model.eval()
48
+ correct_predictions = 0
49
+ total_predictions = len(images)
50
+
51
+ with torch.no_grad():
52
+ for i in images:
53
+ print('---------------------------')
54
+ # Check the true label of the image by checking the sequence of the folder in Task 1
55
+ true_class = classes.index(i.parts[-2])
56
+ print("Image path:", i)
57
+ print("True class:", true_class)
58
+ image = Image.open(i)
59
+ image = transform(image).unsqueeze(0)
60
+ image = image.to(DEVICE)
61
+ output = model(image)
62
+ predicted_class = torch.argmax(output, dim=1).item()
63
+ # Print the predicted class
64
+ print("Predicted class:", predicted_class)
65
+ # Append true and predicted labels to their respective lists
66
+ true_classs.append(true_class)
67
+ predicted_labels.append(predicted_class)
68
+
69
+ # Check if the prediction is correct
70
+ if predicted_class == true_class:
71
+ correct_predictions += 1
72
+
73
+ # Calculate accuracy and f1 socre
74
+ accuracy = correct_predictions / total_predictions
75
+ print("Accuracy:", accuracy)
76
+ f1 = f1_score(true_classs, predicted_labels, average='weighted')
77
+ print("Weighted F1 Score:", f1)
78
+
79
+ # Call predict_image function
80
+ predict_image(image_path, model, preprocess)
81
+
82
+ # Convert the lists to tensors
83
+ predicted_labels_tensor = torch.tensor(predicted_labels)
84
+ true_classs_tensor = torch.tensor(true_classs)
85
+
86
+ conf_matrix = ConfusionMatrix(num_classes=NUM_CLASSES, task='multiclass')
87
+ conf_matrix.update(predicted_labels_tensor, true_classs_tensor)
88
+
89
+ # Plot confusion matrix
90
+ conf_matrix.plot()
91
+ plt.show()
handetect/models.py CHANGED
@@ -2,7 +2,7 @@
2
  # This file stores all the models used in the project.#
3
  #######################################################
4
 
5
- import torch
6
  from torchvision.models import resnet50
7
  from torchvision.models import resnet18
8
  from torchvision.models import squeezenet1_0
@@ -13,3 +13,5 @@ from torchvision.models import googlenet
13
  from torchvision.models import inception_v3
14
  from torchvision.models import mobilenet_v2
15
  from torchvision.models import mobilenet_v3_small
 
 
 
2
  # This file stores all the models used in the project.#
3
  #######################################################
4
 
5
+ # Import all models from torchvision.models
6
  from torchvision.models import resnet50
7
  from torchvision.models import resnet18
8
  from torchvision.models import squeezenet1_0
 
13
  from torchvision.models import inception_v3
14
  from torchvision.models import mobilenet_v2
15
  from torchvision.models import mobilenet_v3_small
16
+ from torchvision.models import mobilenet_v3_large
17
+ from torchvision.models import shufflenet_v2_x0_5
handetect/predict.py CHANGED
@@ -6,36 +6,20 @@ from PIL import Image
6
  from handetect.models import *
7
  from torchmetrics import ConfusionMatrix
8
  import matplotlib.pyplot as plt
 
9
 
10
- # Define the path to your model checkpoint
11
- model_checkpoint_path = "model.pth"
12
-
13
- DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
-
15
- NUM_CLASSES = 6
16
-
17
- # Define transformation for preprocessing the input image
18
- preprocess = transforms.Compose(
19
- [
20
- transforms.Resize((64, 64)), # Resize the image to match training input size
21
- transforms.Grayscale(num_output_channels=3), # Convert the image to grayscale
22
- transforms.ToTensor(),
23
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize the image
24
- ]
25
- )
26
 
27
  # Load your model (change this according to your model definition)
28
- model = mobilenet_v3_small(pretrained=False, num_classes=NUM_CLASSES)
29
- model.load_state_dict(
30
- torch.load(model_checkpoint_path, map_location=DEVICE)
31
  ) # Load the model on the same device
32
- model.eval()
33
- model = model.to(DEVICE)
34
- model.eval()
35
  torch.set_grad_enabled(False)
36
 
37
 
38
- def predict_image(image_path, model=model, transform=preprocess):
39
  # Define images variable to recursively list all the data file in the image_path
40
  classes = ['Cerebral Palsy', 'Dystonia', 'Essential Tremor', 'Healthy', 'Huntington Disease', 'Parkinson Disease']
41
 
 
6
  from handetect.models import *
7
  from torchmetrics import ConfusionMatrix
8
  import matplotlib.pyplot as plt
9
+ from configs import *
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Load your model (change this according to your model definition)
13
+ MODEL.load_state_dict(
14
+ torch.load(MODEL_SAVE_PATH, map_location=DEVICE)
 
15
  ) # Load the model on the same device
16
+ MODEL.eval()
17
+ MODEL = MODEL.to(DEVICE)
18
+ MODEL.eval()
19
  torch.set_grad_enabled(False)
20
 
21
 
22
+ def predict_image(image_path, model=MODEL, transform=preprocess):
23
  # Define images variable to recursively list all the data file in the image_path
24
  classes = ['Cerebral Palsy', 'Dystonia', 'Essential Tremor', 'Healthy', 'Huntington Disease', 'Parkinson Disease']
25
 
handetect/train.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torchvision.transforms import transforms
6
+ from torch.utils.data import DataLoader, random_split, Dataset
7
+ from torchvision.datasets import ImageFolder
8
+ import matplotlib.pyplot as plt
9
+ from models import *
10
+ from scipy.ndimage import gaussian_filter1d
11
+ from torch.utils.tensorboard import SummaryWriter # print to tensorboard
12
+ from torchvision.utils import make_grid
13
+ import pandas as pd
14
+ from configs import *
15
+ import data_loader
16
+
17
+ # torch.cuda.empty_cache()
18
+ # os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:1024"
19
+
20
+ writer = SummaryWriter(log_dir="output/tensorboard")
21
+
22
+
23
+ # Data loader
24
+ train_loader, valid_loader = data_loader.load_data(
25
+ ORIG_DATA_DIR, AUG_DATA_DIR, preprocess
26
+ )
27
+
28
+
29
+ # Initialize model, criterion, optimizer, and scheduler
30
+ MODEL = MODEL.to(DEVICE)
31
+ criterion = nn.CrossEntropyLoss()
32
+ # Adam optimizer
33
+ optimizer = optim.Adam(MODEL.parameters(), lr=LEARNING_RATE)
34
+ # StepLR scheduler
35
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)
36
+
37
+ # Lists to store training and validation loss history
38
+ TRAIN_LOSS_HIST = []
39
+ VAL_LOSS_HIST = []
40
+ AVG_TRAIN_LOSS_HIST = []
41
+ AVG_VAL_LOSS_HIST = []
42
+ TRAIN_ACC_HIST = []
43
+ VAL_ACC_HIST = []
44
+
45
+ # Training loop
46
+ for epoch in range(NUM_EPOCHS):
47
+ MODEL.train(True) # Set model to training mode
48
+ running_loss = 0.0
49
+ total_train = 0
50
+ correct_train = 0
51
+
52
+ for i, (inputs, labels) in enumerate(train_loader, 0):
53
+ inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
54
+ optimizer.zero_grad()
55
+ outputs = MODEL(inputs)
56
+ loss = criterion(outputs, labels)
57
+ loss.backward()
58
+ optimizer.step()
59
+ running_loss += loss.item()
60
+
61
+ if (i + 1) % NUM_PRINT == 0:
62
+ print(
63
+ "[Epoch %d, Batch %d] Loss: %.6f"
64
+ % (epoch + 1, i + 1, running_loss / NUM_PRINT)
65
+ )
66
+ running_loss = 0.0
67
+
68
+ _, predicted = torch.max(outputs, 1)
69
+ total_train += labels.size(0)
70
+ correct_train += (predicted == labels).sum().item()
71
+
72
+ TRAIN_ACC_HIST.append(correct_train / total_train)
73
+
74
+ TRAIN_LOSS_HIST.append(loss.item())
75
+
76
+ # Calculate the average training loss for the epoch
77
+ avg_train_loss = running_loss / len(train_loader)
78
+ writer.add_scalar("Loss/Train", avg_train_loss, epoch)
79
+ writer.add_scalar("Accuracy/Train", correct_train / total_train, epoch)
80
+ AVG_TRAIN_LOSS_HIST.append(avg_train_loss)
81
+
82
+ # Print average training loss for the epoch
83
+ print("[Epoch %d] Average Training Loss: %.6f" % (epoch + 1, avg_train_loss))
84
+
85
+ # Learning rate scheduling
86
+ lr_1 = optimizer.param_groups[0]["lr"]
87
+ print("Learning Rate: {:.15f}".format(lr_1))
88
+ scheduler.step()
89
+
90
+ # Validation loop
91
+ MODEL.eval() # Set model to evaluation mode
92
+ val_loss = 0.0
93
+ correct_val = 0
94
+ total_val = 0
95
+
96
+ with torch.no_grad():
97
+ for inputs, labels in valid_loader:
98
+ inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
99
+ outputs = MODEL(inputs)
100
+ loss = criterion(outputs, labels)
101
+ val_loss += loss.item()
102
+ # Calculate accuracy
103
+ _, predicted = torch.max(outputs, 1)
104
+ total_val += labels.size(0)
105
+ correct_val += (predicted == labels).sum().item()
106
+
107
+ VAL_LOSS_HIST.append(loss.item())
108
+
109
+ # Calculate the average validation loss for the epoch
110
+ avg_val_loss = val_loss / len(valid_loader)
111
+ AVG_VAL_LOSS_HIST.append(loss.item())
112
+ print("Average Validation Loss: %.6f" % (avg_val_loss))
113
+
114
+ # Calculate the accuracy of validation set
115
+ val_accuracy = correct_val / total_val
116
+ VAL_ACC_HIST.append(val_accuracy)
117
+ print("Validation Accuracy: %.6f" % (val_accuracy))
118
+
119
+ writer.add_scalar("Loss/Validation", avg_val_loss, epoch)
120
+ writer.add_scalar("Accuracy/Validation", val_accuracy, epoch)
121
+ # Add sample images to TensorBoard
122
+ sample_images, _ = next(iter(valid_loader)) # Get a batch of sample images
123
+ sample_images = sample_images.to(DEVICE)
124
+ grid_image = make_grid(
125
+ sample_images, nrow=8, normalize=True
126
+ ) # Create a grid of images
127
+ writer.add_image("Sample Images", grid_image, global_step=epoch)
128
+
129
+ # End of training loop
130
+
131
+ # Save the model
132
+
133
+ torch.save(MODEL.state_dict(), MODEL_SAVE_PATH)
134
+ print("Model saved at", MODEL_SAVE_PATH)
135
+
136
+ print("Generating loss plot...")
137
+ # train_loss_line = gaussian_filter1d(TRAIN_LOSS_HIST, sigma=10)
138
+ # val_loss_line = gaussian_filter1d(VAL_LOSS_HIST, sigma=10)
139
+ # plt.plot(range(1, NUM_EPOCHS + 1), train_loss_line, label='Train Loss')
140
+ # plt.plot(range(1, NUM_EPOCHS + 1), val_loss_line, label='Validation Loss')
141
+ avg_train_loss_line = gaussian_filter1d(AVG_TRAIN_LOSS_HIST, sigma=2)
142
+ avg_val_loss_line = gaussian_filter1d(AVG_VAL_LOSS_HIST, sigma=2)
143
+ train_loss_line = gaussian_filter1d(TRAIN_LOSS_HIST, sigma=2)
144
+ val_loss_line = gaussian_filter1d(VAL_LOSS_HIST, sigma=2)
145
+ train_acc_line = gaussian_filter1d(TRAIN_ACC_HIST, sigma=2)
146
+ val_acc_line = gaussian_filter1d(VAL_ACC_HIST, sigma=2)
147
+ plt.plot(range(1, NUM_EPOCHS + 1), train_loss_line, label="Train Loss")
148
+ plt.plot(range(1, NUM_EPOCHS + 1), val_loss_line, label="Validation Loss")
149
+ plt.xlabel("Epochs")
150
+ plt.ylabel("Loss")
151
+ plt.legend()
152
+ plt.title("Train Loss and Validation Loss")
153
+ plt.savefig("loss_plot.png")
154
+ plt.clf()
155
+ plt.plot(range(1, NUM_EPOCHS + 1), avg_train_loss_line, label="Average Train Loss")
156
+ plt.plot(range(1, NUM_EPOCHS + 1), avg_val_loss_line, label="Average Validation Loss")
157
+ plt.xlabel("Epochs")
158
+ plt.ylabel("Loss")
159
+ plt.legend()
160
+ plt.title("Average Train Loss and Average Validation Loss")
161
+ plt.savefig("avg_loss_plot.png")
162
+ plt.clf()
163
+ plt.plot(range(1, NUM_EPOCHS + 1), train_acc_line, label="Train Accuracy")
164
+ plt.plot(range(1, NUM_EPOCHS + 1), val_acc_line, label="Validation Accuracy")
165
+ plt.xlabel("Epochs")
166
+ plt.ylabel("Accuracy")
167
+ plt.legend()
168
+ plt.title("Train Accuracy and Validation Accuracy")
169
+ plt.savefig("accuracy_plot.png")
170
+
171
+ dummy_input = torch.randn(1, 3, 64, 64).to(DEVICE) # Adjust input shape accordingly
172
+ writer.add_graph(MODEL, dummy_input)
173
+ # Close TensorBoard writer
174
+ writer.close()
handetect/tuning.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torchvision.transforms import transforms
6
+ from torch.utils.data import DataLoader, random_split, Dataset
7
+ from torchvision.datasets import ImageFolder
8
+ from models import *
9
+ from torch.utils.tensorboard import SummaryWriter #print to tensorboard
10
+ from torchvision.utils import make_grid
11
+ import optuna
12
+ from configs import *
13
+
14
+ writer = SummaryWriter()
15
+
16
+ # Error if the classes in the original dataset and augmented dataset are not the same
17
+ assert (
18
+ os.listdir(ORIG_DATA_DIR) == os.listdir(AUG_DATA_DIR)
19
+ ), "Classes in original dataset and augmented dataset are not the same"
20
+
21
+
22
+ # Load the dataset using ImageFolder
23
+ original_dataset = ImageFolder(root=ORIG_DATA_DIR, transform=preprocess)
24
+ augmented_dataset = ImageFolder(root=AUG_DATA_DIR, transform=preprocess)
25
+ dataset = original_dataset + augmented_dataset
26
+
27
+ print("Classes: ", original_dataset.classes)
28
+ print("Length of original dataset: ", len(original_dataset))
29
+ print("Length of augmented dataset: ", len(augmented_dataset))
30
+ print("Length of total dataset: ", len(dataset))
31
+
32
+
33
+ # Custom dataset class
34
+ class CustomDataset(Dataset):
35
+ def __init__(self, dataset):
36
+ self.data = dataset
37
+
38
+ def __len__(self):
39
+ return len(self.data)
40
+
41
+ def __getitem__(self, idx):
42
+ img, label = self.data[idx]
43
+ return img, label
44
+
45
+
46
+ # Split the dataset into train and validation sets
47
+ train_size = int(0.8 * len(dataset))
48
+ val_size = len(dataset) - train_size
49
+ train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
50
+
51
+ # Create data loaders for the custom dataset
52
+ train_loader = DataLoader(
53
+ CustomDataset(train_dataset), batch_size=BATCH_SIZE, shuffle=True, num_workers=0
54
+ )
55
+ valid_loader = DataLoader(
56
+ CustomDataset(val_dataset), batch_size=BATCH_SIZE, num_workers=0
57
+ )
58
+
59
+ # Initialize model, criterion, optimizer, and scheduler
60
+ MODEL = MODEL.to(DEVICE)
61
+ criterion = nn.CrossEntropyLoss()
62
+ # Adam optimizer
63
+ optimizer = optim.Adam(MODEL.parameters(), lr=LEARNING_RATE)
64
+
65
+ # ReduceLROnPlateau scheduler
66
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
67
+ optimizer, mode="min", factor=0.1, patience=10, verbose=True
68
+ )
69
+
70
+ # Lists to store training and validation loss history
71
+ TRAIN_LOSS_HIST = []
72
+ VAL_LOSS_HIST = []
73
+ AVG_TRAIN_LOSS_HIST = []
74
+ AVG_VAL_LOSS_HIST = []
75
+ TRAIN_ACC_HIST = []
76
+ VAL_ACC_HIST = []
77
+
78
+ def objective(trial):
79
+ learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True)
80
+ batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
81
+
82
+ # Modify the model and optimizer using suggested hyperparameters
83
+ optimizer = optim.Adam(MODEL.parameters(), lr=learning_rate)
84
+
85
+ for epoch in range(NUM_EPOCHS):
86
+ MODEL.train(True)
87
+ running_loss = 0.0
88
+ total_train = 0
89
+ correct_train = 0
90
+
91
+ for i, (inputs, labels) in enumerate(train_loader, 0):
92
+ inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
93
+ optimizer.zero_grad()
94
+ outputs = MODEL(inputs)
95
+ loss = criterion(outputs, labels)
96
+ loss.backward()
97
+ optimizer.step()
98
+ running_loss += loss.item()
99
+
100
+ if (i + 1) % NUM_PRINT == 0:
101
+ print(
102
+ "[Epoch %d, Batch %d] Loss: %.6f"
103
+ % (epoch + 1, i + 1, running_loss / NUM_PRINT)
104
+ )
105
+ running_loss = 0.0
106
+
107
+ _, predicted = torch.max(outputs, 1)
108
+ total_train += labels.size(0)
109
+ correct_train += (predicted == labels).sum().item()
110
+
111
+ TRAIN_LOSS_HIST.append(loss.item())
112
+ train_accuracy = correct_train / total_train
113
+ TRAIN_ACC_HIST.append(train_accuracy)
114
+ # Calculate the average training loss for the epoch
115
+ avg_train_loss = running_loss / len(train_loader)
116
+
117
+ writer.add_scalar('Loss/Train', avg_train_loss, epoch)
118
+ writer.add_scalar('Accuracy/Train', train_accuracy, epoch)
119
+ AVG_TRAIN_LOSS_HIST.append(avg_train_loss)
120
+
121
+ # Print average training loss for the epoch
122
+ print("[Epoch %d] Average Training Loss: %.6f" % (epoch + 1, avg_train_loss))
123
+
124
+ # Learning rate scheduling
125
+ lr_1 = optimizer.param_groups[0]["lr"]
126
+ print("Learning Rate: {:.15f}".format(lr_1))
127
+ scheduler.step(avg_train_loss)
128
+
129
+ # Validation loop
130
+ MODEL.eval() # Set model to evaluation mode
131
+ val_loss = 0.0
132
+ correct_val = 0
133
+ total_val = 0
134
+
135
+ with torch.no_grad():
136
+ for inputs, labels in valid_loader:
137
+ inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
138
+ outputs = MODEL(inputs)
139
+ loss = criterion(outputs, labels)
140
+ val_loss += loss.item()
141
+ # Calculate accuracy
142
+ _, predicted = torch.max(outputs, 1)
143
+ total_val += labels.size(0)
144
+ correct_val += (predicted == labels).sum().item()
145
+
146
+ VAL_LOSS_HIST.append(loss.item())
147
+
148
+ # Calculate the average validation loss for the epoch
149
+ avg_val_loss = val_loss / len(valid_loader)
150
+ AVG_VAL_LOSS_HIST.append(loss.item())
151
+ print("Average Validation Loss: %.6f" % (avg_val_loss))
152
+
153
+ # Calculate the accuracy of validation set
154
+ val_accuracy = correct_val / total_val
155
+ VAL_ACC_HIST.append(val_accuracy)
156
+ print("Validation Accuracy: %.6f" % (val_accuracy))
157
+ writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
158
+ writer.add_scalar('Accuracy/Validation', val_accuracy, epoch)
159
+ # Add sample images to TensorBoard
160
+ sample_images, _ = next(iter(valid_loader)) # Get a batch of sample images
161
+ sample_images = sample_images.to(DEVICE)
162
+ grid_image = make_grid(sample_images, nrow=8, normalize=True) # Create a grid of images
163
+ writer.add_image('Sample Images', grid_image, global_step=epoch)
164
+ # Validation loop
165
+ MODEL.eval() # Set model to evaluation mode
166
+ correct_val = 0
167
+ total_val = 0
168
+
169
+ with torch.no_grad():
170
+ for inputs, labels in valid_loader:
171
+ inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
172
+ outputs = MODEL(inputs)
173
+ _, predicted = torch.max(outputs, 1)
174
+ total_val += labels.size(0)
175
+ correct_val += (predicted == labels).sum().item()
176
+
177
+ # suan evaluation score
178
+ evaluation_score = correct_val / total_val
179
+
180
+ # Return the evaluation score
181
+ return evaluation_score
182
+
183
+
184
+ if __name__ == "__main__":
185
+ study = optuna.create_study(direction="maximize")
186
+ study.optimize(objective, n_trials=300, timeout=800)
187
+
188
+ # Print statistics
189
+ print("Number of finished trials: ", len(study.trials))
190
+ pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
191
+ print("Number of pruned trials: ", len(pruned_trials))
192
+ complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
193
+ print("Number of complete trials: ", len(complete_trials))
194
+
195
+ # Print best trial
196
+ trial = study.best_trial
197
+ print("Best trial:")
198
+ print(" Value: ", trial.value)
199
+ print(" Params: ")
200
+ for key, value in trial.params.items():
201
+ print(f" {key}: {value}")
202
+
203
+
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ