|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.optim import SGD, lr_scheduler |
|
|
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cudnn.deterministic = False |
|
from src.loss_functions.losses import AsymmetricLoss, ASLSingleLabel |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
import torch.nn.functional as F |
|
|
|
class MyCNN(nn.Module): |
|
def __init__(self, num_classes=12, dropout_prob=0.2, in_channels=3): |
|
super(MyCNN, self).__init__() |
|
self.upsample = nn.Upsample(scale_factor=2, mode='nearest') |
|
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, padding=1) |
|
self.global_avg_pooling = nn.AdaptiveAvgPool2d(1) |
|
self.conv2 = nn.Conv2d(128, 64, kernel_size=3, padding=1) |
|
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1) |
|
self.fc1 = nn.Linear(64 *3* 3, 1024) |
|
self.fc2 = nn.Linear(1024, 256) |
|
self.fc3 = nn.Linear(256, num_classes) |
|
|
|
|
|
self.dropout1 = nn.Dropout(p=dropout_prob) |
|
self.dropout2 = nn.Dropout(p=dropout_prob) |
|
|
|
|
|
def forward(self, x_input): |
|
|
|
|
|
x = F.leaky_relu(self.conv1(x_input)) |
|
x = F.max_pool2d(x, 2) |
|
x = F.leaky_relu(self.conv2(x)) |
|
x = F.max_pool2d(x, 2) |
|
x = F.leaky_relu(self.conv3(x)) |
|
x = F.max_pool2d(x, 2) |
|
|
|
|
|
x = x.view(x.size(0), -1) |
|
x = self.dropout1(x) |
|
x = F.leaky_relu(self.fc1(x)) |
|
x = self.dropout2(x) |
|
x = F.leaky_relu(self.fc2(x)) |
|
|
|
|
|
|
|
x = self.dropout2(x) |
|
x = self.fc3(x) |
|
return x |
|
|
|
|
|
|
|
|
|
cell_attribute_model = MyCNN(num_classes=12, dropout_prob=0.5, in_channels=256).to(device) |
|
cell_attribute_model.train() |
|
|
|
|
|
optimizer_cell_model = torch.optim.SGD(cell_attribute_model.parameters(), lr=0.01, weight_decay=0.01) |
|
step_size = 5 |
|
gamma = 0.1 |
|
scheduler_cell_model = lr_scheduler.StepLR(optimizer_cell_model, step_size=step_size, gamma=gamma) |
|
|
|
criterion = AsymmetricLoss(gamma_neg=4, gamma_pos=1, clip=0.08, disable_torch_grad_focal_loss=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def cell_training(cell_attribute_model_main,cell_datas, labels): |
|
obj_batch_size = len(cell_datas) |
|
|
|
|
|
|
|
|
|
|
|
valid_indices = [i for i, row in enumerate(labels[:,1:]) if not torch.any(row[1:] == 2).item()] |
|
|
|
if not valid_indices: |
|
|
|
object_batch_loss = torch.tensor(0.0, requires_grad=True, device=device) |
|
|
|
return object_batch_loss |
|
|
|
filtered_cell_datas = [cell_datas[i] for i in valid_indices] |
|
filtered_labels = labels[:,1:][valid_indices] |
|
|
|
|
|
cell_images = torch.stack(filtered_cell_datas).to(device) |
|
cell_datas_batch = cell_images.squeeze(1) |
|
filtered_labels = filtered_labels.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs_my = cell_attribute_model_main(cell_datas_batch.float()) |
|
outputs_my = outputs_my.view(len(valid_indices), -1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
num_classes = 2 |
|
one_hot_encoded_tensors = [] |
|
|
|
|
|
for i in range(filtered_labels.size(1)): |
|
|
|
column_values = filtered_labels[:, i].long() |
|
|
|
|
|
one_hot_encoded_col = torch.eye(num_classes, device=filtered_labels.device)[column_values] |
|
|
|
|
|
one_hot_encoded_col = one_hot_encoded_col.unsqueeze(1) |
|
|
|
one_hot_encoded_tensors.append(one_hot_encoded_col) |
|
|
|
|
|
one_hot_encoded_result = torch.cat(one_hot_encoded_tensors, dim=1) |
|
outputs_my = outputs_my.view(outputs_my.size(0), 6,2) |
|
|
|
object_batch_loss = criterion(outputs_my, one_hot_encoded_result) |
|
|
|
|
|
if torch.isnan(object_batch_loss): |
|
object_batch_loss= 0 |
|
|
|
breakpoint() |
|
|
|
|
|
|
|
|
|
object_batch_loss = object_batch_loss/len(filtered_labels) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return object_batch_loss |
|
|
|
|