File size: 4,316 Bytes
bb92fcc 094a954 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import gradio as gr
import nibabel as nib
import numpy as np
import os
import shutil
import pickle
import pandas as pd
# Function to load the model from the pickle file
def load_model():
with open('svm_pipeline.pkl', 'rb') as f:
return pickle.load(f)
# Load the trained model
model = load_model()
# Function to load image data from a filepath
def get_image_data(filepath):
'''
Access the floating point data of an image
Input: Filepath to the image
Output: The image's floating point data
'''
img = nib.load(filepath)
data = img.get_fdata()
return data
# Function to create a vector from a region by time matrix from an image using the atlas
def image_to_vector(image_data, atlas_data):
'''
Create a vector from a region by time matrix from an image using the atlas
Input:
- Data for the image to take points of
- Data from the atlas to apply to the image data
Output: A vector of the image's region by time matrix
'''
# Assuming the time dimension is the last dimension in the image data
time_dim = image_data.shape[-1]
column_names = [f'time_{i}' for i in range(time_dim)]
region_names = [f'region_{region}' for region in np.unique(atlas_data)]
# Reshape the image data to 2D (voxels x time)
reshaped_image_data = image_data.reshape(-1, time_dim)
# Create DataFrame with image data
df_times = pd.DataFrame(reshaped_image_data, columns=column_names)
# Reshape the atlas data to 1D (voxels)
reshaped_atlas_data = atlas_data.reshape(-1)
# Combine atlas regions with image data
df_full = pd.concat([pd.Series(reshaped_atlas_data, name='atlas_region'), df_times], axis=1)
# Group by atlas region and compute mean over time
regions_x_time = df_full.groupby('atlas_region').mean()
regions_x_time.index = region_names
# Flatten the region x time matrix to a vector
regions_x_time_vector = regions_x_time.to_numpy().reshape(-1)
return regions_x_time_vector
# Function to preprocess the input image and extract features
def preprocess_and_extract_features(nifti_data, atlas_data):
'''
Preprocess the input image data and extract features using the atlas.
Input:
- nifti_data: The NIfTI image data
- atlas_data: The atlas data
Output: Extracted feature vector
'''
features = image_to_vector(nifti_data, atlas_data)
num_required_features = 116
# If fewer features are found, pad with zeros; if more, truncate
if features.size < num_required_features:
features = np.pad(features, (0, num_required_features - features.size), 'constant')
else:
features = features[:num_required_features]
return features.reshape(1, -1)
def predict_region(input_file):
temp_file_path = None # Initialize temp_file_path to None
try:
# Create a temporary file with the correct extension
temp_file_path = input_file.name + ".nii.gz"
shutil.copy(input_file.name, temp_file_path)
# Load the NIfTI file and the atlas
img = nib.load(temp_file_path)
data = img.get_fdata()
# Path to the atlas file
atlas_filepath = 'aal_mask_pad.nii.gz' # Corrected file extension
if not os.path.exists(atlas_filepath):
raise FileNotFoundError(f"Atlas file not found at: {atlas_filepath}")
atlas_data = get_image_data(atlas_filepath)
# Preprocess and extract features
features = preprocess_and_extract_features(data, atlas_data)
# Predict using the loaded model
prediction = model.predict(features)
return str(prediction[0])
except Exception as e:
return f"Error: {e}"
finally:
# Clean up the temporary file
if temp_file_path and os.path.exists(temp_file_path):
os.remove(temp_file_path)
# Create Gradio interface
interface = gr.Interface(
fn=predict_region,
inputs=gr.File(label="Region Image (NIfTI file)"),
outputs="text",
title="Region Prediction",
description="Upload a region image in NIfTI format to get the prediction.",
allow_flagging="never" # Disable flagging
)
# Launch the Gradio interface
interface.launch() |