Spaces:
Running
Running
import platform | |
import torch as T | |
import nibabel | |
import numpy as np | |
from monai.networks.nets.unet import UNet | |
from monai.data import DataLoader, Dataset | |
from monai.metrics import compute_meandice | |
from monai.transforms import ( | |
AddChanneld, | |
Compose, | |
LoadImaged, | |
NormalizeIntensityd, | |
ToTensord, | |
Resized, | |
AsDiscrete, | |
ThresholdIntensityd, | |
KeepLargestConnectedComponent | |
) | |
from Engine.utils import ( | |
read_yaml, | |
load_model, | |
store_output | |
) | |
def init_eval_file(directory): | |
with open(directory + "/evaluate.csv", "w+") as file: | |
file.write("epoch,train_loss,val_loss,dice_score\n") | |
def append_eval(eval_path, input_filename, segment_filename, val_loss, dice_score): | |
with open(eval_path, 'a') as file: | |
file.write(f"{input_filename},{segment_filename},{val_loss},{dice_score}\n") | |
def evaluate(model, loss_function, loader, device, evaluate_config): | |
model.eval() | |
losses = [] | |
dices = [] | |
transform = KeepLargestConnectedComponent([1], connectivity = 3) | |
with T.no_grad(): | |
for image in loader: | |
inp, boxes, label = image["image"].to(device), image["boxes"].to(device), image["label"].to(device) | |
directory_split = '\\' if platform.system() == 'Windows' else '/' | |
output = model(inp) | |
original_filename = image['image_meta_dict']['filename_or_obj'][0] | |
segmentation_filename = f"{evaluate_config['save_directory']}{directory_split}seg_{original_filename.split(directory_split)[-1]}" | |
original_image = nibabel.load(original_filename) | |
loss = loss_function(output, label).item() | |
output = output.cpu() | |
output[output >= evaluate_config['output_threshold']] = 1 | |
output[output < evaluate_config['output_threshold']] = 0 | |
output = output.to(device) | |
output = transform(output) | |
if(evaluate_config['save_segmentations']): | |
store_output(output, original_image, segmentation_filename, image['image_meta_dict']['affine'].squeeze(0).numpy()) | |
dice = compute_meandice(output, label).item() | |
append_eval(evaluate_config['save_directory']+"evaluate.csv", original_filename, segmentation_filename, loss, dice) | |
losses.append(loss) | |
dices.append(dice) | |
append_eval(evaluate_config['save_directory']+"evaluate.csv", 'Total', '', sum(losses)/len(losses), sum(dices)/len(dices)) | |
def initiate(config_file): | |
config = read_yaml(config_file) | |
device = T.device(config["device"]) | |
data = read_yaml(config["data"]["train_dataset"]) | |
image_shape = (config["data"]["scale_dim"]["d_0"], config["data"]["scale_dim"]["d_1"], config["data"]["scale_dim"]["d_2"]) | |
for i, d in enumerate(data['data']): | |
data['data'][i]['image'] = data['image_prefix'] + d['label'] | |
data['data'][i]['boxes'] = data['boxes_prefix'] + d['label'] | |
data['data'][i]['label'] = data['label_prefix'] + d['label'] | |
transform = Compose( | |
[ | |
LoadImaged(keys=["image", "boxes", "label"]), | |
AddChanneld(keys=["image", "boxes", "label"]), | |
ToTensord(keys=["image", "boxes", "label"]), | |
] | |
) | |
dataset = Dataset(data['data'], transform) | |
loader = T.utils.data.DataLoader(dataset, 1) | |
model, loss = load_model(config['model'], eval = True) | |
model.load_state_dict(T.load(config["model"]["weights"])) | |
model.to(device) | |
init_eval_file(config['evaluate']['save_directory']) | |
evaluate(model, loss, loader, device, config['evaluate']) | |