Spaces:
Runtime error
Runtime error
File size: 4,213 Bytes
e6f2a04 e2b5593 01b64d6 e2b5593 01b64d6 e6f2a04 0b11e30 e2b5593 0b11e30 e6f2a04 0b11e30 e6f2a04 59908f1 e6f2a04 59908f1 01b64d6 0b11e30 9d7b040 e2b5593 01b64d6 e2b5593 01b64d6 e2b5593 01b64d6 e2b5593 e6f2a04 9d7b040 e6f2a04 01b64d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from models import *
import torch.nn as nn
from torchvision.models import (
squeezenet1_0,
SqueezeNet1_0_Weights,
mobilenet_v3_small,
MobileNet_V3_Small_Weights,
)
from torchvision.models import squeezenet1_0
# Constants
RANDOM_SEED = 123
BATCH_SIZE = 16
NUM_EPOCHS = 40
LEARNING_RATE = 5.488903014780378e-05
STEP_SIZE = 10
GAMMA = 0.3
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
NUM_PRINT = 100
TASK = 1
RAW_DATA_DIR = r"data/train/raw/Task "
AUG_DATA_DIR = r"data/train/augmented/Task "
EXTERNAL_DATA_DIR = r"data/train/external/Task "
TEMP_DATA_DIR = "data/temp/"
NUM_CLASSES = 7
EARLY_STOPPING_PATIENCE = 20
CLASSES = [
"Alzheimer Disease",
"Cerebral Palsy",
"Dystonia",
"Essential Tremor",
"Healthy",
"Huntington Disease",
"Parkinson Disease",
]
MODEL_SAVE_PATH = r"output/checkpoints/model.pth"
class SqueezeNet1_0WithDropout(nn.Module):
def __init__(self, num_classes=1000):
super(SqueezeNet1_0WithDropout, self).__init__()
squeezenet = squeezenet1_0(weights=SqueezeNet1_0_Weights.DEFAULT)
self.features = squeezenet.features
self.classifier = nn.Sequential(
nn.Conv2d(512, num_classes, kernel_size=1),
nn.BatchNorm2d(num_classes), # add batch normalization
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1)),
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
x = torch.flatten(x, 1)
return x
# class ShuffleNetV2WithDropout(nn.Module):
# def __init__(self, num_classes=1000):
# super(ShuffleNetV2WithDropout, self).__init__()
# shufflenet = shufflenet_v2_x2_0(weights=ShuffleNet_V2_X2_0_Weights)
# self.features = shufflenet.features
# self.classifier = nn.Sequential(
# nn.Conv2d(1024, num_classes, kernel_size=1),
# nn.BatchNorm2d(num_classes), # add batch normalization
# nn.ReLU(inplace=True),
# nn.AdaptiveAvgPool2d((1, 1))
# )
# def forward(self, x):
# x = self.features(x)
# x = self.classifier(x)
# x = torch.flatten(x, 1)
# return x
class MobileNetV3SmallWithDropout(nn.Module):
def __init__(self, num_classes=1000):
super(MobileNetV3SmallWithDropout, self).__init__()
mobilenet = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights)
self.features = mobilenet.features
self.classifier = nn.Sequential(
nn.Conv2d(576, num_classes, kernel_size=1),
nn.BatchNorm2d(num_classes), # add batch normalization
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1)),
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
x = torch.flatten(x, 1)
return x
class ResNet18WithNorm(nn.Module):
def __init__(self, num_classes=1000):
super(ResNet18WithNorm, self).__init__()
resnet = resnet18(pretrained=False)
self.features = nn.Sequential(
*list(resnet.children())[:-2]
) # Remove last 2 layers (avgpool and fc)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(512, num_classes),
nn.BatchNorm2d(num_classes), # Add batch normalization
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
x = torch.flatten(x, 1)
return x
MODEL = SqueezeNet1_0WithDropout(num_classes=7)
print(CLASSES)
preprocess = transforms.Compose(
[
transforms.Resize((64, 64)), # Resize images to 64x64
transforms.ToTensor(), # Convert to tensor
# Normalize 3 channels
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
# Custom dataset class
class CustomDataset(Dataset):
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img, label = self.data[idx]
return img, label
|