kaveh commited on
Commit
78245b3
·
1 Parent(s): 35a6153

added spheroid model

Browse files
Files changed (1) hide show
  1. models/s2f_model.py +128 -8
models/s2f_model.py CHANGED
@@ -370,6 +370,114 @@ class S2FGenerator(nn.Module):
370
  return expanded_state
371
 
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  class PatchGANDiscriminator(nn.Module):
374
  """PatchGAN Discriminator (included for create_s2f_model compatibility)."""
375
  def __init__(self, in_channels=2, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
@@ -418,15 +526,27 @@ def create_s2f_model(
418
  use_multi_scale_input=True,
419
  ndf=64,
420
  n_layers=3,
 
421
  ):
422
- """Create S2F model with generator and discriminator."""
423
- generator = S2FGenerator(
424
- in_channels=in_channels,
425
- out_channels=out_channels,
426
- img_size=img_size,
427
- bridge_type=bridge_type,
428
- use_multi_scale_input=use_multi_scale_input,
429
- )
 
 
 
 
 
 
 
 
 
 
 
430
  discriminator = PatchGANDiscriminator(
431
  in_channels=in_channels + out_channels,
432
  ndf=ndf,
 
370
  return expanded_state
371
 
372
 
373
+ class SpheroidAttentionGate(nn.Module):
374
+ """Attention Gate from ForceNet2WithAttention (s2f_spheroid). Checkpoint-compatible for ckp_spheroid_*.pth."""
375
+ def __init__(self, F_g, F_l, F_int):
376
+ super(SpheroidAttentionGate, self).__init__()
377
+ self.W_g = nn.Sequential(
378
+ nn.Conv2d(F_g, F_int, kernel_size=1),
379
+ nn.BatchNorm2d(F_int)
380
+ )
381
+ self.W_x = nn.Sequential(
382
+ nn.Conv2d(F_l, F_int, kernel_size=1),
383
+ nn.BatchNorm2d(F_int)
384
+ )
385
+ self.psi = nn.Sequential(
386
+ nn.ReLU(inplace=True),
387
+ nn.Conv2d(F_int, 1, kernel_size=1),
388
+ nn.Sigmoid()
389
+ )
390
+
391
+ def forward(self, g, x):
392
+ g1 = self.W_g(g)
393
+ x1 = self.W_x(x)
394
+ psi = self.psi(g1 + x1)
395
+ return x * psi
396
+
397
+
398
+ class S2FSpheroidGenerator(nn.Module):
399
+ """
400
+ S2F model tuned for spheroid data. Uses sigmoid output [0, 1] for inference.
401
+ """
402
+ def __init__(self, in_channels=1, out_channels=1, predict_numbers=False, img_size=1024, use_tanh_output=True):
403
+ super(S2FSpheroidGenerator, self).__init__()
404
+ self.predict_numbers = predict_numbers
405
+ self.img_size = img_size
406
+ self.use_tanh_output = use_tanh_output
407
+
408
+ def conv_block(in_c, out_c):
409
+ return nn.Sequential(
410
+ nn.Conv2d(in_c, out_c, 3, padding=1),
411
+ nn.BatchNorm2d(out_c),
412
+ nn.ReLU(inplace=True),
413
+ ResidualBlock(out_c, out_c)
414
+ )
415
+
416
+ # Encoder
417
+ self.encoder1 = conv_block(in_channels, 32)
418
+ self.pool1 = nn.MaxPool2d(2)
419
+ self.encoder2 = conv_block(32, 64)
420
+ self.pool2 = nn.MaxPool2d(2)
421
+ self.encoder3 = conv_block(64, 128)
422
+ self.pool3 = nn.MaxPool2d(2)
423
+ self.encoder4 = conv_block(128, 256)
424
+ self.pool4 = nn.MaxPool2d(2)
425
+ self.bridge = nn.Sequential(
426
+ nn.Conv2d(256, 512, kernel_size=3, padding=2, dilation=2),
427
+ nn.BatchNorm2d(512),
428
+ nn.ReLU(),
429
+ ResidualBlock(512, 512)
430
+ )
431
+
432
+ self.att3 = SpheroidAttentionGate(256, 256, 128)
433
+ self.att2 = SpheroidAttentionGate(128, 128, 64)
434
+ self.att1 = SpheroidAttentionGate(64, 64, 32)
435
+
436
+ self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
437
+ self.dec3 = conv_block(512, 256)
438
+ self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
439
+ self.dec2 = conv_block(256, 128)
440
+ self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
441
+ self.dec1 = conv_block(128, 64)
442
+ self.up0 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
443
+ self.dec0 = conv_block(64, 32)
444
+ self.pred_conv = nn.Conv2d(32, out_channels, kernel_size=1)
445
+
446
+ def forward(self, x):
447
+ e1 = self.encoder1(x)
448
+ e2 = self.encoder2(self.pool1(e1))
449
+ e3 = self.encoder3(self.pool2(e2))
450
+ e4 = self.encoder4(self.pool3(e3))
451
+ b = self.bridge(self.pool4(e4))
452
+
453
+ g3 = self.up3(b)
454
+ x3 = self.att3(g3, e4)
455
+ d3 = self.dec3(torch.cat([g3, x3], dim=1))
456
+
457
+ g2 = self.up2(d3)
458
+ x2 = self.att2(g2, e3)
459
+ d2 = self.dec2(torch.cat([g2, x2], dim=1))
460
+
461
+ g1 = self.up1(d2)
462
+ x1 = self.att1(g1, e2)
463
+ d1 = self.dec1(torch.cat([g1, x1], dim=1))
464
+
465
+ g0 = self.up0(d1)
466
+ d0 = self.dec0(torch.cat([g0, e1], dim=1))
467
+
468
+ out = self.pred_conv(d0)
469
+ out_resized = F.interpolate(out, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)
470
+
471
+ if self.use_tanh_output:
472
+ return torch.tanh(out_resized)
473
+ else:
474
+ return torch.sigmoid(out_resized)
475
+
476
+ def set_output_mode(self, use_tanh=True):
477
+ """Set output activation: tanh [-1,1] for training, sigmoid [0,1] for inference."""
478
+ self.use_tanh_output = use_tanh
479
+
480
+
481
  class PatchGANDiscriminator(nn.Module):
482
  """PatchGAN Discriminator (included for create_s2f_model compatibility)."""
483
  def __init__(self, in_channels=2, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
 
526
  use_multi_scale_input=True,
527
  ndf=64,
528
  n_layers=3,
529
+ model_type='s2f',
530
  ):
531
+ """Create S2F model with generator and discriminator.
532
+ model_type: 's2f' for single-cell, 's2f_spheroid' for spheroid.
533
+ """
534
+ if model_type == 's2f':
535
+ generator = S2FGenerator(
536
+ in_channels=in_channels,
537
+ out_channels=out_channels,
538
+ img_size=img_size,
539
+ bridge_type=bridge_type,
540
+ use_multi_scale_input=use_multi_scale_input,
541
+ )
542
+ elif model_type == 's2f_spheroid':
543
+ generator = S2FSpheroidGenerator(
544
+ in_channels=in_channels,
545
+ out_channels=out_channels,
546
+ img_size=img_size,
547
+ )
548
+ else:
549
+ raise ValueError(f"Invalid model type: {model_type}")
550
  discriminator = PatchGANDiscriminator(
551
  in_channels=in_channels + out_channels,
552
  ndf=ndf,