Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import PlainTextResponse | |
| import nibabel as nib | |
| import numpy as np | |
| import os | |
| import shutil | |
| import pickle | |
| import pandas as pd | |
| from tempfile import NamedTemporaryFile | |
| # Function to load the model from the pickle file | |
| def load_model(): | |
| with open('svm_pipeline.pkl', 'rb') as f: | |
| return pickle.load(f) | |
| # Load the trained model | |
| model = load_model() | |
| # Function to load image data from a filepath | |
| def get_image_data(filepath): | |
| img = nib.load(filepath) | |
| data = img.get_fdata() | |
| return data | |
| # Function to create a vector from a region by time matrix from an image using the atlas | |
| def image_to_vector(image_data, atlas_data): | |
| time_dim = image_data.shape[-1] | |
| column_names = [f'time_{i}' for i in range(time_dim)] | |
| region_names = [f'region_{region}' for region in np.unique(atlas_data)] | |
| reshaped_image_data = image_data.reshape(-1, time_dim) | |
| df_times = pd.DataFrame(reshaped_image_data, columns=column_names) | |
| reshaped_atlas_data = atlas_data.reshape(-1) | |
| df_full = pd.concat([pd.Series(reshaped_atlas_data, name='atlas_region'), df_times], axis=1) | |
| regions_x_time = df_full.groupby('atlas_region').mean() | |
| regions_x_time.index = region_names | |
| regions_x_time_vector = regions_x_time.to_numpy().reshape(-1) | |
| return regions_x_time_vector | |
| # Function to preprocess the input image and extract features | |
| def preprocess_and_extract_features(nifti_data, atlas_data): | |
| features = image_to_vector(nifti_data, atlas_data) | |
| num_required_features = 116 | |
| if features.size < num_required_features: | |
| features = np.pad(features, (0, num_required_features - features.size), 'constant') | |
| else: | |
| features = features[:num_required_features] | |
| return features.reshape(1, -1) | |
| # FastAPI app setup | |
| app = FastAPI() | |
| async def predict_region(file: UploadFile = File(...)): | |
| temp_file_path = None | |
| try: | |
| temp_file_path = NamedTemporaryFile(suffix=".nii.gz", delete=False).name | |
| with open(temp_file_path, "wb") as f: | |
| shutil.copyfileobj(file.file, f) | |
| img = nib.load(temp_file_path) | |
| data = img.get_fdata() | |
| atlas_filepath = 'aal_mask_pad.nii.gz' | |
| if not os.path.exists(atlas_filepath): | |
| raise FileNotFoundError(f"Atlas file not found at: {atlas_filepath}") | |
| atlas_data = get_image_data(atlas_filepath) | |
| features = preprocess_and_extract_features(data, atlas_data) | |
| prediction = model.predict(features) | |
| return str(prediction[0]) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| if temp_file_path and os.path.exists(temp_file_path): | |
| os.remove(temp_file_path) | |