File size: 4,316 Bytes
bb92fcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
094a954
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import gradio as gr
import nibabel as nib
import numpy as np
import os
import shutil
import pickle
import pandas as pd

# Function to load the model from the pickle file
def load_model():
    with open('svm_pipeline.pkl', 'rb') as f:
        return pickle.load(f)

# Load the trained model
model = load_model()

# Function to load image data from a filepath
def get_image_data(filepath):
    '''
    Access the floating point data of an image
    
    Input: Filepath to the image
    
    Output: The image's floating point data
    '''
    img = nib.load(filepath)
    data = img.get_fdata()
    return data

# Function to create a vector from a region by time matrix from an image using the atlas
def image_to_vector(image_data, atlas_data):
    '''
    Create a vector from a region by time matrix from an image using the atlas
    
    Input:
        - Data for the image to take points of
        - Data from the atlas to apply to the image data
    
    Output: A vector of the image's region by time matrix
    '''
    # Assuming the time dimension is the last dimension in the image data
    time_dim = image_data.shape[-1]
    column_names = [f'time_{i}' for i in range(time_dim)]
    region_names = [f'region_{region}' for region in np.unique(atlas_data)]

    # Reshape the image data to 2D (voxels x time)
    reshaped_image_data = image_data.reshape(-1, time_dim)

    # Create DataFrame with image data
    df_times = pd.DataFrame(reshaped_image_data, columns=column_names)
    
    # Reshape the atlas data to 1D (voxels)
    reshaped_atlas_data = atlas_data.reshape(-1)
    
    # Combine atlas regions with image data
    df_full = pd.concat([pd.Series(reshaped_atlas_data, name='atlas_region'), df_times], axis=1)
    
    # Group by atlas region and compute mean over time
    regions_x_time = df_full.groupby('atlas_region').mean()
    regions_x_time.index = region_names
    
    # Flatten the region x time matrix to a vector
    regions_x_time_vector = regions_x_time.to_numpy().reshape(-1)
    return regions_x_time_vector

# Function to preprocess the input image and extract features
def preprocess_and_extract_features(nifti_data, atlas_data):
    '''
    Preprocess the input image data and extract features using the atlas.
    
    Input:
        - nifti_data: The NIfTI image data
        - atlas_data: The atlas data
    
    Output: Extracted feature vector
    '''
    features = image_to_vector(nifti_data, atlas_data)
    num_required_features = 116

    # If fewer features are found, pad with zeros; if more, truncate
    if features.size < num_required_features:
        features = np.pad(features, (0, num_required_features - features.size), 'constant')
    else:
        features = features[:num_required_features]

    return features.reshape(1, -1)

def predict_region(input_file):
    temp_file_path = None  # Initialize temp_file_path to None
    try:
        # Create a temporary file with the correct extension
        temp_file_path = input_file.name + ".nii.gz"
        shutil.copy(input_file.name, temp_file_path)
        
        # Load the NIfTI file and the atlas
        img = nib.load(temp_file_path)
        data = img.get_fdata()
        
        # Path to the atlas file
        atlas_filepath = 'aal_mask_pad.nii.gz'  # Corrected file extension
        if not os.path.exists(atlas_filepath):
            raise FileNotFoundError(f"Atlas file not found at: {atlas_filepath}")
        
        atlas_data = get_image_data(atlas_filepath)
        
        # Preprocess and extract features
        features = preprocess_and_extract_features(data, atlas_data)

        # Predict using the loaded model
        prediction = model.predict(features)
        return str(prediction[0])
    except Exception as e:
        return f"Error: {e}"
    finally:
        # Clean up the temporary file
        if temp_file_path and os.path.exists(temp_file_path):
            os.remove(temp_file_path)
            

# Create Gradio interface
interface = gr.Interface(
    fn=predict_region,
    inputs=gr.File(label="Region Image (NIfTI file)"),
    outputs="text",
    title="Region Prediction",
    description="Upload a region image in NIfTI format to get the prediction.",
    allow_flagging="never"  # Disable flagging
)

# Launch the Gradio interface
interface.launch()