|
import gradio as gr |
|
import nibabel as nib |
|
import numpy as np |
|
import os |
|
import shutil |
|
import pickle |
|
import pandas as pd |
|
|
|
|
|
def load_model(): |
|
with open('svm_pipeline.pkl', 'rb') as f: |
|
return pickle.load(f) |
|
|
|
|
|
model = load_model() |
|
|
|
|
|
def get_image_data(filepath): |
|
''' |
|
Access the floating point data of an image |
|
|
|
Input: Filepath to the image |
|
|
|
Output: The image's floating point data |
|
''' |
|
img = nib.load(filepath) |
|
data = img.get_fdata() |
|
return data |
|
|
|
|
|
def image_to_vector(image_data, atlas_data): |
|
''' |
|
Create a vector from a region by time matrix from an image using the atlas |
|
|
|
Input: |
|
- Data for the image to take points of |
|
- Data from the atlas to apply to the image data |
|
|
|
Output: A vector of the image's region by time matrix |
|
''' |
|
|
|
time_dim = image_data.shape[-1] |
|
column_names = [f'time_{i}' for i in range(time_dim)] |
|
region_names = [f'region_{region}' for region in np.unique(atlas_data)] |
|
|
|
|
|
reshaped_image_data = image_data.reshape(-1, time_dim) |
|
|
|
|
|
df_times = pd.DataFrame(reshaped_image_data, columns=column_names) |
|
|
|
|
|
reshaped_atlas_data = atlas_data.reshape(-1) |
|
|
|
|
|
df_full = pd.concat([pd.Series(reshaped_atlas_data, name='atlas_region'), df_times], axis=1) |
|
|
|
|
|
regions_x_time = df_full.groupby('atlas_region').mean() |
|
regions_x_time.index = region_names |
|
|
|
|
|
regions_x_time_vector = regions_x_time.to_numpy().reshape(-1) |
|
return regions_x_time_vector |
|
|
|
|
|
def preprocess_and_extract_features(nifti_data, atlas_data): |
|
''' |
|
Preprocess the input image data and extract features using the atlas. |
|
|
|
Input: |
|
- nifti_data: The NIfTI image data |
|
- atlas_data: The atlas data |
|
|
|
Output: Extracted feature vector |
|
''' |
|
features = image_to_vector(nifti_data, atlas_data) |
|
num_required_features = 116 |
|
|
|
|
|
if features.size < num_required_features: |
|
features = np.pad(features, (0, num_required_features - features.size), 'constant') |
|
else: |
|
features = features[:num_required_features] |
|
|
|
return features.reshape(1, -1) |
|
|
|
def predict_region(input_file): |
|
temp_file_path = None |
|
try: |
|
|
|
temp_file_path = input_file.name + ".nii.gz" |
|
shutil.copy(input_file.name, temp_file_path) |
|
|
|
|
|
img = nib.load(temp_file_path) |
|
data = img.get_fdata() |
|
|
|
|
|
atlas_filepath = 'aal_mask_pad.nii.gz' |
|
if not os.path.exists(atlas_filepath): |
|
raise FileNotFoundError(f"Atlas file not found at: {atlas_filepath}") |
|
|
|
atlas_data = get_image_data(atlas_filepath) |
|
|
|
|
|
features = preprocess_and_extract_features(data, atlas_data) |
|
|
|
|
|
prediction = model.predict(features) |
|
return str(prediction[0]) |
|
except Exception as e: |
|
return f"Error: {e}" |
|
finally: |
|
|
|
if temp_file_path and os.path.exists(temp_file_path): |
|
os.remove(temp_file_path) |
|
|
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_region, |
|
inputs=gr.File(label="Region Image (NIfTI file)"), |
|
outputs="text", |
|
title="Region Prediction", |
|
description="Upload a region image in NIfTI format to get the prediction.", |
|
allow_flagging="never" |
|
) |
|
|
|
|
|
interface.launch() |