import gradio as gr import nibabel as nib import numpy as np import os import shutil import pickle import pandas as pd # Function to load the model from the pickle file def load_model(): with open('svm_pipeline.pkl', 'rb') as f: return pickle.load(f) # Load the trained model model = load_model() # Function to load image data from a filepath def get_image_data(filepath): ''' Access the floating point data of an image Input: Filepath to the image Output: The image's floating point data ''' img = nib.load(filepath) data = img.get_fdata() return data # Function to create a vector from a region by time matrix from an image using the atlas def image_to_vector(image_data, atlas_data): ''' Create a vector from a region by time matrix from an image using the atlas Input: - Data for the image to take points of - Data from the atlas to apply to the image data Output: A vector of the image's region by time matrix ''' # Assuming the time dimension is the last dimension in the image data time_dim = image_data.shape[-1] column_names = [f'time_{i}' for i in range(time_dim)] region_names = [f'region_{region}' for region in np.unique(atlas_data)] # Reshape the image data to 2D (voxels x time) reshaped_image_data = image_data.reshape(-1, time_dim) # Create DataFrame with image data df_times = pd.DataFrame(reshaped_image_data, columns=column_names) # Reshape the atlas data to 1D (voxels) reshaped_atlas_data = atlas_data.reshape(-1) # Combine atlas regions with image data df_full = pd.concat([pd.Series(reshaped_atlas_data, name='atlas_region'), df_times], axis=1) # Group by atlas region and compute mean over time regions_x_time = df_full.groupby('atlas_region').mean() regions_x_time.index = region_names # Flatten the region x time matrix to a vector regions_x_time_vector = regions_x_time.to_numpy().reshape(-1) return regions_x_time_vector # Function to preprocess the input image and extract features def preprocess_and_extract_features(nifti_data, atlas_data): ''' Preprocess the input image data and extract features using the atlas. Input: - nifti_data: The NIfTI image data - atlas_data: The atlas data Output: Extracted feature vector ''' features = image_to_vector(nifti_data, atlas_data) num_required_features = 116 # If fewer features are found, pad with zeros; if more, truncate if features.size < num_required_features: features = np.pad(features, (0, num_required_features - features.size), 'constant') else: features = features[:num_required_features] return features.reshape(1, -1) def predict_region(input_file): temp_file_path = None # Initialize temp_file_path to None try: # Create a temporary file with the correct extension temp_file_path = input_file.name + ".nii.gz" shutil.copy(input_file.name, temp_file_path) # Load the NIfTI file and the atlas img = nib.load(temp_file_path) data = img.get_fdata() # Path to the atlas file atlas_filepath = 'aal_mask_pad.nii.gz' # Corrected file extension if not os.path.exists(atlas_filepath): raise FileNotFoundError(f"Atlas file not found at: {atlas_filepath}") atlas_data = get_image_data(atlas_filepath) # Preprocess and extract features features = preprocess_and_extract_features(data, atlas_data) # Predict using the loaded model prediction = model.predict(features) return str(prediction[0]) except Exception as e: return f"Error: {e}" finally: # Clean up the temporary file if temp_file_path and os.path.exists(temp_file_path): os.remove(temp_file_path) # Create Gradio interface interface = gr.Interface( fn=predict_region, inputs=gr.File(label="Region Image (NIfTI file)"), outputs="text", title="Region Prediction", description="Upload a region image in NIfTI format to get the prediction.", allow_flagging="never" # Disable flagging ) # Launch the Gradio interface interface.launch()