File size: 1,479 Bytes
e6f2a04
 
 
 
01b64d6
e6f2a04
 
97dcf92
 
 
 
e6f2a04
97dcf92
 
73666ad
 
e6f2a04
 
97dcf92
59908f1
 
 
92bf372
97dcf92
 
e6f2a04
643d383
59908f1
01b64d6
 
 
 
 
 
 
 
 
9d7b040
 
73666ad
97dcf92
e6f2a04
 
97dcf92
e6f2a04
73666ad
e6f2a04
 
 
9d7b040
e6f2a04
 
 
 
 
 
 
 
 
 
01b64d6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from models import *

# Constants
RANDOM_SEED = 123
BATCH_SIZE = 8
NUM_EPOCHS = 150
WARMUP_EPOCHS = 5
LEARNING_RATE = 0.0001
STEP_SIZE = 10
GAMMA = 0.3
CUTMIX_ALPHA = 0.3
# DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device("cpu")
NUM_PRINT = 100
TASK = 1
WARMUP_EPOCHS = 5
RAW_DATA_DIR = r"data/train/raw/Task "
AUG_DATA_DIR = r"data/train/augmented/Task "
EXTERNAL_DATA_DIR = r"data/train/external/Task "
COMBINED_DATA_DIR = r"data/train/combined/Task "
TEST_DATA_DIR = r"data/test/Task "
TEMP_DATA_DIR = "data/temp/Task "
NUM_CLASSES = 7
LABEL_SMOOTHING_EPSILON = 0.1
EARLY_STOPPING_PATIENCE = 20
CLASSES = [
    "Alzheimer Disease",
    "Cerebral Palsy",
    "Dystonia",
    "Essential Tremor",
    "Healthy",
    "Huntington Disease",
    "Parkinson Disease",
]


MODEL = EfficientNetB3WithNorm(num_classes=NUM_CLASSES)
MODEL_SAVE_PATH = r"output/checkpoints/" + MODEL.__class__.__name__ + ".pth"
preprocess = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize(0.8289, 0.2006),
    ]
)


# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, dataset):
        self.data = dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, label = self.data[idx]
        return img, label