|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import torch
|
|
from torchvision.transforms._functional_tensor import normalize
|
|
import pandas as pd
|
|
from skimage.segmentation import expand_labels
|
|
from skimage.measure import label
|
|
import albumentations as A
|
|
from monai.inferers import SlidingWindowInferer
|
|
from dnafiber.deployment import _get_model
|
|
from dnafiber.postprocess import refine_segmentation
|
|
|
|
transform = A.Compose(
|
|
[
|
|
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
|
A.ToTensorV2(),
|
|
]
|
|
)
|
|
|
|
|
|
def preprocess_image(image):
|
|
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
|
|
image = normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
return image
|
|
|
|
|
|
def convert_to_dataset(counts):
|
|
data = {"index": [], "red": [], "green": [], "ratio": []}
|
|
for k, v in counts.items():
|
|
data["index"].append(k)
|
|
data["green"].append(v["green"])
|
|
data["red"].append(v["red"])
|
|
if v["red"] == 0:
|
|
data["ratio"].append(np.nan)
|
|
else:
|
|
data["ratio"].append(v["green"] / (v["red"]))
|
|
df = pd.DataFrame(data)
|
|
return df
|
|
|
|
|
|
def convert_mask_to_image(mask, expand=False):
|
|
if expand:
|
|
mask = expand_labels(mask, distance=expand)
|
|
h, w = mask.shape
|
|
image = np.zeros((h, w, 3), dtype=np.uint8)
|
|
GREEN = np.array([0, 255, 0])
|
|
RED = np.array([255, 0, 0])
|
|
|
|
image[mask == 1] = RED
|
|
image[mask == 2] = GREEN
|
|
|
|
return image
|
|
|
|
|
|
@torch.inference_mode()
|
|
def infer(model, image, device, scale=0.13, to_numpy=True, only_probabilities=False):
|
|
if isinstance(model, str):
|
|
model = _get_model(device=device, revision=model)
|
|
model_pixel_size = 0.26
|
|
|
|
scale = scale / model_pixel_size
|
|
tensor = transform(image=image)["image"].unsqueeze(0).to(device)
|
|
h, w = tensor.shape[2], tensor.shape[3]
|
|
device = torch.device(device)
|
|
with torch.autocast(device_type=device.type):
|
|
tensor = F.interpolate(
|
|
tensor,
|
|
size=(int(h * scale), int(w * scale)),
|
|
mode="bilinear",
|
|
)
|
|
if tensor.shape[2] > 1024 or tensor.shape[3] > 1024:
|
|
inferer = SlidingWindowInferer(
|
|
roi_size=(1024, 1024),
|
|
sw_batch_size=4,
|
|
overlap=0.25,
|
|
mode="gaussian",
|
|
device=device,
|
|
progress=True,
|
|
)
|
|
output = inferer(tensor, model)
|
|
else:
|
|
output = model(tensor)
|
|
|
|
probabilities = F.softmax(output, dim=1)
|
|
if only_probabilities:
|
|
probabilities = probabilities.cpu()
|
|
|
|
probabilities = F.interpolate(
|
|
probabilities,
|
|
size=(h, w),
|
|
mode="bilinear",
|
|
)
|
|
return probabilities
|
|
|
|
output = F.interpolate(
|
|
probabilities.argmax(dim=1, keepdim=True).float(),
|
|
size=(h, w),
|
|
mode="nearest",
|
|
)
|
|
|
|
output = output.squeeze().byte()
|
|
if to_numpy:
|
|
output = output.cpu().numpy()
|
|
|
|
return output
|
|
|