Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.responses import PlainTextResponse | |
import nibabel as nib | |
import numpy as np | |
import os | |
import shutil | |
import pickle | |
import pandas as pd | |
from tempfile import NamedTemporaryFile | |
# Function to load the model from the pickle file | |
def load_model(): | |
with open('svm_pipeline.pkl', 'rb') as f: | |
return pickle.load(f) | |
# Load the trained model | |
model = load_model() | |
# Function to load image data from a filepath | |
def get_image_data(filepath): | |
img = nib.load(filepath) | |
data = img.get_fdata() | |
return data | |
# Function to create a vector from a region by time matrix from an image using the atlas | |
def image_to_vector(image_data, atlas_data): | |
time_dim = image_data.shape[-1] | |
column_names = [f'time_{i}' for i in range(time_dim)] | |
region_names = [f'region_{region}' for region in np.unique(atlas_data)] | |
reshaped_image_data = image_data.reshape(-1, time_dim) | |
df_times = pd.DataFrame(reshaped_image_data, columns=column_names) | |
reshaped_atlas_data = atlas_data.reshape(-1) | |
df_full = pd.concat([pd.Series(reshaped_atlas_data, name='atlas_region'), df_times], axis=1) | |
regions_x_time = df_full.groupby('atlas_region').mean() | |
regions_x_time.index = region_names | |
regions_x_time_vector = regions_x_time.to_numpy().reshape(-1) | |
return regions_x_time_vector | |
# Function to preprocess the input image and extract features | |
def preprocess_and_extract_features(nifti_data, atlas_data): | |
features = image_to_vector(nifti_data, atlas_data) | |
num_required_features = 116 | |
if features.size < num_required_features: | |
features = np.pad(features, (0, num_required_features - features.size), 'constant') | |
else: | |
features = features[:num_required_features] | |
return features.reshape(1, -1) | |
# FastAPI app setup | |
app = FastAPI() | |
async def predict_region(file: UploadFile = File(...)): | |
temp_file_path = None | |
try: | |
temp_file_path = NamedTemporaryFile(suffix=".nii.gz", delete=False).name | |
with open(temp_file_path, "wb") as f: | |
shutil.copyfileobj(file.file, f) | |
img = nib.load(temp_file_path) | |
data = img.get_fdata() | |
atlas_filepath = 'aal_mask_pad.nii.gz' | |
if not os.path.exists(atlas_filepath): | |
raise FileNotFoundError(f"Atlas file not found at: {atlas_filepath}") | |
atlas_data = get_image_data(atlas_filepath) | |
features = preprocess_and_extract_features(data, atlas_data) | |
prediction = model.predict(features) | |
return str(prediction[0]) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
finally: | |
if temp_file_path and os.path.exists(temp_file_path): | |
os.remove(temp_file_path) | |