Morelli001 commited on
Commit
aee1a39
·
verified ·
1 Parent(s): feee4b6

Upload folder using huggingface_hub

Browse files
45.png ADDED
README.md ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FILMUnet2D (Transformers-compatible)
2
+
3
+ This model is a 2D U-Net with FiLM conditioning for multi-organ segmentation.
4
+
5
+ ## Installation
6
+
7
+ Make sure you have `transformers` and `torch` installed.
8
+
9
+ ```bash
10
+ pip install transformers torch
11
+ ```
12
+
13
+ ## Usage
14
+
15
+ You can load the model and processor using the `Auto` classes from `transformers`. Since this repository contains custom code, make sure to pass `trust_remote_code=True`.
16
+
17
+ ```python
18
+ import torch
19
+ from transformers import AutoModel, AutoImageProcessor
20
+ from PIL import Image
21
+
22
+ # 1. Load model and processor
23
+ repo_id = "Morelli001/US_UNet2DFiLM"
24
+
25
+ processor = AutoImageProcessor.from_pretrained(repo_id, trust_remote_code=True)
26
+ model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
27
+ model.eval()
28
+
29
+ # 2. Load and preprocess your image
30
+ # The processor handles resizing, letterboxing, and normalization.
31
+ image = Image.open("path/to/your/image.png").convert("RGB")
32
+ inputs = processor(images=image, return_tensors="pt")
33
+
34
+ # 3. Prepare conditioning input
35
+ # This should be an integer tensor representing the organ ID.
36
+ # Replace `4` with the appropriate ID for your use case.
37
+ organ_id = torch.tensor([4])
38
+
39
+ # 4. Run inference
40
+ with torch.no_grad():
41
+ outputs = model(**inputs, organ_id=organ_id)
42
+
43
+ # 5. Post-process the output to get the final segmentation mask
44
+ # The processor can convert the logits to a binary mask, automatically handling
45
+ # the removal of letterbox padding and resizing to the original image dimensions.
46
+ mask = processor.post_process_semantic_segmentation(
47
+ outputs,
48
+ inputs,
49
+ threshold=0.7,
50
+ return_as_pil=True
51
+ )[0]
52
+
53
+ # 6. Save the result
54
+ mask.save("output_mask.png")
55
+
56
+ print("Segmentation mask saved to output_mask.png")
57
+ ```
58
+
59
+ ### Model Details
60
+
61
+ - **Architecture:** U-Net with FiLM layers for conditional segmentation.
62
+ - **Conditioning:** The model's output is conditioned on an `organ_id` input.
63
+ - **Input:** RGB images.
64
+ - **Output:** A single-channel segmentation mask.
65
+
66
+ ### Configuration
67
+
68
+ The model configuration can be accessed via `model.config`. Key parameters include:
69
+ - `in_channels`: Number of input channels (default: 3).
70
+ - `num_classes`: Number of output classes (default: 1).
71
+ - `n_organs`: The number of different organs the model was trained to condition on.
72
+ - `depth`: The depth of the U-Net.
73
+ - `size`: The base number of filters in the first layer.
74
+
75
+ ### Organ IDs
76
+
77
+ The `organ_id` passed to the model corresponds to the following mapping:
78
+
79
+ ```python
80
+ organ_to_class_dict = {
81
+ "appendix": 0,
82
+ "breast": 1,
83
+ "breast_luminal": 1,
84
+ "cardiac": 2,
85
+ "thyroid": 3,
86
+ "fetal": 4,
87
+ "kidney": 5,
88
+ "liver": 6,
89
+ "testicle": 7,
90
+ }
91
+ ```
92
+
93
+ ### Alternative Versions
94
+
95
+ This repository contains multiple versions of the model located in subfolders. You can load a specific version by using the `subfolder` parameter.
96
+
97
+ #### 4-Stage U-Net
98
+
99
+ This version has a U-Net depth of 4.
100
+
101
+ ```python
102
+ from transformers import AutoModel
103
+
104
+ model_4_stages = AutoModel.from_pretrained(
105
+ "Morelli001/US_UNet2DFiLM",
106
+ subfolder="unet_4_stages",
107
+ trust_remote_code=True
108
+ )
109
+ ```
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "film_unet2d",
3
+ "architectures": ["FilmUnet2DModel"],
4
+ "auto_map": {
5
+ "AutoConfig": "configuration_film_unet2d.FilmUnet2DConfig",
6
+ "AutoModel": "modeling_film_unet2d.FilmUnet2DModel",
7
+ "AutoImageProcessor": "image_processing_film_unet2d.FilmUnet2DImageProcessor"
8
+ },
9
+ "in_channels": 3,
10
+ "num_classes": 1,
11
+ "n_organs": 9,
12
+ "size": 32,
13
+ "depth": 5,
14
+ "film_start": 0,
15
+ "use_film": 1
16
+ }
configuration_film_unet2d.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers.configuration_utils import PretrainedConfig
3
+
4
+ class FilmUnet2DConfig(PretrainedConfig):
5
+ model_type = "film_unet2d"
6
+
7
+ def __init__(self,
8
+ in_channels=3,
9
+ num_classes=1,
10
+ n_organs=9,
11
+ size=32,
12
+ depth=5,
13
+ film_start=0,
14
+ use_film=True,
15
+ **kwargs):
16
+ super().__init__(**kwargs)
17
+ self.in_channels = in_channels
18
+ self.num_classes = num_classes
19
+ self.n_organs = n_organs
20
+ self.size = size
21
+ self.depth = depth
22
+ self.film_start = film_start
23
+ self.use_film = use_film
image_processing_film_unet2d.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # image_processing_film_unet2d.py
2
+ from typing import List, Union, Tuple, Optional
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from transformers.image_processing_utils import ImageProcessingMixin
7
+
8
+ ArrayLike = Union[np.ndarray, torch.Tensor, Image.Image]
9
+
10
+ def _to_rgb_numpy(im: ArrayLike) -> np.ndarray:
11
+ # -> float32 HWC in [0,255], 3 channels
12
+ if isinstance(im, Image.Image):
13
+ if im.mode != "RGB":
14
+ im = im.convert("RGB")
15
+ arr = np.array(im, dtype=np.uint8).astype(np.float32)
16
+ elif isinstance(im, torch.Tensor):
17
+ t = im.detach().cpu()
18
+ if t.ndim != 3:
19
+ raise ValueError("Tensor must be 3D (CHW or HWC).")
20
+ if t.shape[0] in (1, 3): # CHW
21
+ if t.shape[0] == 1:
22
+ t = t.repeat(3, 1, 1)
23
+ t = t.permute(1, 2, 0) # HWC
24
+ elif t.shape[-1] == 1: # HWC gray
25
+ t = t.repeat(1, 1, 3)
26
+ arr = t.numpy()
27
+ if arr.dtype in (np.float32, np.float64) and arr.max() <= 1.5:
28
+ arr = (arr * 255.0).astype(np.float32)
29
+ else:
30
+ arr = arr.astype(np.float32)
31
+ else:
32
+ arr = np.array(im)
33
+ if arr.ndim == 2:
34
+ arr = np.repeat(arr[..., None], 3, axis=-1)
35
+ arr = arr.astype(np.float32)
36
+ if arr.max() <= 1.5:
37
+ arr = (arr * 255.0).astype(np.float32)
38
+ if arr.ndim != 3 or arr.shape[-1] != 3:
39
+ raise ValueError("Expected RGB image with shape HxWx3.")
40
+ return arr
41
+
42
+ def _letterbox_keep_ratio(arr: np.ndarray, target_hw: Tuple[int, int]):
43
+ """Resize with aspect ratio preserved and pad with 0 (black) to target (H,W).
44
+ Returns: out(H,W,3), (top, left, new_h, new_w)
45
+ """
46
+ th, tw = target_hw
47
+ h, w = arr.shape[:2]
48
+ scale = min(th / h, tw / w)
49
+ nh, nw = int(round(h * scale)), int(round(w * scale))
50
+ if nh <= 0 or nw <= 0:
51
+ raise ValueError("Invalid resize result.")
52
+ pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8))
53
+ pil = pil.resize((nw, nh), resample=Image.BILINEAR)
54
+ rs = np.array(pil, dtype=np.float32)
55
+ out = np.zeros((th, tw, 3), dtype=np.float32)
56
+ top = (th - nh) // 2
57
+ left = (tw - nw) // 2
58
+ out[top:top+nh, left:left+nw] = rs
59
+ return out, (top, left, nh, nw)
60
+
61
+ def _zscore_ignore_black(chw: np.ndarray, eps: float = 1e-8) -> np.ndarray:
62
+ mask = (chw.sum(axis=0) > 0) # HxW
63
+ if not mask.any():
64
+ return chw.copy()
65
+ valid = chw[:, mask]
66
+ mean = valid.mean()
67
+ std = valid.std()
68
+ return (chw - mean) / std if std > eps else (chw - mean)
69
+
70
+ class FilmUnet2DImageProcessor(ImageProcessingMixin):
71
+ """
72
+ Processor for FILMUnet2D:
73
+ - Convert to RGB
74
+ - Keep-aspect-ratio resize+pad (letterbox) to 512x512 (configurable)
75
+ - Normalize with mean/std in 0–255 space (like your training)
76
+ - Optional z-score 'self_norm' ignoring black pixels
77
+ Returns dict with:
78
+ - pixel_values: torch.FloatTensor [B,3,H,W]
79
+ - original_sizes: torch.LongTensor [B,2] (H,W)
80
+ - letterbox_params: torch.LongTensor [B,4] (top, left, nh, nw) # NEW
81
+ """
82
+
83
+ model_input_names = ["pixel_values"]
84
+
85
+ def __init__(
86
+ self,
87
+ do_resize: bool = True,
88
+ size: Tuple[int, int] = (512, 512),
89
+ keep_ratio: bool = True,
90
+ image_mean: Tuple[float, float, float] = (123.675, 116.28, 103.53),
91
+ image_std: Tuple[float, float, float] = (58.395, 57.12, 57.375),
92
+ self_norm: bool = False,
93
+ **kwargs,
94
+ ):
95
+ super().__init__(**kwargs)
96
+ self.do_resize = bool(do_resize)
97
+ self.size = tuple(size)
98
+ self.keep_ratio = bool(keep_ratio)
99
+ self.image_mean = tuple(float(x) for x in image_mean)
100
+ self.image_std = tuple(float(x) for x in image_std)
101
+ self.self_norm = bool(self_norm)
102
+
103
+ def __call__(
104
+ self,
105
+ images: Union[ArrayLike, List[ArrayLike]],
106
+ return_tensors: Optional[str] = "pt",
107
+ **kwargs,
108
+ ):
109
+ imgs = images if isinstance(images, (list, tuple)) else [images]
110
+ batch = []
111
+ orig_sizes = []
112
+ lb_params = []
113
+
114
+ for im in imgs:
115
+ arr = _to_rgb_numpy(im) # HWC float32 in 0–255
116
+ oh, ow = arr.shape[:2]
117
+ orig_sizes.append((oh, ow))
118
+
119
+ if self.do_resize:
120
+ if self.keep_ratio:
121
+ arr, meta = _letterbox_keep_ratio(arr, self.size) # meta=(top,left,nh,nw)
122
+ else:
123
+ h, w = self.size
124
+ pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8))
125
+ arr = np.array(pil.resize((w, h), resample=Image.BILINEAR), dtype=np.float32)
126
+ meta = (0, 0, h, w)
127
+ else:
128
+ # no resize: still expose meta so postprocess can handle consistently
129
+ h, w = arr.shape[:2]
130
+ pad_h = self.size[0] - h
131
+ pad_w = self.size[1] - w
132
+ top = max(pad_h // 2, 0)
133
+ left = max(pad_w // 2, 0)
134
+ out = np.zeros((*self.size, 3), dtype=np.float32)
135
+ out[top:top+h, left:left+w] = arr[:self.size[0]-top, :self.size[1]-left]
136
+ arr = out
137
+ meta = (top, left, h, w)
138
+
139
+ lb_params.append(meta)
140
+
141
+ mean = np.array(self.image_mean, dtype=np.float32).reshape(1, 1, 3)
142
+ std = np.array(self.image_std, dtype=np.float32).reshape(1, 1, 3)
143
+ arr = (arr - mean) / std # HWC
144
+
145
+ chw = np.transpose(arr, (2, 0, 1)) # C,H,W
146
+ if self.self_norm:
147
+ chw = _zscore_ignore_black(chw)
148
+ batch.append(chw)
149
+
150
+ pixel_values = np.stack(batch, axis=0) # B,C,H,W
151
+ if return_tensors == "pt":
152
+ pixel_values = torch.from_numpy(pixel_values).to(torch.float32)
153
+ original_sizes = torch.tensor(orig_sizes, dtype=torch.long)
154
+ letterbox_params = torch.tensor(lb_params, dtype=torch.long)
155
+ else:
156
+ original_sizes = orig_sizes
157
+ letterbox_params = lb_params
158
+
159
+ return {
160
+ "pixel_values": pixel_values,
161
+ "original_sizes": original_sizes, # (B,2) H,W
162
+ "letterbox_params": letterbox_params # (B,4) top,left,nh,nw in 512x512
163
+ }
164
+
165
+ # ---------- POST-PROCESSING ----------
166
+ def post_process_semantic_segmentation(
167
+ self,
168
+ outputs: dict,
169
+ processor_inputs: Optional[dict] = None,
170
+ threshold: float = 0.5,
171
+ return_as_pil: bool = True,
172
+ ):
173
+ """
174
+ Turn model outputs into masks resized back to the ORIGINAL image sizes,
175
+ with letterbox padding removed.
176
+
177
+ Args:
178
+ outputs: dict from model forward (expects 'logits': [B,1,512,512])
179
+ processor_inputs: the dict returned by __call__ (must contain
180
+ 'original_sizes' [B,2] and 'letterbox_params' [B,4])
181
+ threshold: probability threshold for binarization
182
+ return_as_pil: return a list of PIL Images (uint8 0/255) if True,
183
+ else a list of torch tensors [H,W] uint8
184
+
185
+ Returns:
186
+ List of masks back in original sizes (H,W).
187
+ """
188
+ logits = outputs["logits"] # [B,1,H,W]
189
+ probs = torch.sigmoid(logits)
190
+ masks = (probs > threshold).to(torch.uint8) * 255 # [B,1,H,W] uint8
191
+
192
+ if processor_inputs is None:
193
+ raise ValueError("processor_inputs must be provided to undo letterboxing.")
194
+
195
+ orig_sizes = processor_inputs["original_sizes"] # [B,2]
196
+ lb_params = processor_inputs["letterbox_params"] # [B,4] top,left,nh,nw
197
+
198
+ results = []
199
+ B = masks.shape[0]
200
+ for i in range(B):
201
+ m = masks[i, 0] # [512,512]
202
+ top, left, nh, nw = [int(x) for x in lb_params[i].tolist()]
203
+ # crop letterbox
204
+ m_cropped = m[top:top+nh, left:left+nw] # [nh,nw]
205
+ # resize back to original
206
+ oh, ow = [int(x) for x in orig_sizes[i].tolist()]
207
+ m_resized = torch.nn.functional.interpolate(
208
+ m_cropped.unsqueeze(0).unsqueeze(0).float(),
209
+ size=(oh, ow),
210
+ mode="nearest"
211
+ )[0,0].to(torch.uint8) # [oh,ow]
212
+
213
+ if return_as_pil:
214
+ results.append(Image.fromarray(m_resized.cpu().numpy(), mode="L"))
215
+ else:
216
+ results.append(m_resized)
217
+
218
+ return results
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66b39afecca324389126efc9d7995b707cfe2cc0330e2bfa12728829bd79b2a6
3
+ size 603375348
modeling_film_unet2d.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers.modeling_utils import PreTrainedModel
5
+ from .configuration_film_unet2d import FilmUnet2DConfig
6
+
7
+ class ConvBlock(nn.Module):
8
+ def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
9
+ super().__init__()
10
+ self.block = nn.Sequential(
11
+ nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p),
12
+ nn.InstanceNorm2d(out_ch),
13
+ nn.LeakyReLU(inplace=True),
14
+ )
15
+ def forward(self, x): return self.block(x)
16
+
17
+ class FiLM2d(nn.Module):
18
+ def __init__(self, n_organs, in_channels, emb_dim=64, hidden=None):
19
+ super().__init__()
20
+ hidden = hidden or 2 * in_channels
21
+ self.embed = nn.Embedding(n_organs, emb_dim)
22
+ self.mlp = nn.Sequential(
23
+ nn.Linear(emb_dim, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, 2*in_channels)
24
+ )
25
+ nn.init.zeros_(self.mlp[-1].weight)
26
+ nn.init.constant_(self.mlp[-1].bias[:in_channels], 0)
27
+ nn.init.constant_(self.mlp[-1].bias[in_channels:], 1)
28
+ def forward(self, x, organ_id):
29
+ beta_gamma = self.mlp(self.embed(organ_id))
30
+ beta, gamma = beta_gamma.chunk(2, dim=-1)
31
+ beta = beta.unsqueeze(-1).unsqueeze(-1)
32
+ gamma = gamma.unsqueeze(-1).unsqueeze(-1)
33
+ return gamma * x + beta
34
+
35
+ class DownFiLM(nn.Module):
36
+ def __init__(self, in_chs, out_chs, n_organs):
37
+ super().__init__()
38
+ self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
39
+ self.film_blocks = nn.ModuleList([FiLM2d(n_organs, o) for o in out_chs])
40
+ self.pool = nn.MaxPool2d(2,2)
41
+ def forward(self, x, organ_id):
42
+ for c,f in zip(self.conv_blocks, self.film_blocks):
43
+ x = f(c(x), organ_id)
44
+ return self.pool(x), x
45
+
46
+ class Down(nn.Module):
47
+ def __init__(self, in_chs, out_chs):
48
+ super().__init__()
49
+ self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
50
+ self.pool = nn.MaxPool2d(2,2)
51
+ def forward(self, x):
52
+ for c in self.conv_blocks: x = c(x)
53
+ return self.pool(x), x
54
+
55
+ class UpFiLM(nn.Module):
56
+ def __init__(self, in_chs, out_chs, n_organs, up=True):
57
+ super().__init__()
58
+ self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
59
+ self.film_blocks = nn.ModuleList([FiLM2d(n_organs, o) for o in out_chs])
60
+ self.up_conv_op = nn.ConvTranspose2d(out_chs[-1], out_chs[-1], kernel_size=2, stride=2) if up else None
61
+ def forward(self, x, organ_id):
62
+ for c,f in zip(self.conv_blocks, self.film_blocks):
63
+ x = f(c(x), organ_id)
64
+ return self.up_conv_op(x) if self.up_conv_op is not None else x
65
+
66
+ class Up(nn.Module):
67
+ def __init__(self, in_chs, out_chs, up=True):
68
+ super().__init__()
69
+ self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
70
+ self.up_conv_op = nn.ConvTranspose2d(out_chs[-1], out_chs[-1], kernel_size=2, stride=2) if up else None
71
+ def forward(self, x):
72
+ for c in self.conv_blocks: x = c(x)
73
+ return self.up_conv_op(x) if self.up_conv_op is not None else x
74
+
75
+ class UNet2DFiLMCore(nn.Module):
76
+ def __init__(self, cfg: FilmUnet2DConfig):
77
+ super().__init__()
78
+ size, depth, n_organs = cfg.size, cfg.depth, cfg.n_organs
79
+ use_film, film_start = cfg.use_film, cfg.film_start
80
+ self.encoder = nn.ModuleDict()
81
+ if use_film and 0 >= film_start:
82
+ self.encoder["0"] = DownFiLM([cfg.in_channels, size], [size, size*2], n_organs)
83
+ else:
84
+ self.encoder["0"] = Down([cfg.in_channels, size], [size, size*2])
85
+ for i in range(1, depth):
86
+ in_ch = [size*(2**i), size*(2**i)]
87
+ out_ch = [size*(2**i), size*(2**(i+1))]
88
+ if use_film and i >= film_start:
89
+ self.encoder[str(i)] = DownFiLM(in_ch, out_ch, n_organs)
90
+ else:
91
+ self.encoder[str(i)] = Down(in_ch, out_ch)
92
+ if use_film:
93
+ self.bottleneck = UpFiLM([size*(2**depth), size*(2**depth)], [size*(2**depth), size*(2**(depth+1))], n_organs)
94
+ else:
95
+ self.bottleneck = Up([size*(2**depth), size*(2**depth)], [size*(2**depth), size*(2**(depth+1))])
96
+ self.decoder = nn.ModuleDict()
97
+ for i in range(depth, 1, -1):
98
+ use_film_here = use_film and (i-1) >= film_start
99
+ if use_film_here:
100
+ self.decoder[str(i-1)] = UpFiLM([size*(2**(i+1))+size*(2**i), size*(2**i)], [size*(2**i), size*(2**i)], n_organs)
101
+ else:
102
+ self.decoder[str(i-1)] = Up([size*(2**(i+1))+size*(2**i), size*(2**i)], [size*(2**i), size*(2**i)])
103
+ if use_film and 0 >= film_start:
104
+ self.decoder["0"] = UpFiLM([size*4+size*2, size*2], [size*2, size*2], n_organs, up=False)
105
+ else:
106
+ self.decoder["0"] = Up([size*4+size*2, size*2], [size*2, size*2], up=False)
107
+ self.out_layer = ConvBlock(
108
+ size * 2,
109
+ cfg.num_classes,
110
+ k= 1,s= 1,p=0
111
+ )
112
+ def forward(self, pixel_values, organ_id):
113
+ feats = []
114
+ out, feat = (self.encoder["0"](pixel_values, organ_id) if isinstance(self.encoder["0"], DownFiLM) else self.encoder["0"](pixel_values))
115
+ feats.append(feat)
116
+ for k in list(self.encoder.keys())[1:]:
117
+ blk = self.encoder[k]
118
+ out, feat = (blk(out, organ_id) if isinstance(blk, DownFiLM) else blk(out))
119
+ feats.append(feat)
120
+ out = self.bottleneck(out, organ_id) if isinstance(self.bottleneck, UpFiLM) else self.bottleneck(out)
121
+ for k in self.decoder:
122
+ cat = torch.cat([out, feats[int(k)]], dim=1)
123
+ blk = self.decoder[k]
124
+ out = blk(cat, organ_id) if isinstance(blk, UpFiLM) else blk(cat)
125
+ return self.out_layer(out)
126
+
127
+ class FilmUnet2DModel(PreTrainedModel):
128
+ config_class = FilmUnet2DConfig
129
+ base_model_prefix = "model"
130
+
131
+ def __init__(self, config: FilmUnet2DConfig):
132
+ super().__init__(config)
133
+ self.model = UNet2DFiLMCore(config)
134
+ self.post_init()
135
+
136
+ def forward(self, pixel_values, organ_id, labels=None, **kwargs):
137
+ logits = self.model(pixel_values, organ_id)
138
+ if labels is None:
139
+ return {"logits": logits}
140
+ loss = nn.functional.binary_cross_entropy_with_logits(logits, labels)
141
+ return {"loss": loss, "logits": logits}
preprocessor_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_resize": true,
3
+ "size": [512, 512],
4
+ "keep_ratio": true,
5
+ "image_mean": [123.675, 116.28, 103.53],
6
+ "image_std": [58.395, 57.12, 57.375],
7
+ "self_norm": false
8
+ }
test.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test_load_film_unet2d.py
2
+ import torch, os
3
+ from transformers import AutoModel, AutoConfig, AutoImageProcessor
4
+
5
+ # ✅ point to your local folder (or your HF repo id after pushing)
6
+ repo_or_path = os.path.abspath("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers")
7
+
8
+ print("Loading config...")
9
+ cfg = AutoConfig.from_pretrained(repo_or_path, trust_remote_code=True)
10
+ print(cfg)
11
+
12
+ print("Loading model and weights...")
13
+ proc = AutoImageProcessor.from_pretrained(repo_or_path, trust_remote_code=True)
14
+
15
+ model = AutoModel.from_pretrained(repo_or_path, trust_remote_code=True)
16
+ model.eval()
17
+
18
+ # --- quick synthetic forward ---
19
+ # x = torch.randn(1, cfg.in_channels, 512, 512)
20
+ from PIL import Image
21
+ from torchvision.transforms.v2.functional import pil_to_tensor, to_pil_image
22
+ x = Image.open("/home/nicola/Downloads/45.png").convert("RGB")
23
+ inputs = proc(images=x, return_tensors="pt") # {'pixel_values': B,C,H,W}
24
+ organ_id = torch.tensor([4]) # any valid organ id < cfg.n_organs
25
+ with torch.no_grad():
26
+ out = model(**inputs, organ_id=organ_id)
27
+
28
+ # Post-process: undo letterbox & resize back to original, with threshold 0.7
29
+ masks = proc.post_process_semantic_segmentation(out, inputs, threshold=0.7, return_as_pil=True)
30
+
31
+ # Save the first (since you used a single image, that'll be masks[0])
32
+ masks[0].save("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/tmp.png")
test_4_stages.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test_load_film_unet_4_stages.py
2
+ import torch, os
3
+ from transformers import AutoModel, AutoConfig, AutoImageProcessor
4
+ from PIL import Image
5
+
6
+ # This script tests the 4-stage U-Net model.
7
+
8
+ # ✅ Point to the root folder of your repository
9
+ repo_or_path = os.path.abspath("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers")
10
+ subfolder_4_stages = "unet_4_stages"
11
+
12
+ # --- IMPORTANT ---
13
+ # You need to place the correct model weights for the 4-stage U-Net in the
14
+ # 'unet_4_stages' directory. The file should be named 'model.safetensors'.
15
+ # The path is: /home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/unet_4_stages/model.safetensors
16
+ # -----------------
17
+
18
+ print("Loading 4-stage model and processor...")
19
+ try:
20
+ proc = AutoImageProcessor.from_pretrained(repo_or_path, subfolder=subfolder_4_stages, trust_remote_code=True)
21
+ model = AutoModel.from_pretrained(repo_or_path, subfolder=subfolder_4_stages, trust_remote_code=True)
22
+ model.eval()
23
+ print("Model loaded successfully.")
24
+ except Exception as e:
25
+ print(f"Error loading the 4-stage model: {e}")
26
+ print("Please ensure the 'model.safetensors' file in the 'unet_4_stages' directory is compatible with the 4-stage architecture.")
27
+ exit()
28
+
29
+ # --- Inference ---
30
+ image_path = "/home/nicola/Downloads/45.png"
31
+ if not os.path.exists(image_path):
32
+ print(f"Error: Image file not found at {image_path}")
33
+ exit()
34
+
35
+ print(f"Loading image from {image_path}...")
36
+ image = Image.open(image_path).convert("RGB")
37
+ inputs = proc(images=image, return_tensors="pt")
38
+
39
+ # Use an appropriate organ ID for your test case
40
+ organ_id = torch.tensor([4])
41
+
42
+ print("Running inference...")
43
+ with torch.no_grad():
44
+ out = model(**inputs, organ_id=organ_id)
45
+
46
+ # Post-process to get the segmentation mask
47
+ masks = proc.post_process_semantic_segmentation(out, inputs, threshold=0.7, return_as_pil=True)
48
+
49
+ # Save the output mask
50
+ output_path = "/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/tmp_4_stages.png"
51
+ masks[0].save(output_path)
52
+
53
+ print(f"✅ Test complete. Segmentation mask saved to {output_path}")
test_5_stages_testicles.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test_load_film_unet_5_stages_testicles.py
2
+ import torch, os
3
+ from transformers import AutoModel, AutoImageProcessor
4
+ from PIL import Image
5
+
6
+ # This script tests the 5-stage U-Net model fine-tuned on testicle ultrasounds.
7
+
8
+ # ✅ Point to the root folder of your repository
9
+ repo_or_path = os.path.abspath("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers")
10
+ subfolder_5_stages = "unet_5_stages_testicles"
11
+
12
+ print("Loading 5-stage testicle-finetuned model and processor...")
13
+ try:
14
+ proc = AutoImageProcessor.from_pretrained(repo_or_path, subfolder=subfolder_5_stages, trust_remote_code=True)
15
+ model = AutoModel.from_pretrained(repo_or_path, subfolder=subfolder_5_stages, trust_remote_code=True)
16
+ model.eval()
17
+ print("Model loaded successfully.")
18
+ except Exception as e:
19
+ print(f"Error loading the model: {e}")
20
+ exit()
21
+
22
+ # --- Inference ---
23
+ image_path = "/home/nicola/Downloads/45.png"
24
+ if not os.path.exists(image_path):
25
+ print(f"Error: Image file not found at {image_path}")
26
+ exit()
27
+
28
+ print(f"Loading image from {image_path}...")
29
+ image = Image.open(image_path).convert("RGB")
30
+ inputs = proc(images=image, return_tensors="pt")
31
+
32
+ # From the dictionary you provided, 'testicle' corresponds to ID 7
33
+ organ_id = torch.tensor([7])
34
+
35
+ print("Running inference with organ_id=7 (testicle)...")
36
+ with torch.no_grad():
37
+ out = model(**inputs, organ_id=organ_id)
38
+
39
+ # Post-process to get the segmentation mask
40
+ masks = proc.post_process_semantic_segmentation(out, inputs, threshold=0.7, return_as_pil=True)
41
+
42
+ # Save the output mask
43
+ output_path = "/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/tmp_5_stages_testicles.png"
44
+ masks[0].save(output_path)
45
+
46
+ print(f"✅ Test complete. Segmentation mask saved to {output_path}")
tmp.png ADDED
tmp_4_stages.png ADDED
unet_4_stages/config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "film_unet2d",
3
+ "architectures": ["FilmUnet2DModel"],
4
+ "auto_map": {
5
+ "AutoConfig": "configuration_film_unet2d.FilmUnet2DConfig",
6
+ "AutoModel": "modeling_film_unet2d.FilmUnet2DModel",
7
+ "AutoImageProcessor": "image_processing_film_unet2d.FilmUnet2DImageProcessor"
8
+ },
9
+ "in_channels": 3,
10
+ "num_classes": 1,
11
+ "n_organs": 9,
12
+ "size": 32,
13
+ "depth": 4,
14
+ "film_start": 0,
15
+ "use_film": 1
16
+ }
unet_4_stages/configuration_film_unet2d.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers.configuration_utils import PretrainedConfig
3
+
4
+ class FilmUnet2DConfig(PretrainedConfig):
5
+ model_type = "film_unet2d"
6
+
7
+ def __init__(self,
8
+ in_channels=3,
9
+ num_classes=1,
10
+ n_organs=9,
11
+ size=32,
12
+ depth=5,
13
+ film_start=0,
14
+ use_film=True,
15
+ **kwargs):
16
+ super().__init__(**kwargs)
17
+ self.in_channels = in_channels
18
+ self.num_classes = num_classes
19
+ self.n_organs = n_organs
20
+ self.size = size
21
+ self.depth = depth
22
+ self.film_start = film_start
23
+ self.use_film = use_film
unet_4_stages/image_processing_film_unet2d.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # image_processing_film_unet2d.py
2
+ from typing import List, Union, Tuple, Optional
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from transformers.image_processing_utils import ImageProcessingMixin
7
+
8
+ ArrayLike = Union[np.ndarray, torch.Tensor, Image.Image]
9
+
10
+ def _to_rgb_numpy(im: ArrayLike) -> np.ndarray:
11
+ # -> float32 HWC in [0,255], 3 channels
12
+ if isinstance(im, Image.Image):
13
+ if im.mode != "RGB":
14
+ im = im.convert("RGB")
15
+ arr = np.array(im, dtype=np.uint8).astype(np.float32)
16
+ elif isinstance(im, torch.Tensor):
17
+ t = im.detach().cpu()
18
+ if t.ndim != 3:
19
+ raise ValueError("Tensor must be 3D (CHW or HWC).")
20
+ if t.shape[0] in (1, 3): # CHW
21
+ if t.shape[0] == 1:
22
+ t = t.repeat(3, 1, 1)
23
+ t = t.permute(1, 2, 0) # HWC
24
+ elif t.shape[-1] == 1: # HWC gray
25
+ t = t.repeat(1, 1, 3)
26
+ arr = t.numpy()
27
+ if arr.dtype in (np.float32, np.float64) and arr.max() <= 1.5:
28
+ arr = (arr * 255.0).astype(np.float32)
29
+ else:
30
+ arr = arr.astype(np.float32)
31
+ else:
32
+ arr = np.array(im)
33
+ if arr.ndim == 2:
34
+ arr = np.repeat(arr[..., None], 3, axis=-1)
35
+ arr = arr.astype(np.float32)
36
+ if arr.max() <= 1.5:
37
+ arr = (arr * 255.0).astype(np.float32)
38
+ if arr.ndim != 3 or arr.shape[-1] != 3:
39
+ raise ValueError("Expected RGB image with shape HxWx3.")
40
+ return arr
41
+
42
+ def _letterbox_keep_ratio(arr: np.ndarray, target_hw: Tuple[int, int]):
43
+ """Resize with aspect ratio preserved and pad with 0 (black) to target (H,W).
44
+ Returns: out(H,W,3), (top, left, new_h, new_w)
45
+ """
46
+ th, tw = target_hw
47
+ h, w = arr.shape[:2]
48
+ scale = min(th / h, tw / w)
49
+ nh, nw = int(round(h * scale)), int(round(w * scale))
50
+ if nh <= 0 or nw <= 0:
51
+ raise ValueError("Invalid resize result.")
52
+ pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8))
53
+ pil = pil.resize((nw, nh), resample=Image.BILINEAR)
54
+ rs = np.array(pil, dtype=np.float32)
55
+ out = np.zeros((th, tw, 3), dtype=np.float32)
56
+ top = (th - nh) // 2
57
+ left = (tw - nw) // 2
58
+ out[top:top+nh, left:left+nw] = rs
59
+ return out, (top, left, nh, nw)
60
+
61
+ def _zscore_ignore_black(chw: np.ndarray, eps: float = 1e-8) -> np.ndarray:
62
+ mask = (chw.sum(axis=0) > 0) # HxW
63
+ if not mask.any():
64
+ return chw.copy()
65
+ valid = chw[:, mask]
66
+ mean = valid.mean()
67
+ std = valid.std()
68
+ return (chw - mean) / std if std > eps else (chw - mean)
69
+
70
+ class FilmUnet2DImageProcessor(ImageProcessingMixin):
71
+ """
72
+ Processor for FILMUnet2D:
73
+ - Convert to RGB
74
+ - Keep-aspect-ratio resize+pad (letterbox) to 512x512 (configurable)
75
+ - Normalize with mean/std in 0–255 space (like your training)
76
+ - Optional z-score 'self_norm' ignoring black pixels
77
+ Returns dict with:
78
+ - pixel_values: torch.FloatTensor [B,3,H,W]
79
+ - original_sizes: torch.LongTensor [B,2] (H,W)
80
+ - letterbox_params: torch.LongTensor [B,4] (top, left, nh, nw) # NEW
81
+ """
82
+
83
+ model_input_names = ["pixel_values"]
84
+
85
+ def __init__(
86
+ self,
87
+ do_resize: bool = True,
88
+ size: Tuple[int, int] = (512, 512),
89
+ keep_ratio: bool = True,
90
+ image_mean: Tuple[float, float, float] = (123.675, 116.28, 103.53),
91
+ image_std: Tuple[float, float, float] = (58.395, 57.12, 57.375),
92
+ self_norm: bool = False,
93
+ **kwargs,
94
+ ):
95
+ super().__init__(**kwargs)
96
+ self.do_resize = bool(do_resize)
97
+ self.size = tuple(size)
98
+ self.keep_ratio = bool(keep_ratio)
99
+ self.image_mean = tuple(float(x) for x in image_mean)
100
+ self.image_std = tuple(float(x) for x in image_std)
101
+ self.self_norm = bool(self_norm)
102
+
103
+ def __call__(
104
+ self,
105
+ images: Union[ArrayLike, List[ArrayLike]],
106
+ return_tensors: Optional[str] = "pt",
107
+ **kwargs,
108
+ ):
109
+ imgs = images if isinstance(images, (list, tuple)) else [images]
110
+ batch = []
111
+ orig_sizes = []
112
+ lb_params = []
113
+
114
+ for im in imgs:
115
+ arr = _to_rgb_numpy(im) # HWC float32 in 0–255
116
+ oh, ow = arr.shape[:2]
117
+ orig_sizes.append((oh, ow))
118
+
119
+ if self.do_resize:
120
+ if self.keep_ratio:
121
+ arr, meta = _letterbox_keep_ratio(arr, self.size) # meta=(top,left,nh,nw)
122
+ else:
123
+ h, w = self.size
124
+ pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8))
125
+ arr = np.array(pil.resize((w, h), resample=Image.BILINEAR), dtype=np.float32)
126
+ meta = (0, 0, h, w)
127
+ else:
128
+ # no resize: still expose meta so postprocess can handle consistently
129
+ h, w = arr.shape[:2]
130
+ pad_h = self.size[0] - h
131
+ pad_w = self.size[1] - w
132
+ top = max(pad_h // 2, 0)
133
+ left = max(pad_w // 2, 0)
134
+ out = np.zeros((*self.size, 3), dtype=np.float32)
135
+ out[top:top+h, left:left+w] = arr[:self.size[0]-top, :self.size[1]-left]
136
+ arr = out
137
+ meta = (top, left, h, w)
138
+
139
+ lb_params.append(meta)
140
+
141
+ mean = np.array(self.image_mean, dtype=np.float32).reshape(1, 1, 3)
142
+ std = np.array(self.image_std, dtype=np.float32).reshape(1, 1, 3)
143
+ arr = (arr - mean) / std # HWC
144
+
145
+ chw = np.transpose(arr, (2, 0, 1)) # C,H,W
146
+ if self.self_norm:
147
+ chw = _zscore_ignore_black(chw)
148
+ batch.append(chw)
149
+
150
+ pixel_values = np.stack(batch, axis=0) # B,C,H,W
151
+ if return_tensors == "pt":
152
+ pixel_values = torch.from_numpy(pixel_values).to(torch.float32)
153
+ original_sizes = torch.tensor(orig_sizes, dtype=torch.long)
154
+ letterbox_params = torch.tensor(lb_params, dtype=torch.long)
155
+ else:
156
+ original_sizes = orig_sizes
157
+ letterbox_params = lb_params
158
+
159
+ return {
160
+ "pixel_values": pixel_values,
161
+ "original_sizes": original_sizes, # (B,2) H,W
162
+ "letterbox_params": letterbox_params # (B,4) top,left,nh,nw in 512x512
163
+ }
164
+
165
+ # ---------- POST-PROCESSING ----------
166
+ def post_process_semantic_segmentation(
167
+ self,
168
+ outputs: dict,
169
+ processor_inputs: Optional[dict] = None,
170
+ threshold: float = 0.5,
171
+ return_as_pil: bool = True,
172
+ ):
173
+ """
174
+ Turn model outputs into masks resized back to the ORIGINAL image sizes,
175
+ with letterbox padding removed.
176
+
177
+ Args:
178
+ outputs: dict from model forward (expects 'logits': [B,1,512,512])
179
+ processor_inputs: the dict returned by __call__ (must contain
180
+ 'original_sizes' [B,2] and 'letterbox_params' [B,4])
181
+ threshold: probability threshold for binarization
182
+ return_as_pil: return a list of PIL Images (uint8 0/255) if True,
183
+ else a list of torch tensors [H,W] uint8
184
+
185
+ Returns:
186
+ List of masks back in original sizes (H,W).
187
+ """
188
+ logits = outputs["logits"] # [B,1,H,W]
189
+ probs = torch.sigmoid(logits)
190
+ masks = (probs > threshold).to(torch.uint8) * 255 # [B,1,H,W] uint8
191
+
192
+ if processor_inputs is None:
193
+ raise ValueError("processor_inputs must be provided to undo letterboxing.")
194
+
195
+ orig_sizes = processor_inputs["original_sizes"] # [B,2]
196
+ lb_params = processor_inputs["letterbox_params"] # [B,4] top,left,nh,nw
197
+
198
+ results = []
199
+ B = masks.shape[0]
200
+ for i in range(B):
201
+ m = masks[i, 0] # [512,512]
202
+ top, left, nh, nw = [int(x) for x in lb_params[i].tolist()]
203
+ # crop letterbox
204
+ m_cropped = m[top:top+nh, left:left+nw] # [nh,nw]
205
+ # resize back to original
206
+ oh, ow = [int(x) for x in orig_sizes[i].tolist()]
207
+ m_resized = torch.nn.functional.interpolate(
208
+ m_cropped.unsqueeze(0).unsqueeze(0).float(),
209
+ size=(oh, ow),
210
+ mode="nearest"
211
+ )[0,0].to(torch.uint8) # [oh,ow]
212
+
213
+ if return_as_pil:
214
+ results.append(Image.fromarray(m_resized.cpu().numpy(), mode="L"))
215
+ else:
216
+ results.append(m_resized)
217
+
218
+ return results
unet_4_stages/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bf559438d470c899a73302e24c3f150cd673ffb67a9e7b844623f8156bbabf6
3
+ size 151840188
unet_4_stages/modeling_film_unet2d.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers.modeling_utils import PreTrainedModel
5
+ from .configuration_film_unet2d import FilmUnet2DConfig
6
+
7
+ class ConvBlock(nn.Module):
8
+ def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
9
+ super().__init__()
10
+ self.block = nn.Sequential(
11
+ nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p),
12
+ nn.InstanceNorm2d(out_ch),
13
+ nn.LeakyReLU(inplace=True),
14
+ )
15
+ def forward(self, x): return self.block(x)
16
+
17
+ class FiLM2d(nn.Module):
18
+ def __init__(self, n_organs, in_channels, emb_dim=64, hidden=None):
19
+ super().__init__()
20
+ hidden = hidden or 2 * in_channels
21
+ self.embed = nn.Embedding(n_organs, emb_dim)
22
+ self.mlp = nn.Sequential(
23
+ nn.Linear(emb_dim, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, 2*in_channels)
24
+ )
25
+ nn.init.zeros_(self.mlp[-1].weight)
26
+ nn.init.constant_(self.mlp[-1].bias[:in_channels], 0)
27
+ nn.init.constant_(self.mlp[-1].bias[in_channels:], 1)
28
+ def forward(self, x, organ_id):
29
+ beta_gamma = self.mlp(self.embed(organ_id))
30
+ beta, gamma = beta_gamma.chunk(2, dim=-1)
31
+ beta = beta.unsqueeze(-1).unsqueeze(-1)
32
+ gamma = gamma.unsqueeze(-1).unsqueeze(-1)
33
+ return gamma * x + beta
34
+
35
+ class DownFiLM(nn.Module):
36
+ def __init__(self, in_chs, out_chs, n_organs):
37
+ super().__init__()
38
+ self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
39
+ self.film_blocks = nn.ModuleList([FiLM2d(n_organs, o) for o in out_chs])
40
+ self.pool = nn.MaxPool2d(2,2)
41
+ def forward(self, x, organ_id):
42
+ for c,f in zip(self.conv_blocks, self.film_blocks):
43
+ x = f(c(x), organ_id)
44
+ return self.pool(x), x
45
+
46
+ class Down(nn.Module):
47
+ def __init__(self, in_chs, out_chs):
48
+ super().__init__()
49
+ self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
50
+ self.pool = nn.MaxPool2d(2,2)
51
+ def forward(self, x):
52
+ for c in self.conv_blocks: x = c(x)
53
+ return self.pool(x), x
54
+
55
+ class UpFiLM(nn.Module):
56
+ def __init__(self, in_chs, out_chs, n_organs, up=True):
57
+ super().__init__()
58
+ self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
59
+ self.film_blocks = nn.ModuleList([FiLM2d(n_organs, o) for o in out_chs])
60
+ self.up_conv_op = nn.ConvTranspose2d(out_chs[-1], out_chs[-1], kernel_size=2, stride=2) if up else None
61
+ def forward(self, x, organ_id):
62
+ for c,f in zip(self.conv_blocks, self.film_blocks):
63
+ x = f(c(x), organ_id)
64
+ return self.up_conv_op(x) if self.up_conv_op is not None else x
65
+
66
+ class Up(nn.Module):
67
+ def __init__(self, in_chs, out_chs, up=True):
68
+ super().__init__()
69
+ self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
70
+ self.up_conv_op = nn.ConvTranspose2d(out_chs[-1], out_chs[-1], kernel_size=2, stride=2) if up else None
71
+ def forward(self, x):
72
+ for c in self.conv_blocks: x = c(x)
73
+ return self.up_conv_op(x) if self.up_conv_op is not None else x
74
+
75
+ class UNet2DFiLMCore(nn.Module):
76
+ def __init__(self, cfg: FilmUnet2DConfig):
77
+ super().__init__()
78
+ size, depth, n_organs = cfg.size, cfg.depth, cfg.n_organs
79
+ use_film, film_start = cfg.use_film, cfg.film_start
80
+ self.encoder = nn.ModuleDict()
81
+ if use_film and 0 >= film_start:
82
+ self.encoder["0"] = DownFiLM([cfg.in_channels, size], [size, size*2], n_organs)
83
+ else:
84
+ self.encoder["0"] = Down([cfg.in_channels, size], [size, size*2])
85
+ for i in range(1, depth):
86
+ in_ch = [size*(2**i), size*(2**i)]
87
+ out_ch = [size*(2**i), size*(2**(i+1))]
88
+ if use_film and i >= film_start:
89
+ self.encoder[str(i)] = DownFiLM(in_ch, out_ch, n_organs)
90
+ else:
91
+ self.encoder[str(i)] = Down(in_ch, out_ch)
92
+ if use_film:
93
+ self.bottleneck = UpFiLM([size*(2**depth), size*(2**depth)], [size*(2**depth), size*(2**(depth+1))], n_organs)
94
+ else:
95
+ self.bottleneck = Up([size*(2**depth), size*(2**depth)], [size*(2**depth), size*(2**(depth+1))])
96
+ self.decoder = nn.ModuleDict()
97
+ for i in range(depth, 1, -1):
98
+ use_film_here = use_film and (i-1) >= film_start
99
+ if use_film_here:
100
+ self.decoder[str(i-1)] = UpFiLM([size*(2**(i+1))+size*(2**i), size*(2**i)], [size*(2**i), size*(2**i)], n_organs)
101
+ else:
102
+ self.decoder[str(i-1)] = Up([size*(2**(i+1))+size*(2**i), size*(2**i)], [size*(2**i), size*(2**i)])
103
+ if use_film and 0 >= film_start:
104
+ self.decoder["0"] = UpFiLM([size*4+size*2, size*2], [size*2, size*2], n_organs, up=False)
105
+ else:
106
+ self.decoder["0"] = Up([size*4+size*2, size*2], [size*2, size*2], up=False)
107
+ self.out_layer = ConvBlock(
108
+ size * 2,
109
+ cfg.num_classes,
110
+ k= 1,s= 1,p=0
111
+ )
112
+ def forward(self, pixel_values, organ_id):
113
+ feats = []
114
+ out, feat = (self.encoder["0"](pixel_values, organ_id) if isinstance(self.encoder["0"], DownFiLM) else self.encoder["0"](pixel_values))
115
+ feats.append(feat)
116
+ for k in list(self.encoder.keys())[1:]:
117
+ blk = self.encoder[k]
118
+ out, feat = (blk(out, organ_id) if isinstance(blk, DownFiLM) else blk(out))
119
+ feats.append(feat)
120
+ out = self.bottleneck(out, organ_id) if isinstance(self.bottleneck, UpFiLM) else self.bottleneck(out)
121
+ for k in self.decoder:
122
+ cat = torch.cat([out, feats[int(k)]], dim=1)
123
+ blk = self.decoder[k]
124
+ out = blk(cat, organ_id) if isinstance(blk, UpFiLM) else blk(cat)
125
+ return self.out_layer(out)
126
+
127
+ class FilmUnet2DModel(PreTrainedModel):
128
+ config_class = FilmUnet2DConfig
129
+ base_model_prefix = "model"
130
+
131
+ def __init__(self, config: FilmUnet2DConfig):
132
+ super().__init__(config)
133
+ self.model = UNet2DFiLMCore(config)
134
+ self.post_init()
135
+
136
+ def forward(self, pixel_values, organ_id, labels=None, **kwargs):
137
+ logits = self.model(pixel_values, organ_id)
138
+ if labels is None:
139
+ return {"logits": logits}
140
+ loss = nn.functional.binary_cross_entropy_with_logits(logits, labels)
141
+ return {"loss": loss, "logits": logits}
unet_4_stages/preprocessor_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_resize": true,
3
+ "size": [512, 512],
4
+ "keep_ratio": true,
5
+ "image_mean": [123.675, 116.28, 103.53],
6
+ "image_std": [58.395, 57.12, 57.375],
7
+ "self_norm": false
8
+ }