This model crops mammography images to eliminate unnecessary background. The model uses a lightweight mobilenetv3_small_100
backbone and predicts normalized xywh
coordinates.
The model was trained and validated using 54,706 screening mammography images from the RSNA Screening Mammography Breast Cancer Detection challenge using a 90%/10% split. On single-fold validation, the model achieved mean absolute errors (normalized coordinates) of:
x: 0.0032
y: 0.0030
w: 0.0054
h: 0.0088
The ground-truth coordinates were generated using the following code:
import cv2
def crop_roi(img):
img = img[5:-5, 5:-5]
output = cv2.connectedComponentsWithStats((img > 10).astype("uint8")[:, :], 8, cv2.CV_32S)
stats = output[2]
idx = stats[1:, 4].argmax() + 1
x1, y1, w, h = stats[idx][:4]
x1 = max(0, x1 - 5)
y1 = max(0, y1 - 5)
img_h, img_w = img.shape[:2]
return x1, y1, w, h)
While not guaranteed to be foolproof, a cursory review of a sample of cropped images demonstrated excellent performance.
The model was trained with a larger batch size (256) to mitigate noise. Input into the model is expected to be [0, 255].
If you are loading from DICOM, you should convert the input into an 8-bit image, pass it through model.preprocess
, and
ensure that it is a float torch.Tensor
before passing it to the model. The normalization step happens within the model itself.
To use the model:
import cv2
import torch
from transformers import AutoModel
model = AutoModel.from_pretrained("ianpan/mammo-crop", trust_remote_code=True)
model = model.eval()
img = cv2.imread(..., 0)
img_shape = torch.tensor([img.shape[:2]])
x = model.preprocess(img)
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0)
x = x.float()
# if you do not provide img_shape
# model will return normalized coordinates
with torch.inference_mode():
coords = model(x, img_shape)
# only 1 sample in batch
coords = coords[0].numpy()
x, y, w, h = coords
# coords already rescaled with img_shape
cropped_img = img[y: y + h, x: x + w]
If you have pydicom
installed, you can also load a DICOM image directly:
img = model.load_image_from_dicom(path_to_dicom)
- Downloads last month
- 934
Model tree for ianpan/mammo-crop
Base model
timm/mobilenetv3_small_100.lamb_in1k