Spaces:
Running
Running
added spheroid model
Browse files- models/s2f_model.py +128 -8
models/s2f_model.py
CHANGED
|
@@ -370,6 +370,114 @@ class S2FGenerator(nn.Module):
|
|
| 370 |
return expanded_state
|
| 371 |
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
class PatchGANDiscriminator(nn.Module):
|
| 374 |
"""PatchGAN Discriminator (included for create_s2f_model compatibility)."""
|
| 375 |
def __init__(self, in_channels=2, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
|
@@ -418,15 +526,27 @@ def create_s2f_model(
|
|
| 418 |
use_multi_scale_input=True,
|
| 419 |
ndf=64,
|
| 420 |
n_layers=3,
|
|
|
|
| 421 |
):
|
| 422 |
-
"""Create S2F model with generator and discriminator.
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
discriminator = PatchGANDiscriminator(
|
| 431 |
in_channels=in_channels + out_channels,
|
| 432 |
ndf=ndf,
|
|
|
|
| 370 |
return expanded_state
|
| 371 |
|
| 372 |
|
| 373 |
+
class SpheroidAttentionGate(nn.Module):
|
| 374 |
+
"""Attention Gate from ForceNet2WithAttention (s2f_spheroid). Checkpoint-compatible for ckp_spheroid_*.pth."""
|
| 375 |
+
def __init__(self, F_g, F_l, F_int):
|
| 376 |
+
super(SpheroidAttentionGate, self).__init__()
|
| 377 |
+
self.W_g = nn.Sequential(
|
| 378 |
+
nn.Conv2d(F_g, F_int, kernel_size=1),
|
| 379 |
+
nn.BatchNorm2d(F_int)
|
| 380 |
+
)
|
| 381 |
+
self.W_x = nn.Sequential(
|
| 382 |
+
nn.Conv2d(F_l, F_int, kernel_size=1),
|
| 383 |
+
nn.BatchNorm2d(F_int)
|
| 384 |
+
)
|
| 385 |
+
self.psi = nn.Sequential(
|
| 386 |
+
nn.ReLU(inplace=True),
|
| 387 |
+
nn.Conv2d(F_int, 1, kernel_size=1),
|
| 388 |
+
nn.Sigmoid()
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
def forward(self, g, x):
|
| 392 |
+
g1 = self.W_g(g)
|
| 393 |
+
x1 = self.W_x(x)
|
| 394 |
+
psi = self.psi(g1 + x1)
|
| 395 |
+
return x * psi
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
class S2FSpheroidGenerator(nn.Module):
|
| 399 |
+
"""
|
| 400 |
+
S2F model tuned for spheroid data. Uses sigmoid output [0, 1] for inference.
|
| 401 |
+
"""
|
| 402 |
+
def __init__(self, in_channels=1, out_channels=1, predict_numbers=False, img_size=1024, use_tanh_output=True):
|
| 403 |
+
super(S2FSpheroidGenerator, self).__init__()
|
| 404 |
+
self.predict_numbers = predict_numbers
|
| 405 |
+
self.img_size = img_size
|
| 406 |
+
self.use_tanh_output = use_tanh_output
|
| 407 |
+
|
| 408 |
+
def conv_block(in_c, out_c):
|
| 409 |
+
return nn.Sequential(
|
| 410 |
+
nn.Conv2d(in_c, out_c, 3, padding=1),
|
| 411 |
+
nn.BatchNorm2d(out_c),
|
| 412 |
+
nn.ReLU(inplace=True),
|
| 413 |
+
ResidualBlock(out_c, out_c)
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Encoder
|
| 417 |
+
self.encoder1 = conv_block(in_channels, 32)
|
| 418 |
+
self.pool1 = nn.MaxPool2d(2)
|
| 419 |
+
self.encoder2 = conv_block(32, 64)
|
| 420 |
+
self.pool2 = nn.MaxPool2d(2)
|
| 421 |
+
self.encoder3 = conv_block(64, 128)
|
| 422 |
+
self.pool3 = nn.MaxPool2d(2)
|
| 423 |
+
self.encoder4 = conv_block(128, 256)
|
| 424 |
+
self.pool4 = nn.MaxPool2d(2)
|
| 425 |
+
self.bridge = nn.Sequential(
|
| 426 |
+
nn.Conv2d(256, 512, kernel_size=3, padding=2, dilation=2),
|
| 427 |
+
nn.BatchNorm2d(512),
|
| 428 |
+
nn.ReLU(),
|
| 429 |
+
ResidualBlock(512, 512)
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
self.att3 = SpheroidAttentionGate(256, 256, 128)
|
| 433 |
+
self.att2 = SpheroidAttentionGate(128, 128, 64)
|
| 434 |
+
self.att1 = SpheroidAttentionGate(64, 64, 32)
|
| 435 |
+
|
| 436 |
+
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
|
| 437 |
+
self.dec3 = conv_block(512, 256)
|
| 438 |
+
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
|
| 439 |
+
self.dec2 = conv_block(256, 128)
|
| 440 |
+
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
|
| 441 |
+
self.dec1 = conv_block(128, 64)
|
| 442 |
+
self.up0 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
|
| 443 |
+
self.dec0 = conv_block(64, 32)
|
| 444 |
+
self.pred_conv = nn.Conv2d(32, out_channels, kernel_size=1)
|
| 445 |
+
|
| 446 |
+
def forward(self, x):
|
| 447 |
+
e1 = self.encoder1(x)
|
| 448 |
+
e2 = self.encoder2(self.pool1(e1))
|
| 449 |
+
e3 = self.encoder3(self.pool2(e2))
|
| 450 |
+
e4 = self.encoder4(self.pool3(e3))
|
| 451 |
+
b = self.bridge(self.pool4(e4))
|
| 452 |
+
|
| 453 |
+
g3 = self.up3(b)
|
| 454 |
+
x3 = self.att3(g3, e4)
|
| 455 |
+
d3 = self.dec3(torch.cat([g3, x3], dim=1))
|
| 456 |
+
|
| 457 |
+
g2 = self.up2(d3)
|
| 458 |
+
x2 = self.att2(g2, e3)
|
| 459 |
+
d2 = self.dec2(torch.cat([g2, x2], dim=1))
|
| 460 |
+
|
| 461 |
+
g1 = self.up1(d2)
|
| 462 |
+
x1 = self.att1(g1, e2)
|
| 463 |
+
d1 = self.dec1(torch.cat([g1, x1], dim=1))
|
| 464 |
+
|
| 465 |
+
g0 = self.up0(d1)
|
| 466 |
+
d0 = self.dec0(torch.cat([g0, e1], dim=1))
|
| 467 |
+
|
| 468 |
+
out = self.pred_conv(d0)
|
| 469 |
+
out_resized = F.interpolate(out, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
|
| 470 |
+
|
| 471 |
+
if self.use_tanh_output:
|
| 472 |
+
return torch.tanh(out_resized)
|
| 473 |
+
else:
|
| 474 |
+
return torch.sigmoid(out_resized)
|
| 475 |
+
|
| 476 |
+
def set_output_mode(self, use_tanh=True):
|
| 477 |
+
"""Set output activation: tanh [-1,1] for training, sigmoid [0,1] for inference."""
|
| 478 |
+
self.use_tanh_output = use_tanh
|
| 479 |
+
|
| 480 |
+
|
| 481 |
class PatchGANDiscriminator(nn.Module):
|
| 482 |
"""PatchGAN Discriminator (included for create_s2f_model compatibility)."""
|
| 483 |
def __init__(self, in_channels=2, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
|
|
|
| 526 |
use_multi_scale_input=True,
|
| 527 |
ndf=64,
|
| 528 |
n_layers=3,
|
| 529 |
+
model_type='s2f',
|
| 530 |
):
|
| 531 |
+
"""Create S2F model with generator and discriminator.
|
| 532 |
+
model_type: 's2f' for single-cell, 's2f_spheroid' for spheroid.
|
| 533 |
+
"""
|
| 534 |
+
if model_type == 's2f':
|
| 535 |
+
generator = S2FGenerator(
|
| 536 |
+
in_channels=in_channels,
|
| 537 |
+
out_channels=out_channels,
|
| 538 |
+
img_size=img_size,
|
| 539 |
+
bridge_type=bridge_type,
|
| 540 |
+
use_multi_scale_input=use_multi_scale_input,
|
| 541 |
+
)
|
| 542 |
+
elif model_type == 's2f_spheroid':
|
| 543 |
+
generator = S2FSpheroidGenerator(
|
| 544 |
+
in_channels=in_channels,
|
| 545 |
+
out_channels=out_channels,
|
| 546 |
+
img_size=img_size,
|
| 547 |
+
)
|
| 548 |
+
else:
|
| 549 |
+
raise ValueError(f"Invalid model type: {model_type}")
|
| 550 |
discriminator = PatchGANDiscriminator(
|
| 551 |
in_channels=in_channels + out_channels,
|
| 552 |
ndf=ndf,
|