import torchio as tio import torch from apps.model import model def preprocess_input(uploaded_file): subject = tio.Subject({"CT": tio.ScalarImage(uploaded_file)}) normalize_orientation = tio.ToCanonical() preprocess_spatial = tio.Compose([ normalize_orientation, tio.RescaleIntensity((0, 1)), tio.Resize((300, 300, 400)) ]) transform = preprocess_spatial dataset = tio.SubjectsDataset([subject], transform=transform) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ckpt_path = 'apps/best_model_36.ckpt' checkpoint = torch.load(ckpt_path, map_location=device) model.load_state_dict(checkpoint['state_dict']) model.to(device) model.eval() grid_sampler = tio.inference.GridSampler(dataset[0], 96, (8, 8, 8)) aggregator = tio.inference.GridAggregator(grid_sampler) patch_loader = tio.data.SubjectsLoader(grid_sampler, batch_size=4) with torch.no_grad(): for patches_batch in patch_loader: input_tensor = patches_batch['CT']["data"].to(device) # Get batch of patches locations = patches_batch[tio.LOCATION] # Get locations of patches pred = model(input_tensor) # Compute prediction aggregator.add_batch(pred, locations) output_tensor = aggregator.get_output_tensor() return output_tensor, dataset