|
import torch |
|
import numpy as np |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
|
|
|
|
class CustomTverskyLoss(nn.Module): |
|
def __init__(self, alpha=0.1, beta=0.9, size_average=True): |
|
super(CustomTverskyLoss, self).__init__() |
|
self.alpha = alpha |
|
self.beta = beta |
|
self.size_average = size_average |
|
|
|
def forward(self, inputs, targets, smooth=1): |
|
|
|
|
|
|
|
|
|
if inputs.shape != targets.shape: |
|
raise ValueError("Shape mismatch: inputs and targets must have the same shape") |
|
|
|
|
|
tversky_loss_values = [] |
|
for input_sample, target_sample in zip(inputs, targets): |
|
|
|
input_sample = input_sample.view(-1) |
|
target_sample = target_sample.view(-1) |
|
|
|
|
|
true_positives = (input_sample * target_sample).sum() |
|
false_positives = (input_sample * (1 - target_sample)).sum() |
|
false_negatives = ((1 - input_sample) * target_sample).sum() |
|
|
|
|
|
tversky_index = (true_positives + smooth) / (true_positives + self.alpha * false_positives + self.beta * false_negatives + smooth) |
|
|
|
tversky_loss_values.append(1 - tversky_index) |
|
|
|
|
|
tversky_loss_values = torch.stack(tversky_loss_values) |
|
|
|
|
|
if self.size_average: |
|
return tversky_loss_values.mean() |
|
else: |
|
|
|
return tversky_loss_values |
|
|
|
class CustomDiceLoss(nn.Module): |
|
def __init__(self, weight=None, size_average=True): |
|
super(CustomDiceLoss, self).__init__() |
|
self.size_average = size_average |
|
def forward(self, inputs, targets, smooth=1): |
|
|
|
|
|
|
|
|
|
|
|
if inputs.shape != targets.shape: |
|
raise ValueError("Shape mismatch: inputs and targets must have the same shape") |
|
|
|
|
|
dice_loss_values = [] |
|
for input_sample, target_sample in zip(inputs, targets): |
|
|
|
|
|
input_sample = input_sample.view(-1) |
|
target_sample = target_sample.view(-1) |
|
|
|
intersection = (input_sample * target_sample).sum() |
|
dice = (2. * intersection + smooth) / (input_sample.sum() + target_sample.sum() + smooth) |
|
|
|
dice_loss_values.append(1 - dice) |
|
|
|
|
|
dice_loss_values = torch.stack(dice_loss_values) |
|
|
|
|
|
if self.size_average: |
|
return dice_loss_values.mean() |
|
else: |
|
|
|
return dice_loss_values |
|
|
|
def smooth_heaviside(phi, alpha, epsilon): |
|
|
|
scaled_phi = (phi - alpha) / epsilon |
|
|
|
|
|
H = torch.sigmoid(scaled_phi) |
|
|
|
return H |
|
def calc_Phi(variable, LSgrid): |
|
device = variable.device |
|
|
|
x0 = variable[0] |
|
y0 = variable[1] |
|
L = variable[2] |
|
t = variable[3] |
|
angle = variable[4] |
|
|
|
|
|
st = torch.sin(angle) |
|
ct = torch.cos(angle) |
|
x1 = ct * (LSgrid[0][:, None].to(device) - x0) + st * (LSgrid[1][:, None].to(device) - y0) |
|
y1 = -st * (LSgrid[0][:, None].to(device) - x0) + ct * (LSgrid[1][:, None].to(device) - y0) |
|
|
|
|
|
a = L / 2 |
|
b = t / 2 |
|
small_constant = 1e-9 |
|
temp = ((x1 / (a + small_constant))**6) + ((y1 / (b + small_constant))**6) |
|
|
|
|
|
allPhi = 1 - (temp + small_constant)**(1/6) |
|
|
|
|
|
alpha = torch.tensor(0.0, device=device, dtype=torch.float32) |
|
epsilon = torch.tensor(0.001, device=device, dtype=torch.float32) |
|
H_phi = smooth_heaviside(allPhi, alpha, epsilon) |
|
return allPhi, H_phi |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
from matplotlib.colors import TwoSlopeNorm |
|
|
|
def preprocess_image(image_path, threshold_value=0.9, upscale=False, upscale_factor=2.0): |
|
image = Image.open(image_path).convert('L') |
|
image = image.point(lambda x: 255 if x > threshold_value * 255 else 0, '1') |
|
|
|
if upscale: |
|
image = image.resize( |
|
(int(image.width * upscale_factor), int(image.height * upscale_factor)), |
|
resample=Image.BICUBIC |
|
) |
|
|
|
return image |
|
|
|
def run_model(model, image, conf=0.05, iou=0.5, imgsz=640): |
|
results = model(image, conf=conf, iou=iou, imgsz=imgsz) |
|
return results |
|
|
|
def save_results(results, filename='results.jpg'): |
|
for r in results: |
|
im_array = r.plot(boxes=True, labels=False, line_width=1) |
|
im = Image.fromarray(im_array[..., ::-1]) |
|
im.save(filename) |
|
|
|
def process_results(results, input_image): |
|
diceloss = CustomDiceLoss() |
|
tverskyloss = CustomTverskyLoss() |
|
|
|
prediction_tensor = results[0].regression_preds.to('cpu').detach() |
|
input_image_array = np.array(input_image.convert('L')) |
|
input_image_array_tensor = torch.tensor(input_image_array) / 255.0 |
|
input_image_array_tensor = 1.0 - input_image_array_tensor |
|
input_image_array_tensor = torch.flip(input_image_array_tensor, [0]) |
|
|
|
for r in results: |
|
im_array = r.plot(boxes=True, labels=False, line_width=1) |
|
seg_result = Image.fromarray(im_array[..., ::-1]) |
|
|
|
DH = input_image_array.shape[0] / min(input_image_array.shape[1], input_image_array.shape[0]) |
|
DW = input_image_array.shape[1] / min(input_image_array.shape[1], input_image_array.shape[0]) |
|
nelx = input_image_array.shape[1] - 1 |
|
nely = input_image_array.shape[0] - 1 |
|
|
|
x, y = torch.meshgrid(torch.linspace(0, DW, nelx+1), torch.linspace(0, DH, nely+1)) |
|
LSgrid = torch.stack((x.flatten(), y.flatten()), dim=0) |
|
|
|
pred_bboxes = results[0].boxes.xyxyn.to('cpu').detach() |
|
constant_tensor_02 = torch.full((pred_bboxes.shape[0],), 0.2) |
|
constant_tensor_00 = torch.full((pred_bboxes.shape[0],), 0.001) |
|
|
|
xmax = torch.stack([pred_bboxes[:,2]*(DW*1.0), pred_bboxes[:,3]*(DH*1.0), pred_bboxes[:,2]*(DW*1.0), pred_bboxes[:,3]*(DH*1.0), constant_tensor_02], dim=1) |
|
xmin = torch.stack([pred_bboxes[:,0]*(DW*1.0), pred_bboxes[:,1]*(DH*1.0), pred_bboxes[:,0]*(DW*1.0), pred_bboxes[:,1]*(DH*1.0), constant_tensor_00], dim=1) |
|
|
|
unnormalized_preds = prediction_tensor * (xmax - xmin) + xmin |
|
|
|
x_center = (unnormalized_preds[:, 0] + unnormalized_preds[:, 2]) / 2 |
|
y_center = (unnormalized_preds[:, 1] + unnormalized_preds[:, 3]) / 2 |
|
|
|
L = torch.sqrt((unnormalized_preds[:, 0] - unnormalized_preds[:, 2])**2 + |
|
(unnormalized_preds[:, 1] - unnormalized_preds[:, 3])**2) |
|
|
|
L = L + 1e-4 |
|
t_1 = unnormalized_preds[:, 4] |
|
|
|
epsilon = 1e-10 |
|
y_diff = unnormalized_preds[:, 3] - unnormalized_preds[:, 1] + epsilon |
|
x_diff = unnormalized_preds[:, 2] - unnormalized_preds[:, 0] + epsilon |
|
theta = torch.atan2(y_diff, x_diff) |
|
|
|
formatted_variables = torch.cat((x_center.unsqueeze(1), |
|
y_center.unsqueeze(1), |
|
L.unsqueeze(1), |
|
t_1.unsqueeze(1), |
|
theta.unsqueeze(1)), dim=1) |
|
|
|
pred_Phi, pred_H = calc_Phi(formatted_variables.T, LSgrid) |
|
|
|
sum_pred_H = torch.sum(pred_H.detach().cpu(), dim=1) |
|
sum_pred_H[sum_pred_H > 1] = 1 |
|
|
|
final_H = np.flipud(sum_pred_H.detach().numpy().reshape((nely+1, nelx+1), order='F')) |
|
|
|
dice_loss = diceloss(torch.tensor(final_H.copy()), input_image_array_tensor) |
|
tversky_loss = tverskyloss(torch.tensor(final_H.copy()), input_image_array_tensor) |
|
|
|
return input_image_array_tensor, seg_result, pred_Phi, sum_pred_H, final_H, dice_loss, tversky_loss |
|
|
|
def plot_results(input_image_array_tensor, seg_result, pred_Phi, sum_pred_H, final_H, dice_loss, tversky_loss, filename='combined_plots.png'): |
|
nelx = input_image_array_tensor.shape[1] - 1 |
|
nely = input_image_array_tensor.shape[0] - 1 |
|
fig, axes = plt.subplots(2, 2, figsize=(8, 8)) |
|
|
|
axes[0, 0].imshow(input_image_array_tensor.squeeze(), origin='lower', cmap='gray_r') |
|
axes[0, 0].set_title('Input Image') |
|
axes[0, 0].axis('on') |
|
|
|
axes[0, 1].imshow(seg_result) |
|
axes[0, 1].set_title('Segmentation Result') |
|
axes[0, 1].axis('off') |
|
|
|
render_colors1 = ['yellow', 'g', 'r', 'c', 'm', 'y', 'black', 'orange', 'pink', 'cyan', 'slategrey', 'wheat', 'purple', 'mediumturquoise', 'darkviolet', 'orangered'] |
|
for i, color in zip(range(0, pred_Phi.shape[1]), render_colors1*100): |
|
axes[1, 1].contourf(np.flipud(pred_Phi[:, i].numpy().reshape((nely+1, nelx+1), order='F')), [0, 1], colors=color) |
|
axes[1, 1].set_title('Prediction contours') |
|
axes[1, 1].set_aspect('equal') |
|
|
|
axes[1, 0].imshow(np.flipud(sum_pred_H.detach().numpy().reshape((nely+1, nelx+1), order='F')), origin='lower', cmap='gray_r') |
|
axes[1, 0].set_title('Prediction Projection') |
|
|
|
plt.subplots_adjust(hspace=0.3, wspace=0.01) |
|
|
|
plt.figtext(0.5, 0.05, f'Dice Loss: {dice_loss.item():.4f}', ha='center', fontsize=16) |
|
|
|
fig.savefig(filename, dpi=600) |
|
|