|
import logging |
|
|
|
import pytest |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
|
|
import kornia |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class TestIntegrationFocalLoss: |
|
|
|
thresh = 1e-1 |
|
lr = 1e-3 |
|
num_iterations = 1000 |
|
num_classes = 2 |
|
|
|
|
|
alpha = 2.0 |
|
gamma = 2.0 |
|
|
|
def generate_sample(self, base_target, std_val=0.1): |
|
target = base_target.float() / base_target.max() |
|
noise = std_val * torch.rand(1, 1, 6, 5).to(base_target.device) |
|
return target + noise |
|
|
|
@staticmethod |
|
def init_weights(m): |
|
if isinstance(m, nn.Conv2d): |
|
torch.nn.init.xavier_uniform_(m.weight) |
|
|
|
def test_conv2d_relu(self, device): |
|
|
|
|
|
target = torch.LongTensor(1, 6, 5).fill_(0).to(device) |
|
for i in range(1, self.num_classes): |
|
target[..., i:-i, i:-i] = i |
|
|
|
m = nn.Sequential(nn.Conv2d(1, self.num_classes, kernel_size=3, padding=1), nn.ReLU(True)).to(device) |
|
m.apply(self.init_weights) |
|
|
|
optimizer = optim.Adam(m.parameters(), lr=self.lr) |
|
|
|
criterion = kornia.losses.FocalLoss(alpha=self.alpha, gamma=self.gamma, reduction='mean') |
|
|
|
|
|
|
|
for _ in range(self.num_iterations): |
|
sample = self.generate_sample(target).to(device) |
|
output = m(sample) |
|
loss = criterion(output, target.to(device)) |
|
logger.debug(f"Loss: {loss.item()}") |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
sample = self.generate_sample(target).to(device) |
|
output_argmax = torch.argmax(m(sample), dim=1) |
|
logger.debug(f"Output argmax: \n{output_argmax}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
val = F.mse_loss(output_argmax.float(), target.float()) |
|
if not val.item() < self.thresh: |
|
pytest.xfail("Wrong seed or initial weight values.") |
|
|