Upload folder using huggingface_hub
Browse files- 45.png +0 -0
- README.md +109 -0
- config.json +16 -0
- configuration_film_unet2d.py +23 -0
- image_processing_film_unet2d.py +218 -0
- model.safetensors +3 -0
- modeling_film_unet2d.py +141 -0
- preprocessor_config.json +8 -0
- test.py +32 -0
- test_4_stages.py +53 -0
- test_5_stages_testicles.py +46 -0
- tmp.png +0 -0
- tmp_4_stages.png +0 -0
- unet_4_stages/config.json +16 -0
- unet_4_stages/configuration_film_unet2d.py +23 -0
- unet_4_stages/image_processing_film_unet2d.py +218 -0
- unet_4_stages/model.safetensors +3 -0
- unet_4_stages/modeling_film_unet2d.py +141 -0
- unet_4_stages/preprocessor_config.json +8 -0
45.png
ADDED
|
README.md
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FILMUnet2D (Transformers-compatible)
|
| 2 |
+
|
| 3 |
+
This model is a 2D U-Net with FiLM conditioning for multi-organ segmentation.
|
| 4 |
+
|
| 5 |
+
## Installation
|
| 6 |
+
|
| 7 |
+
Make sure you have `transformers` and `torch` installed.
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
pip install transformers torch
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
## Usage
|
| 14 |
+
|
| 15 |
+
You can load the model and processor using the `Auto` classes from `transformers`. Since this repository contains custom code, make sure to pass `trust_remote_code=True`.
|
| 16 |
+
|
| 17 |
+
```python
|
| 18 |
+
import torch
|
| 19 |
+
from transformers import AutoModel, AutoImageProcessor
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
# 1. Load model and processor
|
| 23 |
+
repo_id = "Morelli001/US_UNet2DFiLM"
|
| 24 |
+
|
| 25 |
+
processor = AutoImageProcessor.from_pretrained(repo_id, trust_remote_code=True)
|
| 26 |
+
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
|
| 27 |
+
model.eval()
|
| 28 |
+
|
| 29 |
+
# 2. Load and preprocess your image
|
| 30 |
+
# The processor handles resizing, letterboxing, and normalization.
|
| 31 |
+
image = Image.open("path/to/your/image.png").convert("RGB")
|
| 32 |
+
inputs = processor(images=image, return_tensors="pt")
|
| 33 |
+
|
| 34 |
+
# 3. Prepare conditioning input
|
| 35 |
+
# This should be an integer tensor representing the organ ID.
|
| 36 |
+
# Replace `4` with the appropriate ID for your use case.
|
| 37 |
+
organ_id = torch.tensor([4])
|
| 38 |
+
|
| 39 |
+
# 4. Run inference
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
outputs = model(**inputs, organ_id=organ_id)
|
| 42 |
+
|
| 43 |
+
# 5. Post-process the output to get the final segmentation mask
|
| 44 |
+
# The processor can convert the logits to a binary mask, automatically handling
|
| 45 |
+
# the removal of letterbox padding and resizing to the original image dimensions.
|
| 46 |
+
mask = processor.post_process_semantic_segmentation(
|
| 47 |
+
outputs,
|
| 48 |
+
inputs,
|
| 49 |
+
threshold=0.7,
|
| 50 |
+
return_as_pil=True
|
| 51 |
+
)[0]
|
| 52 |
+
|
| 53 |
+
# 6. Save the result
|
| 54 |
+
mask.save("output_mask.png")
|
| 55 |
+
|
| 56 |
+
print("Segmentation mask saved to output_mask.png")
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### Model Details
|
| 60 |
+
|
| 61 |
+
- **Architecture:** U-Net with FiLM layers for conditional segmentation.
|
| 62 |
+
- **Conditioning:** The model's output is conditioned on an `organ_id` input.
|
| 63 |
+
- **Input:** RGB images.
|
| 64 |
+
- **Output:** A single-channel segmentation mask.
|
| 65 |
+
|
| 66 |
+
### Configuration
|
| 67 |
+
|
| 68 |
+
The model configuration can be accessed via `model.config`. Key parameters include:
|
| 69 |
+
- `in_channels`: Number of input channels (default: 3).
|
| 70 |
+
- `num_classes`: Number of output classes (default: 1).
|
| 71 |
+
- `n_organs`: The number of different organs the model was trained to condition on.
|
| 72 |
+
- `depth`: The depth of the U-Net.
|
| 73 |
+
- `size`: The base number of filters in the first layer.
|
| 74 |
+
|
| 75 |
+
### Organ IDs
|
| 76 |
+
|
| 77 |
+
The `organ_id` passed to the model corresponds to the following mapping:
|
| 78 |
+
|
| 79 |
+
```python
|
| 80 |
+
organ_to_class_dict = {
|
| 81 |
+
"appendix": 0,
|
| 82 |
+
"breast": 1,
|
| 83 |
+
"breast_luminal": 1,
|
| 84 |
+
"cardiac": 2,
|
| 85 |
+
"thyroid": 3,
|
| 86 |
+
"fetal": 4,
|
| 87 |
+
"kidney": 5,
|
| 88 |
+
"liver": 6,
|
| 89 |
+
"testicle": 7,
|
| 90 |
+
}
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
### Alternative Versions
|
| 94 |
+
|
| 95 |
+
This repository contains multiple versions of the model located in subfolders. You can load a specific version by using the `subfolder` parameter.
|
| 96 |
+
|
| 97 |
+
#### 4-Stage U-Net
|
| 98 |
+
|
| 99 |
+
This version has a U-Net depth of 4.
|
| 100 |
+
|
| 101 |
+
```python
|
| 102 |
+
from transformers import AutoModel
|
| 103 |
+
|
| 104 |
+
model_4_stages = AutoModel.from_pretrained(
|
| 105 |
+
"Morelli001/US_UNet2DFiLM",
|
| 106 |
+
subfolder="unet_4_stages",
|
| 107 |
+
trust_remote_code=True
|
| 108 |
+
)
|
| 109 |
+
```
|
config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "film_unet2d",
|
| 3 |
+
"architectures": ["FilmUnet2DModel"],
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoConfig": "configuration_film_unet2d.FilmUnet2DConfig",
|
| 6 |
+
"AutoModel": "modeling_film_unet2d.FilmUnet2DModel",
|
| 7 |
+
"AutoImageProcessor": "image_processing_film_unet2d.FilmUnet2DImageProcessor"
|
| 8 |
+
},
|
| 9 |
+
"in_channels": 3,
|
| 10 |
+
"num_classes": 1,
|
| 11 |
+
"n_organs": 9,
|
| 12 |
+
"size": 32,
|
| 13 |
+
"depth": 5,
|
| 14 |
+
"film_start": 0,
|
| 15 |
+
"use_film": 1
|
| 16 |
+
}
|
configuration_film_unet2d.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 3 |
+
|
| 4 |
+
class FilmUnet2DConfig(PretrainedConfig):
|
| 5 |
+
model_type = "film_unet2d"
|
| 6 |
+
|
| 7 |
+
def __init__(self,
|
| 8 |
+
in_channels=3,
|
| 9 |
+
num_classes=1,
|
| 10 |
+
n_organs=9,
|
| 11 |
+
size=32,
|
| 12 |
+
depth=5,
|
| 13 |
+
film_start=0,
|
| 14 |
+
use_film=True,
|
| 15 |
+
**kwargs):
|
| 16 |
+
super().__init__(**kwargs)
|
| 17 |
+
self.in_channels = in_channels
|
| 18 |
+
self.num_classes = num_classes
|
| 19 |
+
self.n_organs = n_organs
|
| 20 |
+
self.size = size
|
| 21 |
+
self.depth = depth
|
| 22 |
+
self.film_start = film_start
|
| 23 |
+
self.use_film = use_film
|
image_processing_film_unet2d.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# image_processing_film_unet2d.py
|
| 2 |
+
from typing import List, Union, Tuple, Optional
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torch
|
| 6 |
+
from transformers.image_processing_utils import ImageProcessingMixin
|
| 7 |
+
|
| 8 |
+
ArrayLike = Union[np.ndarray, torch.Tensor, Image.Image]
|
| 9 |
+
|
| 10 |
+
def _to_rgb_numpy(im: ArrayLike) -> np.ndarray:
|
| 11 |
+
# -> float32 HWC in [0,255], 3 channels
|
| 12 |
+
if isinstance(im, Image.Image):
|
| 13 |
+
if im.mode != "RGB":
|
| 14 |
+
im = im.convert("RGB")
|
| 15 |
+
arr = np.array(im, dtype=np.uint8).astype(np.float32)
|
| 16 |
+
elif isinstance(im, torch.Tensor):
|
| 17 |
+
t = im.detach().cpu()
|
| 18 |
+
if t.ndim != 3:
|
| 19 |
+
raise ValueError("Tensor must be 3D (CHW or HWC).")
|
| 20 |
+
if t.shape[0] in (1, 3): # CHW
|
| 21 |
+
if t.shape[0] == 1:
|
| 22 |
+
t = t.repeat(3, 1, 1)
|
| 23 |
+
t = t.permute(1, 2, 0) # HWC
|
| 24 |
+
elif t.shape[-1] == 1: # HWC gray
|
| 25 |
+
t = t.repeat(1, 1, 3)
|
| 26 |
+
arr = t.numpy()
|
| 27 |
+
if arr.dtype in (np.float32, np.float64) and arr.max() <= 1.5:
|
| 28 |
+
arr = (arr * 255.0).astype(np.float32)
|
| 29 |
+
else:
|
| 30 |
+
arr = arr.astype(np.float32)
|
| 31 |
+
else:
|
| 32 |
+
arr = np.array(im)
|
| 33 |
+
if arr.ndim == 2:
|
| 34 |
+
arr = np.repeat(arr[..., None], 3, axis=-1)
|
| 35 |
+
arr = arr.astype(np.float32)
|
| 36 |
+
if arr.max() <= 1.5:
|
| 37 |
+
arr = (arr * 255.0).astype(np.float32)
|
| 38 |
+
if arr.ndim != 3 or arr.shape[-1] != 3:
|
| 39 |
+
raise ValueError("Expected RGB image with shape HxWx3.")
|
| 40 |
+
return arr
|
| 41 |
+
|
| 42 |
+
def _letterbox_keep_ratio(arr: np.ndarray, target_hw: Tuple[int, int]):
|
| 43 |
+
"""Resize with aspect ratio preserved and pad with 0 (black) to target (H,W).
|
| 44 |
+
Returns: out(H,W,3), (top, left, new_h, new_w)
|
| 45 |
+
"""
|
| 46 |
+
th, tw = target_hw
|
| 47 |
+
h, w = arr.shape[:2]
|
| 48 |
+
scale = min(th / h, tw / w)
|
| 49 |
+
nh, nw = int(round(h * scale)), int(round(w * scale))
|
| 50 |
+
if nh <= 0 or nw <= 0:
|
| 51 |
+
raise ValueError("Invalid resize result.")
|
| 52 |
+
pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8))
|
| 53 |
+
pil = pil.resize((nw, nh), resample=Image.BILINEAR)
|
| 54 |
+
rs = np.array(pil, dtype=np.float32)
|
| 55 |
+
out = np.zeros((th, tw, 3), dtype=np.float32)
|
| 56 |
+
top = (th - nh) // 2
|
| 57 |
+
left = (tw - nw) // 2
|
| 58 |
+
out[top:top+nh, left:left+nw] = rs
|
| 59 |
+
return out, (top, left, nh, nw)
|
| 60 |
+
|
| 61 |
+
def _zscore_ignore_black(chw: np.ndarray, eps: float = 1e-8) -> np.ndarray:
|
| 62 |
+
mask = (chw.sum(axis=0) > 0) # HxW
|
| 63 |
+
if not mask.any():
|
| 64 |
+
return chw.copy()
|
| 65 |
+
valid = chw[:, mask]
|
| 66 |
+
mean = valid.mean()
|
| 67 |
+
std = valid.std()
|
| 68 |
+
return (chw - mean) / std if std > eps else (chw - mean)
|
| 69 |
+
|
| 70 |
+
class FilmUnet2DImageProcessor(ImageProcessingMixin):
|
| 71 |
+
"""
|
| 72 |
+
Processor for FILMUnet2D:
|
| 73 |
+
- Convert to RGB
|
| 74 |
+
- Keep-aspect-ratio resize+pad (letterbox) to 512x512 (configurable)
|
| 75 |
+
- Normalize with mean/std in 0–255 space (like your training)
|
| 76 |
+
- Optional z-score 'self_norm' ignoring black pixels
|
| 77 |
+
Returns dict with:
|
| 78 |
+
- pixel_values: torch.FloatTensor [B,3,H,W]
|
| 79 |
+
- original_sizes: torch.LongTensor [B,2] (H,W)
|
| 80 |
+
- letterbox_params: torch.LongTensor [B,4] (top, left, nh, nw) # NEW
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
model_input_names = ["pixel_values"]
|
| 84 |
+
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
do_resize: bool = True,
|
| 88 |
+
size: Tuple[int, int] = (512, 512),
|
| 89 |
+
keep_ratio: bool = True,
|
| 90 |
+
image_mean: Tuple[float, float, float] = (123.675, 116.28, 103.53),
|
| 91 |
+
image_std: Tuple[float, float, float] = (58.395, 57.12, 57.375),
|
| 92 |
+
self_norm: bool = False,
|
| 93 |
+
**kwargs,
|
| 94 |
+
):
|
| 95 |
+
super().__init__(**kwargs)
|
| 96 |
+
self.do_resize = bool(do_resize)
|
| 97 |
+
self.size = tuple(size)
|
| 98 |
+
self.keep_ratio = bool(keep_ratio)
|
| 99 |
+
self.image_mean = tuple(float(x) for x in image_mean)
|
| 100 |
+
self.image_std = tuple(float(x) for x in image_std)
|
| 101 |
+
self.self_norm = bool(self_norm)
|
| 102 |
+
|
| 103 |
+
def __call__(
|
| 104 |
+
self,
|
| 105 |
+
images: Union[ArrayLike, List[ArrayLike]],
|
| 106 |
+
return_tensors: Optional[str] = "pt",
|
| 107 |
+
**kwargs,
|
| 108 |
+
):
|
| 109 |
+
imgs = images if isinstance(images, (list, tuple)) else [images]
|
| 110 |
+
batch = []
|
| 111 |
+
orig_sizes = []
|
| 112 |
+
lb_params = []
|
| 113 |
+
|
| 114 |
+
for im in imgs:
|
| 115 |
+
arr = _to_rgb_numpy(im) # HWC float32 in 0–255
|
| 116 |
+
oh, ow = arr.shape[:2]
|
| 117 |
+
orig_sizes.append((oh, ow))
|
| 118 |
+
|
| 119 |
+
if self.do_resize:
|
| 120 |
+
if self.keep_ratio:
|
| 121 |
+
arr, meta = _letterbox_keep_ratio(arr, self.size) # meta=(top,left,nh,nw)
|
| 122 |
+
else:
|
| 123 |
+
h, w = self.size
|
| 124 |
+
pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8))
|
| 125 |
+
arr = np.array(pil.resize((w, h), resample=Image.BILINEAR), dtype=np.float32)
|
| 126 |
+
meta = (0, 0, h, w)
|
| 127 |
+
else:
|
| 128 |
+
# no resize: still expose meta so postprocess can handle consistently
|
| 129 |
+
h, w = arr.shape[:2]
|
| 130 |
+
pad_h = self.size[0] - h
|
| 131 |
+
pad_w = self.size[1] - w
|
| 132 |
+
top = max(pad_h // 2, 0)
|
| 133 |
+
left = max(pad_w // 2, 0)
|
| 134 |
+
out = np.zeros((*self.size, 3), dtype=np.float32)
|
| 135 |
+
out[top:top+h, left:left+w] = arr[:self.size[0]-top, :self.size[1]-left]
|
| 136 |
+
arr = out
|
| 137 |
+
meta = (top, left, h, w)
|
| 138 |
+
|
| 139 |
+
lb_params.append(meta)
|
| 140 |
+
|
| 141 |
+
mean = np.array(self.image_mean, dtype=np.float32).reshape(1, 1, 3)
|
| 142 |
+
std = np.array(self.image_std, dtype=np.float32).reshape(1, 1, 3)
|
| 143 |
+
arr = (arr - mean) / std # HWC
|
| 144 |
+
|
| 145 |
+
chw = np.transpose(arr, (2, 0, 1)) # C,H,W
|
| 146 |
+
if self.self_norm:
|
| 147 |
+
chw = _zscore_ignore_black(chw)
|
| 148 |
+
batch.append(chw)
|
| 149 |
+
|
| 150 |
+
pixel_values = np.stack(batch, axis=0) # B,C,H,W
|
| 151 |
+
if return_tensors == "pt":
|
| 152 |
+
pixel_values = torch.from_numpy(pixel_values).to(torch.float32)
|
| 153 |
+
original_sizes = torch.tensor(orig_sizes, dtype=torch.long)
|
| 154 |
+
letterbox_params = torch.tensor(lb_params, dtype=torch.long)
|
| 155 |
+
else:
|
| 156 |
+
original_sizes = orig_sizes
|
| 157 |
+
letterbox_params = lb_params
|
| 158 |
+
|
| 159 |
+
return {
|
| 160 |
+
"pixel_values": pixel_values,
|
| 161 |
+
"original_sizes": original_sizes, # (B,2) H,W
|
| 162 |
+
"letterbox_params": letterbox_params # (B,4) top,left,nh,nw in 512x512
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
# ---------- POST-PROCESSING ----------
|
| 166 |
+
def post_process_semantic_segmentation(
|
| 167 |
+
self,
|
| 168 |
+
outputs: dict,
|
| 169 |
+
processor_inputs: Optional[dict] = None,
|
| 170 |
+
threshold: float = 0.5,
|
| 171 |
+
return_as_pil: bool = True,
|
| 172 |
+
):
|
| 173 |
+
"""
|
| 174 |
+
Turn model outputs into masks resized back to the ORIGINAL image sizes,
|
| 175 |
+
with letterbox padding removed.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
outputs: dict from model forward (expects 'logits': [B,1,512,512])
|
| 179 |
+
processor_inputs: the dict returned by __call__ (must contain
|
| 180 |
+
'original_sizes' [B,2] and 'letterbox_params' [B,4])
|
| 181 |
+
threshold: probability threshold for binarization
|
| 182 |
+
return_as_pil: return a list of PIL Images (uint8 0/255) if True,
|
| 183 |
+
else a list of torch tensors [H,W] uint8
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
List of masks back in original sizes (H,W).
|
| 187 |
+
"""
|
| 188 |
+
logits = outputs["logits"] # [B,1,H,W]
|
| 189 |
+
probs = torch.sigmoid(logits)
|
| 190 |
+
masks = (probs > threshold).to(torch.uint8) * 255 # [B,1,H,W] uint8
|
| 191 |
+
|
| 192 |
+
if processor_inputs is None:
|
| 193 |
+
raise ValueError("processor_inputs must be provided to undo letterboxing.")
|
| 194 |
+
|
| 195 |
+
orig_sizes = processor_inputs["original_sizes"] # [B,2]
|
| 196 |
+
lb_params = processor_inputs["letterbox_params"] # [B,4] top,left,nh,nw
|
| 197 |
+
|
| 198 |
+
results = []
|
| 199 |
+
B = masks.shape[0]
|
| 200 |
+
for i in range(B):
|
| 201 |
+
m = masks[i, 0] # [512,512]
|
| 202 |
+
top, left, nh, nw = [int(x) for x in lb_params[i].tolist()]
|
| 203 |
+
# crop letterbox
|
| 204 |
+
m_cropped = m[top:top+nh, left:left+nw] # [nh,nw]
|
| 205 |
+
# resize back to original
|
| 206 |
+
oh, ow = [int(x) for x in orig_sizes[i].tolist()]
|
| 207 |
+
m_resized = torch.nn.functional.interpolate(
|
| 208 |
+
m_cropped.unsqueeze(0).unsqueeze(0).float(),
|
| 209 |
+
size=(oh, ow),
|
| 210 |
+
mode="nearest"
|
| 211 |
+
)[0,0].to(torch.uint8) # [oh,ow]
|
| 212 |
+
|
| 213 |
+
if return_as_pil:
|
| 214 |
+
results.append(Image.fromarray(m_resized.cpu().numpy(), mode="L"))
|
| 215 |
+
else:
|
| 216 |
+
results.append(m_resized)
|
| 217 |
+
|
| 218 |
+
return results
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:66b39afecca324389126efc9d7995b707cfe2cc0330e2bfa12728829bd79b2a6
|
| 3 |
+
size 603375348
|
modeling_film_unet2d.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 5 |
+
from .configuration_film_unet2d import FilmUnet2DConfig
|
| 6 |
+
|
| 7 |
+
class ConvBlock(nn.Module):
|
| 8 |
+
def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.block = nn.Sequential(
|
| 11 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p),
|
| 12 |
+
nn.InstanceNorm2d(out_ch),
|
| 13 |
+
nn.LeakyReLU(inplace=True),
|
| 14 |
+
)
|
| 15 |
+
def forward(self, x): return self.block(x)
|
| 16 |
+
|
| 17 |
+
class FiLM2d(nn.Module):
|
| 18 |
+
def __init__(self, n_organs, in_channels, emb_dim=64, hidden=None):
|
| 19 |
+
super().__init__()
|
| 20 |
+
hidden = hidden or 2 * in_channels
|
| 21 |
+
self.embed = nn.Embedding(n_organs, emb_dim)
|
| 22 |
+
self.mlp = nn.Sequential(
|
| 23 |
+
nn.Linear(emb_dim, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, 2*in_channels)
|
| 24 |
+
)
|
| 25 |
+
nn.init.zeros_(self.mlp[-1].weight)
|
| 26 |
+
nn.init.constant_(self.mlp[-1].bias[:in_channels], 0)
|
| 27 |
+
nn.init.constant_(self.mlp[-1].bias[in_channels:], 1)
|
| 28 |
+
def forward(self, x, organ_id):
|
| 29 |
+
beta_gamma = self.mlp(self.embed(organ_id))
|
| 30 |
+
beta, gamma = beta_gamma.chunk(2, dim=-1)
|
| 31 |
+
beta = beta.unsqueeze(-1).unsqueeze(-1)
|
| 32 |
+
gamma = gamma.unsqueeze(-1).unsqueeze(-1)
|
| 33 |
+
return gamma * x + beta
|
| 34 |
+
|
| 35 |
+
class DownFiLM(nn.Module):
|
| 36 |
+
def __init__(self, in_chs, out_chs, n_organs):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
|
| 39 |
+
self.film_blocks = nn.ModuleList([FiLM2d(n_organs, o) for o in out_chs])
|
| 40 |
+
self.pool = nn.MaxPool2d(2,2)
|
| 41 |
+
def forward(self, x, organ_id):
|
| 42 |
+
for c,f in zip(self.conv_blocks, self.film_blocks):
|
| 43 |
+
x = f(c(x), organ_id)
|
| 44 |
+
return self.pool(x), x
|
| 45 |
+
|
| 46 |
+
class Down(nn.Module):
|
| 47 |
+
def __init__(self, in_chs, out_chs):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
|
| 50 |
+
self.pool = nn.MaxPool2d(2,2)
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
for c in self.conv_blocks: x = c(x)
|
| 53 |
+
return self.pool(x), x
|
| 54 |
+
|
| 55 |
+
class UpFiLM(nn.Module):
|
| 56 |
+
def __init__(self, in_chs, out_chs, n_organs, up=True):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
|
| 59 |
+
self.film_blocks = nn.ModuleList([FiLM2d(n_organs, o) for o in out_chs])
|
| 60 |
+
self.up_conv_op = nn.ConvTranspose2d(out_chs[-1], out_chs[-1], kernel_size=2, stride=2) if up else None
|
| 61 |
+
def forward(self, x, organ_id):
|
| 62 |
+
for c,f in zip(self.conv_blocks, self.film_blocks):
|
| 63 |
+
x = f(c(x), organ_id)
|
| 64 |
+
return self.up_conv_op(x) if self.up_conv_op is not None else x
|
| 65 |
+
|
| 66 |
+
class Up(nn.Module):
|
| 67 |
+
def __init__(self, in_chs, out_chs, up=True):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
|
| 70 |
+
self.up_conv_op = nn.ConvTranspose2d(out_chs[-1], out_chs[-1], kernel_size=2, stride=2) if up else None
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
for c in self.conv_blocks: x = c(x)
|
| 73 |
+
return self.up_conv_op(x) if self.up_conv_op is not None else x
|
| 74 |
+
|
| 75 |
+
class UNet2DFiLMCore(nn.Module):
|
| 76 |
+
def __init__(self, cfg: FilmUnet2DConfig):
|
| 77 |
+
super().__init__()
|
| 78 |
+
size, depth, n_organs = cfg.size, cfg.depth, cfg.n_organs
|
| 79 |
+
use_film, film_start = cfg.use_film, cfg.film_start
|
| 80 |
+
self.encoder = nn.ModuleDict()
|
| 81 |
+
if use_film and 0 >= film_start:
|
| 82 |
+
self.encoder["0"] = DownFiLM([cfg.in_channels, size], [size, size*2], n_organs)
|
| 83 |
+
else:
|
| 84 |
+
self.encoder["0"] = Down([cfg.in_channels, size], [size, size*2])
|
| 85 |
+
for i in range(1, depth):
|
| 86 |
+
in_ch = [size*(2**i), size*(2**i)]
|
| 87 |
+
out_ch = [size*(2**i), size*(2**(i+1))]
|
| 88 |
+
if use_film and i >= film_start:
|
| 89 |
+
self.encoder[str(i)] = DownFiLM(in_ch, out_ch, n_organs)
|
| 90 |
+
else:
|
| 91 |
+
self.encoder[str(i)] = Down(in_ch, out_ch)
|
| 92 |
+
if use_film:
|
| 93 |
+
self.bottleneck = UpFiLM([size*(2**depth), size*(2**depth)], [size*(2**depth), size*(2**(depth+1))], n_organs)
|
| 94 |
+
else:
|
| 95 |
+
self.bottleneck = Up([size*(2**depth), size*(2**depth)], [size*(2**depth), size*(2**(depth+1))])
|
| 96 |
+
self.decoder = nn.ModuleDict()
|
| 97 |
+
for i in range(depth, 1, -1):
|
| 98 |
+
use_film_here = use_film and (i-1) >= film_start
|
| 99 |
+
if use_film_here:
|
| 100 |
+
self.decoder[str(i-1)] = UpFiLM([size*(2**(i+1))+size*(2**i), size*(2**i)], [size*(2**i), size*(2**i)], n_organs)
|
| 101 |
+
else:
|
| 102 |
+
self.decoder[str(i-1)] = Up([size*(2**(i+1))+size*(2**i), size*(2**i)], [size*(2**i), size*(2**i)])
|
| 103 |
+
if use_film and 0 >= film_start:
|
| 104 |
+
self.decoder["0"] = UpFiLM([size*4+size*2, size*2], [size*2, size*2], n_organs, up=False)
|
| 105 |
+
else:
|
| 106 |
+
self.decoder["0"] = Up([size*4+size*2, size*2], [size*2, size*2], up=False)
|
| 107 |
+
self.out_layer = ConvBlock(
|
| 108 |
+
size * 2,
|
| 109 |
+
cfg.num_classes,
|
| 110 |
+
k= 1,s= 1,p=0
|
| 111 |
+
)
|
| 112 |
+
def forward(self, pixel_values, organ_id):
|
| 113 |
+
feats = []
|
| 114 |
+
out, feat = (self.encoder["0"](pixel_values, organ_id) if isinstance(self.encoder["0"], DownFiLM) else self.encoder["0"](pixel_values))
|
| 115 |
+
feats.append(feat)
|
| 116 |
+
for k in list(self.encoder.keys())[1:]:
|
| 117 |
+
blk = self.encoder[k]
|
| 118 |
+
out, feat = (blk(out, organ_id) if isinstance(blk, DownFiLM) else blk(out))
|
| 119 |
+
feats.append(feat)
|
| 120 |
+
out = self.bottleneck(out, organ_id) if isinstance(self.bottleneck, UpFiLM) else self.bottleneck(out)
|
| 121 |
+
for k in self.decoder:
|
| 122 |
+
cat = torch.cat([out, feats[int(k)]], dim=1)
|
| 123 |
+
blk = self.decoder[k]
|
| 124 |
+
out = blk(cat, organ_id) if isinstance(blk, UpFiLM) else blk(cat)
|
| 125 |
+
return self.out_layer(out)
|
| 126 |
+
|
| 127 |
+
class FilmUnet2DModel(PreTrainedModel):
|
| 128 |
+
config_class = FilmUnet2DConfig
|
| 129 |
+
base_model_prefix = "model"
|
| 130 |
+
|
| 131 |
+
def __init__(self, config: FilmUnet2DConfig):
|
| 132 |
+
super().__init__(config)
|
| 133 |
+
self.model = UNet2DFiLMCore(config)
|
| 134 |
+
self.post_init()
|
| 135 |
+
|
| 136 |
+
def forward(self, pixel_values, organ_id, labels=None, **kwargs):
|
| 137 |
+
logits = self.model(pixel_values, organ_id)
|
| 138 |
+
if labels is None:
|
| 139 |
+
return {"logits": logits}
|
| 140 |
+
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels)
|
| 141 |
+
return {"loss": loss, "logits": logits}
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_resize": true,
|
| 3 |
+
"size": [512, 512],
|
| 4 |
+
"keep_ratio": true,
|
| 5 |
+
"image_mean": [123.675, 116.28, 103.53],
|
| 6 |
+
"image_std": [58.395, 57.12, 57.375],
|
| 7 |
+
"self_norm": false
|
| 8 |
+
}
|
test.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# test_load_film_unet2d.py
|
| 2 |
+
import torch, os
|
| 3 |
+
from transformers import AutoModel, AutoConfig, AutoImageProcessor
|
| 4 |
+
|
| 5 |
+
# ✅ point to your local folder (or your HF repo id after pushing)
|
| 6 |
+
repo_or_path = os.path.abspath("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers")
|
| 7 |
+
|
| 8 |
+
print("Loading config...")
|
| 9 |
+
cfg = AutoConfig.from_pretrained(repo_or_path, trust_remote_code=True)
|
| 10 |
+
print(cfg)
|
| 11 |
+
|
| 12 |
+
print("Loading model and weights...")
|
| 13 |
+
proc = AutoImageProcessor.from_pretrained(repo_or_path, trust_remote_code=True)
|
| 14 |
+
|
| 15 |
+
model = AutoModel.from_pretrained(repo_or_path, trust_remote_code=True)
|
| 16 |
+
model.eval()
|
| 17 |
+
|
| 18 |
+
# --- quick synthetic forward ---
|
| 19 |
+
# x = torch.randn(1, cfg.in_channels, 512, 512)
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from torchvision.transforms.v2.functional import pil_to_tensor, to_pil_image
|
| 22 |
+
x = Image.open("/home/nicola/Downloads/45.png").convert("RGB")
|
| 23 |
+
inputs = proc(images=x, return_tensors="pt") # {'pixel_values': B,C,H,W}
|
| 24 |
+
organ_id = torch.tensor([4]) # any valid organ id < cfg.n_organs
|
| 25 |
+
with torch.no_grad():
|
| 26 |
+
out = model(**inputs, organ_id=organ_id)
|
| 27 |
+
|
| 28 |
+
# Post-process: undo letterbox & resize back to original, with threshold 0.7
|
| 29 |
+
masks = proc.post_process_semantic_segmentation(out, inputs, threshold=0.7, return_as_pil=True)
|
| 30 |
+
|
| 31 |
+
# Save the first (since you used a single image, that'll be masks[0])
|
| 32 |
+
masks[0].save("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/tmp.png")
|
test_4_stages.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# test_load_film_unet_4_stages.py
|
| 2 |
+
import torch, os
|
| 3 |
+
from transformers import AutoModel, AutoConfig, AutoImageProcessor
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
# This script tests the 4-stage U-Net model.
|
| 7 |
+
|
| 8 |
+
# ✅ Point to the root folder of your repository
|
| 9 |
+
repo_or_path = os.path.abspath("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers")
|
| 10 |
+
subfolder_4_stages = "unet_4_stages"
|
| 11 |
+
|
| 12 |
+
# --- IMPORTANT ---
|
| 13 |
+
# You need to place the correct model weights for the 4-stage U-Net in the
|
| 14 |
+
# 'unet_4_stages' directory. The file should be named 'model.safetensors'.
|
| 15 |
+
# The path is: /home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/unet_4_stages/model.safetensors
|
| 16 |
+
# -----------------
|
| 17 |
+
|
| 18 |
+
print("Loading 4-stage model and processor...")
|
| 19 |
+
try:
|
| 20 |
+
proc = AutoImageProcessor.from_pretrained(repo_or_path, subfolder=subfolder_4_stages, trust_remote_code=True)
|
| 21 |
+
model = AutoModel.from_pretrained(repo_or_path, subfolder=subfolder_4_stages, trust_remote_code=True)
|
| 22 |
+
model.eval()
|
| 23 |
+
print("Model loaded successfully.")
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print(f"Error loading the 4-stage model: {e}")
|
| 26 |
+
print("Please ensure the 'model.safetensors' file in the 'unet_4_stages' directory is compatible with the 4-stage architecture.")
|
| 27 |
+
exit()
|
| 28 |
+
|
| 29 |
+
# --- Inference ---
|
| 30 |
+
image_path = "/home/nicola/Downloads/45.png"
|
| 31 |
+
if not os.path.exists(image_path):
|
| 32 |
+
print(f"Error: Image file not found at {image_path}")
|
| 33 |
+
exit()
|
| 34 |
+
|
| 35 |
+
print(f"Loading image from {image_path}...")
|
| 36 |
+
image = Image.open(image_path).convert("RGB")
|
| 37 |
+
inputs = proc(images=image, return_tensors="pt")
|
| 38 |
+
|
| 39 |
+
# Use an appropriate organ ID for your test case
|
| 40 |
+
organ_id = torch.tensor([4])
|
| 41 |
+
|
| 42 |
+
print("Running inference...")
|
| 43 |
+
with torch.no_grad():
|
| 44 |
+
out = model(**inputs, organ_id=organ_id)
|
| 45 |
+
|
| 46 |
+
# Post-process to get the segmentation mask
|
| 47 |
+
masks = proc.post_process_semantic_segmentation(out, inputs, threshold=0.7, return_as_pil=True)
|
| 48 |
+
|
| 49 |
+
# Save the output mask
|
| 50 |
+
output_path = "/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/tmp_4_stages.png"
|
| 51 |
+
masks[0].save(output_path)
|
| 52 |
+
|
| 53 |
+
print(f"✅ Test complete. Segmentation mask saved to {output_path}")
|
test_5_stages_testicles.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# test_load_film_unet_5_stages_testicles.py
|
| 2 |
+
import torch, os
|
| 3 |
+
from transformers import AutoModel, AutoImageProcessor
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
# This script tests the 5-stage U-Net model fine-tuned on testicle ultrasounds.
|
| 7 |
+
|
| 8 |
+
# ✅ Point to the root folder of your repository
|
| 9 |
+
repo_or_path = os.path.abspath("/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers")
|
| 10 |
+
subfolder_5_stages = "unet_5_stages_testicles"
|
| 11 |
+
|
| 12 |
+
print("Loading 5-stage testicle-finetuned model and processor...")
|
| 13 |
+
try:
|
| 14 |
+
proc = AutoImageProcessor.from_pretrained(repo_or_path, subfolder=subfolder_5_stages, trust_remote_code=True)
|
| 15 |
+
model = AutoModel.from_pretrained(repo_or_path, subfolder=subfolder_5_stages, trust_remote_code=True)
|
| 16 |
+
model.eval()
|
| 17 |
+
print("Model loaded successfully.")
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(f"Error loading the model: {e}")
|
| 20 |
+
exit()
|
| 21 |
+
|
| 22 |
+
# --- Inference ---
|
| 23 |
+
image_path = "/home/nicola/Downloads/45.png"
|
| 24 |
+
if not os.path.exists(image_path):
|
| 25 |
+
print(f"Error: Image file not found at {image_path}")
|
| 26 |
+
exit()
|
| 27 |
+
|
| 28 |
+
print(f"Loading image from {image_path}...")
|
| 29 |
+
image = Image.open(image_path).convert("RGB")
|
| 30 |
+
inputs = proc(images=image, return_tensors="pt")
|
| 31 |
+
|
| 32 |
+
# From the dictionary you provided, 'testicle' corresponds to ID 7
|
| 33 |
+
organ_id = torch.tensor([7])
|
| 34 |
+
|
| 35 |
+
print("Running inference with organ_id=7 (testicle)...")
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
out = model(**inputs, organ_id=organ_id)
|
| 38 |
+
|
| 39 |
+
# Post-process to get the segmentation mask
|
| 40 |
+
masks = proc.post_process_semantic_segmentation(out, inputs, threshold=0.7, return_as_pil=True)
|
| 41 |
+
|
| 42 |
+
# Save the output mask
|
| 43 |
+
output_path = "/home/nicola/Downloads/FILMUnet2D_transformers_repo/FILMUnet2D_transformers/tmp_5_stages_testicles.png"
|
| 44 |
+
masks[0].save(output_path)
|
| 45 |
+
|
| 46 |
+
print(f"✅ Test complete. Segmentation mask saved to {output_path}")
|
tmp.png
ADDED
|
tmp_4_stages.png
ADDED
|
unet_4_stages/config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "film_unet2d",
|
| 3 |
+
"architectures": ["FilmUnet2DModel"],
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoConfig": "configuration_film_unet2d.FilmUnet2DConfig",
|
| 6 |
+
"AutoModel": "modeling_film_unet2d.FilmUnet2DModel",
|
| 7 |
+
"AutoImageProcessor": "image_processing_film_unet2d.FilmUnet2DImageProcessor"
|
| 8 |
+
},
|
| 9 |
+
"in_channels": 3,
|
| 10 |
+
"num_classes": 1,
|
| 11 |
+
"n_organs": 9,
|
| 12 |
+
"size": 32,
|
| 13 |
+
"depth": 4,
|
| 14 |
+
"film_start": 0,
|
| 15 |
+
"use_film": 1
|
| 16 |
+
}
|
unet_4_stages/configuration_film_unet2d.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 3 |
+
|
| 4 |
+
class FilmUnet2DConfig(PretrainedConfig):
|
| 5 |
+
model_type = "film_unet2d"
|
| 6 |
+
|
| 7 |
+
def __init__(self,
|
| 8 |
+
in_channels=3,
|
| 9 |
+
num_classes=1,
|
| 10 |
+
n_organs=9,
|
| 11 |
+
size=32,
|
| 12 |
+
depth=5,
|
| 13 |
+
film_start=0,
|
| 14 |
+
use_film=True,
|
| 15 |
+
**kwargs):
|
| 16 |
+
super().__init__(**kwargs)
|
| 17 |
+
self.in_channels = in_channels
|
| 18 |
+
self.num_classes = num_classes
|
| 19 |
+
self.n_organs = n_organs
|
| 20 |
+
self.size = size
|
| 21 |
+
self.depth = depth
|
| 22 |
+
self.film_start = film_start
|
| 23 |
+
self.use_film = use_film
|
unet_4_stages/image_processing_film_unet2d.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# image_processing_film_unet2d.py
|
| 2 |
+
from typing import List, Union, Tuple, Optional
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torch
|
| 6 |
+
from transformers.image_processing_utils import ImageProcessingMixin
|
| 7 |
+
|
| 8 |
+
ArrayLike = Union[np.ndarray, torch.Tensor, Image.Image]
|
| 9 |
+
|
| 10 |
+
def _to_rgb_numpy(im: ArrayLike) -> np.ndarray:
|
| 11 |
+
# -> float32 HWC in [0,255], 3 channels
|
| 12 |
+
if isinstance(im, Image.Image):
|
| 13 |
+
if im.mode != "RGB":
|
| 14 |
+
im = im.convert("RGB")
|
| 15 |
+
arr = np.array(im, dtype=np.uint8).astype(np.float32)
|
| 16 |
+
elif isinstance(im, torch.Tensor):
|
| 17 |
+
t = im.detach().cpu()
|
| 18 |
+
if t.ndim != 3:
|
| 19 |
+
raise ValueError("Tensor must be 3D (CHW or HWC).")
|
| 20 |
+
if t.shape[0] in (1, 3): # CHW
|
| 21 |
+
if t.shape[0] == 1:
|
| 22 |
+
t = t.repeat(3, 1, 1)
|
| 23 |
+
t = t.permute(1, 2, 0) # HWC
|
| 24 |
+
elif t.shape[-1] == 1: # HWC gray
|
| 25 |
+
t = t.repeat(1, 1, 3)
|
| 26 |
+
arr = t.numpy()
|
| 27 |
+
if arr.dtype in (np.float32, np.float64) and arr.max() <= 1.5:
|
| 28 |
+
arr = (arr * 255.0).astype(np.float32)
|
| 29 |
+
else:
|
| 30 |
+
arr = arr.astype(np.float32)
|
| 31 |
+
else:
|
| 32 |
+
arr = np.array(im)
|
| 33 |
+
if arr.ndim == 2:
|
| 34 |
+
arr = np.repeat(arr[..., None], 3, axis=-1)
|
| 35 |
+
arr = arr.astype(np.float32)
|
| 36 |
+
if arr.max() <= 1.5:
|
| 37 |
+
arr = (arr * 255.0).astype(np.float32)
|
| 38 |
+
if arr.ndim != 3 or arr.shape[-1] != 3:
|
| 39 |
+
raise ValueError("Expected RGB image with shape HxWx3.")
|
| 40 |
+
return arr
|
| 41 |
+
|
| 42 |
+
def _letterbox_keep_ratio(arr: np.ndarray, target_hw: Tuple[int, int]):
|
| 43 |
+
"""Resize with aspect ratio preserved and pad with 0 (black) to target (H,W).
|
| 44 |
+
Returns: out(H,W,3), (top, left, new_h, new_w)
|
| 45 |
+
"""
|
| 46 |
+
th, tw = target_hw
|
| 47 |
+
h, w = arr.shape[:2]
|
| 48 |
+
scale = min(th / h, tw / w)
|
| 49 |
+
nh, nw = int(round(h * scale)), int(round(w * scale))
|
| 50 |
+
if nh <= 0 or nw <= 0:
|
| 51 |
+
raise ValueError("Invalid resize result.")
|
| 52 |
+
pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8))
|
| 53 |
+
pil = pil.resize((nw, nh), resample=Image.BILINEAR)
|
| 54 |
+
rs = np.array(pil, dtype=np.float32)
|
| 55 |
+
out = np.zeros((th, tw, 3), dtype=np.float32)
|
| 56 |
+
top = (th - nh) // 2
|
| 57 |
+
left = (tw - nw) // 2
|
| 58 |
+
out[top:top+nh, left:left+nw] = rs
|
| 59 |
+
return out, (top, left, nh, nw)
|
| 60 |
+
|
| 61 |
+
def _zscore_ignore_black(chw: np.ndarray, eps: float = 1e-8) -> np.ndarray:
|
| 62 |
+
mask = (chw.sum(axis=0) > 0) # HxW
|
| 63 |
+
if not mask.any():
|
| 64 |
+
return chw.copy()
|
| 65 |
+
valid = chw[:, mask]
|
| 66 |
+
mean = valid.mean()
|
| 67 |
+
std = valid.std()
|
| 68 |
+
return (chw - mean) / std if std > eps else (chw - mean)
|
| 69 |
+
|
| 70 |
+
class FilmUnet2DImageProcessor(ImageProcessingMixin):
|
| 71 |
+
"""
|
| 72 |
+
Processor for FILMUnet2D:
|
| 73 |
+
- Convert to RGB
|
| 74 |
+
- Keep-aspect-ratio resize+pad (letterbox) to 512x512 (configurable)
|
| 75 |
+
- Normalize with mean/std in 0–255 space (like your training)
|
| 76 |
+
- Optional z-score 'self_norm' ignoring black pixels
|
| 77 |
+
Returns dict with:
|
| 78 |
+
- pixel_values: torch.FloatTensor [B,3,H,W]
|
| 79 |
+
- original_sizes: torch.LongTensor [B,2] (H,W)
|
| 80 |
+
- letterbox_params: torch.LongTensor [B,4] (top, left, nh, nw) # NEW
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
model_input_names = ["pixel_values"]
|
| 84 |
+
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
do_resize: bool = True,
|
| 88 |
+
size: Tuple[int, int] = (512, 512),
|
| 89 |
+
keep_ratio: bool = True,
|
| 90 |
+
image_mean: Tuple[float, float, float] = (123.675, 116.28, 103.53),
|
| 91 |
+
image_std: Tuple[float, float, float] = (58.395, 57.12, 57.375),
|
| 92 |
+
self_norm: bool = False,
|
| 93 |
+
**kwargs,
|
| 94 |
+
):
|
| 95 |
+
super().__init__(**kwargs)
|
| 96 |
+
self.do_resize = bool(do_resize)
|
| 97 |
+
self.size = tuple(size)
|
| 98 |
+
self.keep_ratio = bool(keep_ratio)
|
| 99 |
+
self.image_mean = tuple(float(x) for x in image_mean)
|
| 100 |
+
self.image_std = tuple(float(x) for x in image_std)
|
| 101 |
+
self.self_norm = bool(self_norm)
|
| 102 |
+
|
| 103 |
+
def __call__(
|
| 104 |
+
self,
|
| 105 |
+
images: Union[ArrayLike, List[ArrayLike]],
|
| 106 |
+
return_tensors: Optional[str] = "pt",
|
| 107 |
+
**kwargs,
|
| 108 |
+
):
|
| 109 |
+
imgs = images if isinstance(images, (list, tuple)) else [images]
|
| 110 |
+
batch = []
|
| 111 |
+
orig_sizes = []
|
| 112 |
+
lb_params = []
|
| 113 |
+
|
| 114 |
+
for im in imgs:
|
| 115 |
+
arr = _to_rgb_numpy(im) # HWC float32 in 0–255
|
| 116 |
+
oh, ow = arr.shape[:2]
|
| 117 |
+
orig_sizes.append((oh, ow))
|
| 118 |
+
|
| 119 |
+
if self.do_resize:
|
| 120 |
+
if self.keep_ratio:
|
| 121 |
+
arr, meta = _letterbox_keep_ratio(arr, self.size) # meta=(top,left,nh,nw)
|
| 122 |
+
else:
|
| 123 |
+
h, w = self.size
|
| 124 |
+
pil = Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8))
|
| 125 |
+
arr = np.array(pil.resize((w, h), resample=Image.BILINEAR), dtype=np.float32)
|
| 126 |
+
meta = (0, 0, h, w)
|
| 127 |
+
else:
|
| 128 |
+
# no resize: still expose meta so postprocess can handle consistently
|
| 129 |
+
h, w = arr.shape[:2]
|
| 130 |
+
pad_h = self.size[0] - h
|
| 131 |
+
pad_w = self.size[1] - w
|
| 132 |
+
top = max(pad_h // 2, 0)
|
| 133 |
+
left = max(pad_w // 2, 0)
|
| 134 |
+
out = np.zeros((*self.size, 3), dtype=np.float32)
|
| 135 |
+
out[top:top+h, left:left+w] = arr[:self.size[0]-top, :self.size[1]-left]
|
| 136 |
+
arr = out
|
| 137 |
+
meta = (top, left, h, w)
|
| 138 |
+
|
| 139 |
+
lb_params.append(meta)
|
| 140 |
+
|
| 141 |
+
mean = np.array(self.image_mean, dtype=np.float32).reshape(1, 1, 3)
|
| 142 |
+
std = np.array(self.image_std, dtype=np.float32).reshape(1, 1, 3)
|
| 143 |
+
arr = (arr - mean) / std # HWC
|
| 144 |
+
|
| 145 |
+
chw = np.transpose(arr, (2, 0, 1)) # C,H,W
|
| 146 |
+
if self.self_norm:
|
| 147 |
+
chw = _zscore_ignore_black(chw)
|
| 148 |
+
batch.append(chw)
|
| 149 |
+
|
| 150 |
+
pixel_values = np.stack(batch, axis=0) # B,C,H,W
|
| 151 |
+
if return_tensors == "pt":
|
| 152 |
+
pixel_values = torch.from_numpy(pixel_values).to(torch.float32)
|
| 153 |
+
original_sizes = torch.tensor(orig_sizes, dtype=torch.long)
|
| 154 |
+
letterbox_params = torch.tensor(lb_params, dtype=torch.long)
|
| 155 |
+
else:
|
| 156 |
+
original_sizes = orig_sizes
|
| 157 |
+
letterbox_params = lb_params
|
| 158 |
+
|
| 159 |
+
return {
|
| 160 |
+
"pixel_values": pixel_values,
|
| 161 |
+
"original_sizes": original_sizes, # (B,2) H,W
|
| 162 |
+
"letterbox_params": letterbox_params # (B,4) top,left,nh,nw in 512x512
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
# ---------- POST-PROCESSING ----------
|
| 166 |
+
def post_process_semantic_segmentation(
|
| 167 |
+
self,
|
| 168 |
+
outputs: dict,
|
| 169 |
+
processor_inputs: Optional[dict] = None,
|
| 170 |
+
threshold: float = 0.5,
|
| 171 |
+
return_as_pil: bool = True,
|
| 172 |
+
):
|
| 173 |
+
"""
|
| 174 |
+
Turn model outputs into masks resized back to the ORIGINAL image sizes,
|
| 175 |
+
with letterbox padding removed.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
outputs: dict from model forward (expects 'logits': [B,1,512,512])
|
| 179 |
+
processor_inputs: the dict returned by __call__ (must contain
|
| 180 |
+
'original_sizes' [B,2] and 'letterbox_params' [B,4])
|
| 181 |
+
threshold: probability threshold for binarization
|
| 182 |
+
return_as_pil: return a list of PIL Images (uint8 0/255) if True,
|
| 183 |
+
else a list of torch tensors [H,W] uint8
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
List of masks back in original sizes (H,W).
|
| 187 |
+
"""
|
| 188 |
+
logits = outputs["logits"] # [B,1,H,W]
|
| 189 |
+
probs = torch.sigmoid(logits)
|
| 190 |
+
masks = (probs > threshold).to(torch.uint8) * 255 # [B,1,H,W] uint8
|
| 191 |
+
|
| 192 |
+
if processor_inputs is None:
|
| 193 |
+
raise ValueError("processor_inputs must be provided to undo letterboxing.")
|
| 194 |
+
|
| 195 |
+
orig_sizes = processor_inputs["original_sizes"] # [B,2]
|
| 196 |
+
lb_params = processor_inputs["letterbox_params"] # [B,4] top,left,nh,nw
|
| 197 |
+
|
| 198 |
+
results = []
|
| 199 |
+
B = masks.shape[0]
|
| 200 |
+
for i in range(B):
|
| 201 |
+
m = masks[i, 0] # [512,512]
|
| 202 |
+
top, left, nh, nw = [int(x) for x in lb_params[i].tolist()]
|
| 203 |
+
# crop letterbox
|
| 204 |
+
m_cropped = m[top:top+nh, left:left+nw] # [nh,nw]
|
| 205 |
+
# resize back to original
|
| 206 |
+
oh, ow = [int(x) for x in orig_sizes[i].tolist()]
|
| 207 |
+
m_resized = torch.nn.functional.interpolate(
|
| 208 |
+
m_cropped.unsqueeze(0).unsqueeze(0).float(),
|
| 209 |
+
size=(oh, ow),
|
| 210 |
+
mode="nearest"
|
| 211 |
+
)[0,0].to(torch.uint8) # [oh,ow]
|
| 212 |
+
|
| 213 |
+
if return_as_pil:
|
| 214 |
+
results.append(Image.fromarray(m_resized.cpu().numpy(), mode="L"))
|
| 215 |
+
else:
|
| 216 |
+
results.append(m_resized)
|
| 217 |
+
|
| 218 |
+
return results
|
unet_4_stages/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9bf559438d470c899a73302e24c3f150cd673ffb67a9e7b844623f8156bbabf6
|
| 3 |
+
size 151840188
|
unet_4_stages/modeling_film_unet2d.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 5 |
+
from .configuration_film_unet2d import FilmUnet2DConfig
|
| 6 |
+
|
| 7 |
+
class ConvBlock(nn.Module):
|
| 8 |
+
def __init__(self, in_ch, out_ch, k=3, s=1, p=1):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.block = nn.Sequential(
|
| 11 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p),
|
| 12 |
+
nn.InstanceNorm2d(out_ch),
|
| 13 |
+
nn.LeakyReLU(inplace=True),
|
| 14 |
+
)
|
| 15 |
+
def forward(self, x): return self.block(x)
|
| 16 |
+
|
| 17 |
+
class FiLM2d(nn.Module):
|
| 18 |
+
def __init__(self, n_organs, in_channels, emb_dim=64, hidden=None):
|
| 19 |
+
super().__init__()
|
| 20 |
+
hidden = hidden or 2 * in_channels
|
| 21 |
+
self.embed = nn.Embedding(n_organs, emb_dim)
|
| 22 |
+
self.mlp = nn.Sequential(
|
| 23 |
+
nn.Linear(emb_dim, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, 2*in_channels)
|
| 24 |
+
)
|
| 25 |
+
nn.init.zeros_(self.mlp[-1].weight)
|
| 26 |
+
nn.init.constant_(self.mlp[-1].bias[:in_channels], 0)
|
| 27 |
+
nn.init.constant_(self.mlp[-1].bias[in_channels:], 1)
|
| 28 |
+
def forward(self, x, organ_id):
|
| 29 |
+
beta_gamma = self.mlp(self.embed(organ_id))
|
| 30 |
+
beta, gamma = beta_gamma.chunk(2, dim=-1)
|
| 31 |
+
beta = beta.unsqueeze(-1).unsqueeze(-1)
|
| 32 |
+
gamma = gamma.unsqueeze(-1).unsqueeze(-1)
|
| 33 |
+
return gamma * x + beta
|
| 34 |
+
|
| 35 |
+
class DownFiLM(nn.Module):
|
| 36 |
+
def __init__(self, in_chs, out_chs, n_organs):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
|
| 39 |
+
self.film_blocks = nn.ModuleList([FiLM2d(n_organs, o) for o in out_chs])
|
| 40 |
+
self.pool = nn.MaxPool2d(2,2)
|
| 41 |
+
def forward(self, x, organ_id):
|
| 42 |
+
for c,f in zip(self.conv_blocks, self.film_blocks):
|
| 43 |
+
x = f(c(x), organ_id)
|
| 44 |
+
return self.pool(x), x
|
| 45 |
+
|
| 46 |
+
class Down(nn.Module):
|
| 47 |
+
def __init__(self, in_chs, out_chs):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
|
| 50 |
+
self.pool = nn.MaxPool2d(2,2)
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
for c in self.conv_blocks: x = c(x)
|
| 53 |
+
return self.pool(x), x
|
| 54 |
+
|
| 55 |
+
class UpFiLM(nn.Module):
|
| 56 |
+
def __init__(self, in_chs, out_chs, n_organs, up=True):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
|
| 59 |
+
self.film_blocks = nn.ModuleList([FiLM2d(n_organs, o) for o in out_chs])
|
| 60 |
+
self.up_conv_op = nn.ConvTranspose2d(out_chs[-1], out_chs[-1], kernel_size=2, stride=2) if up else None
|
| 61 |
+
def forward(self, x, organ_id):
|
| 62 |
+
for c,f in zip(self.conv_blocks, self.film_blocks):
|
| 63 |
+
x = f(c(x), organ_id)
|
| 64 |
+
return self.up_conv_op(x) if self.up_conv_op is not None else x
|
| 65 |
+
|
| 66 |
+
class Up(nn.Module):
|
| 67 |
+
def __init__(self, in_chs, out_chs, up=True):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.conv_blocks = nn.ModuleList([ConvBlock(i, o) for i,o in zip(in_chs,out_chs)])
|
| 70 |
+
self.up_conv_op = nn.ConvTranspose2d(out_chs[-1], out_chs[-1], kernel_size=2, stride=2) if up else None
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
for c in self.conv_blocks: x = c(x)
|
| 73 |
+
return self.up_conv_op(x) if self.up_conv_op is not None else x
|
| 74 |
+
|
| 75 |
+
class UNet2DFiLMCore(nn.Module):
|
| 76 |
+
def __init__(self, cfg: FilmUnet2DConfig):
|
| 77 |
+
super().__init__()
|
| 78 |
+
size, depth, n_organs = cfg.size, cfg.depth, cfg.n_organs
|
| 79 |
+
use_film, film_start = cfg.use_film, cfg.film_start
|
| 80 |
+
self.encoder = nn.ModuleDict()
|
| 81 |
+
if use_film and 0 >= film_start:
|
| 82 |
+
self.encoder["0"] = DownFiLM([cfg.in_channels, size], [size, size*2], n_organs)
|
| 83 |
+
else:
|
| 84 |
+
self.encoder["0"] = Down([cfg.in_channels, size], [size, size*2])
|
| 85 |
+
for i in range(1, depth):
|
| 86 |
+
in_ch = [size*(2**i), size*(2**i)]
|
| 87 |
+
out_ch = [size*(2**i), size*(2**(i+1))]
|
| 88 |
+
if use_film and i >= film_start:
|
| 89 |
+
self.encoder[str(i)] = DownFiLM(in_ch, out_ch, n_organs)
|
| 90 |
+
else:
|
| 91 |
+
self.encoder[str(i)] = Down(in_ch, out_ch)
|
| 92 |
+
if use_film:
|
| 93 |
+
self.bottleneck = UpFiLM([size*(2**depth), size*(2**depth)], [size*(2**depth), size*(2**(depth+1))], n_organs)
|
| 94 |
+
else:
|
| 95 |
+
self.bottleneck = Up([size*(2**depth), size*(2**depth)], [size*(2**depth), size*(2**(depth+1))])
|
| 96 |
+
self.decoder = nn.ModuleDict()
|
| 97 |
+
for i in range(depth, 1, -1):
|
| 98 |
+
use_film_here = use_film and (i-1) >= film_start
|
| 99 |
+
if use_film_here:
|
| 100 |
+
self.decoder[str(i-1)] = UpFiLM([size*(2**(i+1))+size*(2**i), size*(2**i)], [size*(2**i), size*(2**i)], n_organs)
|
| 101 |
+
else:
|
| 102 |
+
self.decoder[str(i-1)] = Up([size*(2**(i+1))+size*(2**i), size*(2**i)], [size*(2**i), size*(2**i)])
|
| 103 |
+
if use_film and 0 >= film_start:
|
| 104 |
+
self.decoder["0"] = UpFiLM([size*4+size*2, size*2], [size*2, size*2], n_organs, up=False)
|
| 105 |
+
else:
|
| 106 |
+
self.decoder["0"] = Up([size*4+size*2, size*2], [size*2, size*2], up=False)
|
| 107 |
+
self.out_layer = ConvBlock(
|
| 108 |
+
size * 2,
|
| 109 |
+
cfg.num_classes,
|
| 110 |
+
k= 1,s= 1,p=0
|
| 111 |
+
)
|
| 112 |
+
def forward(self, pixel_values, organ_id):
|
| 113 |
+
feats = []
|
| 114 |
+
out, feat = (self.encoder["0"](pixel_values, organ_id) if isinstance(self.encoder["0"], DownFiLM) else self.encoder["0"](pixel_values))
|
| 115 |
+
feats.append(feat)
|
| 116 |
+
for k in list(self.encoder.keys())[1:]:
|
| 117 |
+
blk = self.encoder[k]
|
| 118 |
+
out, feat = (blk(out, organ_id) if isinstance(blk, DownFiLM) else blk(out))
|
| 119 |
+
feats.append(feat)
|
| 120 |
+
out = self.bottleneck(out, organ_id) if isinstance(self.bottleneck, UpFiLM) else self.bottleneck(out)
|
| 121 |
+
for k in self.decoder:
|
| 122 |
+
cat = torch.cat([out, feats[int(k)]], dim=1)
|
| 123 |
+
blk = self.decoder[k]
|
| 124 |
+
out = blk(cat, organ_id) if isinstance(blk, UpFiLM) else blk(cat)
|
| 125 |
+
return self.out_layer(out)
|
| 126 |
+
|
| 127 |
+
class FilmUnet2DModel(PreTrainedModel):
|
| 128 |
+
config_class = FilmUnet2DConfig
|
| 129 |
+
base_model_prefix = "model"
|
| 130 |
+
|
| 131 |
+
def __init__(self, config: FilmUnet2DConfig):
|
| 132 |
+
super().__init__(config)
|
| 133 |
+
self.model = UNet2DFiLMCore(config)
|
| 134 |
+
self.post_init()
|
| 135 |
+
|
| 136 |
+
def forward(self, pixel_values, organ_id, labels=None, **kwargs):
|
| 137 |
+
logits = self.model(pixel_values, organ_id)
|
| 138 |
+
if labels is None:
|
| 139 |
+
return {"logits": logits}
|
| 140 |
+
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels)
|
| 141 |
+
return {"loss": loss, "logits": logits}
|
unet_4_stages/preprocessor_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_resize": true,
|
| 3 |
+
"size": [512, 512],
|
| 4 |
+
"keep_ratio": true,
|
| 5 |
+
"image_mean": [123.675, 116.28, 103.53],
|
| 6 |
+
"image_std": [58.395, 57.12, 57.375],
|
| 7 |
+
"self_norm": false
|
| 8 |
+
}
|