File size: 5,759 Bytes
ee8e6f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, lr_scheduler
torch.backends.cudnn.benchmark = False # You can set it to True if you experience performance gains
torch.backends.cudnn.deterministic = False
from src.loss_functions.losses import AsymmetricLoss, ASLSingleLabel
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import torch.nn.functional as F
class MyCNN(nn.Module):
def __init__(self, num_classes=12, dropout_prob=0.2, in_channels=3):
super(MyCNN, self).__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, padding=1)
self.global_avg_pooling = nn.AdaptiveAvgPool2d(1)
self.conv2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.fc1 = nn.Linear(64 *3* 3, 1024)
self.fc2 = nn.Linear(1024, 256)
self.fc3 = nn.Linear(256, num_classes)
# Dropout layers
self.dropout1 = nn.Dropout(p=dropout_prob)
self.dropout2 = nn.Dropout(p=dropout_prob)
def forward(self, x_input):
# Apply convolutional and pooling layers
# x= self.upsample(x_input)
x = F.leaky_relu(self.conv1(x_input))
x = F.max_pool2d(x, 2)
x = F.leaky_relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = F.leaky_relu(self.conv3(x))
x = F.max_pool2d(x, 2)
# Flatten the output for the fully connected layers
x = x.view(x.size(0), -1)
x = self.dropout1(x)
x = F.leaky_relu(self.fc1(x))
x = self.dropout2(x)
x = F.leaky_relu(self.fc2(x))
# Apply fully connected layers
x = self.dropout2(x)
x = self.fc3(x)
return x
# Rest of the code remains unchanged
# Initialize the model
cell_attribute_model = MyCNN(num_classes=12, dropout_prob=0.5, in_channels=256).to(device)
cell_attribute_model.train() # Set the model in training mode
# Initialize optimizer, criterion, and scheduler
optimizer_cell_model = torch.optim.SGD(cell_attribute_model.parameters(), lr=0.01, weight_decay=0.01)
step_size = 5
gamma = 0.1
scheduler_cell_model = lr_scheduler.StepLR(optimizer_cell_model, step_size=step_size, gamma=gamma)
#criterion = nn.CrossEntropyLoss()
criterion = AsymmetricLoss(gamma_neg=4, gamma_pos=1, clip=0.08, disable_torch_grad_focal_loss=True)
# criterion = ASLSingleLabel()
# /num_classes = 2
#criterion = nn.BCEWithLogitsLoss() # Binary Cross-Entropy Loss
def cell_training(cell_attribute_model_main,cell_datas, labels):
obj_batch_size = len(cell_datas)
# Set the model in training mode
#optimizer_cell_model.zero_grad()
# Filter out instances with label=2 and their corresponding cell_datas
# Filter out rows where any element in the row (excluding the first column) is equal to 2
valid_indices = [i for i, row in enumerate(labels[:,1:]) if not torch.any(row[1:] == 2).item()]
if not valid_indices:
# print("No valid instances, skipping training.")
object_batch_loss = torch.tensor(0.0, requires_grad=True, device=device) # Initialize as a torch.Tensor
return object_batch_loss
filtered_cell_datas = [cell_datas[i] for i in valid_indices]
filtered_labels = labels[:,1:][valid_indices]
# Assuming each element in filtered_cell_datas is a tensor of shape (in_channels, height, width)
cell_images = torch.stack(filtered_cell_datas).to(device)
cell_datas_batch = cell_images.squeeze(1)
filtered_labels = filtered_labels.to(device)
# Initialize the model with the dynamically determined in_channels
# in_channels = filtered_cell_datas[0].size(1) # Assuming the first element in filtered_cell_datas defines in_channels
# cell_attribute_model_main.conv1.in_channels = in_channels
# Forward pass
outputs_my = cell_attribute_model_main(cell_datas_batch.float())
outputs_my = outputs_my.view(len(valid_indices), -1)
# Process labels to create target_tensor
# label_att = filtered_labels[:, 5].float() # Assuming label[5] contains 0 or 1
# target_tensor = label_att.view(-1, 1)
# Compute the loss
num_classes = 2
one_hot_encoded_tensors = []
# Perform one-hot encoding for each column separately
for i in range(filtered_labels.size(1)):
# Extract the current column
column_values = filtered_labels[:, i].long()
# Generate one-hot encoded tensor for the current column
one_hot_encoded_col = torch.eye(num_classes, device=filtered_labels.device)[column_values]
# Reshape to match the original shape
one_hot_encoded_col = one_hot_encoded_col.unsqueeze(1)
one_hot_encoded_tensors.append(one_hot_encoded_col)
# Concatenate the one-hot encoded tensors along the second dimension (axis=1)
one_hot_encoded_result = torch.cat(one_hot_encoded_tensors, dim=1)
outputs_my = outputs_my.view(outputs_my.size(0), 6,2)
object_batch_loss = criterion(outputs_my, one_hot_encoded_result)
# Check if the loss contains NaN
if torch.isnan(object_batch_loss):
object_batch_loss= 0
# If NaN, trigger a breakpoint to inspect variables
breakpoint()
#torch.use_deterministic_algorithms(False, warn_only=True)
# Backward pass and optimization
object_batch_loss = object_batch_loss/len(filtered_labels)
# object_batch_loss.backward(retain_graph=True)
# optimizer_cell_model.step()
#scheduler_cell_model.step()
# Explicitly release tensors
#del cell_images, target_tensor
#torch.cuda.empty_cache()
return object_batch_loss
|