docker_off-mind / app.py
ant-man19's picture
Upload 5 files
8952bd9 verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import PlainTextResponse
import nibabel as nib
import numpy as np
import os
import shutil
import pickle
import pandas as pd
from tempfile import NamedTemporaryFile
# Function to load the model from the pickle file
def load_model():
with open('svm_pipeline.pkl', 'rb') as f:
return pickle.load(f)
# Load the trained model
model = load_model()
# Function to load image data from a filepath
def get_image_data(filepath):
img = nib.load(filepath)
data = img.get_fdata()
return data
# Function to create a vector from a region by time matrix from an image using the atlas
def image_to_vector(image_data, atlas_data):
time_dim = image_data.shape[-1]
column_names = [f'time_{i}' for i in range(time_dim)]
region_names = [f'region_{region}' for region in np.unique(atlas_data)]
reshaped_image_data = image_data.reshape(-1, time_dim)
df_times = pd.DataFrame(reshaped_image_data, columns=column_names)
reshaped_atlas_data = atlas_data.reshape(-1)
df_full = pd.concat([pd.Series(reshaped_atlas_data, name='atlas_region'), df_times], axis=1)
regions_x_time = df_full.groupby('atlas_region').mean()
regions_x_time.index = region_names
regions_x_time_vector = regions_x_time.to_numpy().reshape(-1)
return regions_x_time_vector
# Function to preprocess the input image and extract features
def preprocess_and_extract_features(nifti_data, atlas_data):
features = image_to_vector(nifti_data, atlas_data)
num_required_features = 116
if features.size < num_required_features:
features = np.pad(features, (0, num_required_features - features.size), 'constant')
else:
features = features[:num_required_features]
return features.reshape(1, -1)
# FastAPI app setup
app = FastAPI()
@app.post("/predict", response_class=PlainTextResponse)
async def predict_region(file: UploadFile = File(...)):
temp_file_path = None
try:
temp_file_path = NamedTemporaryFile(suffix=".nii.gz", delete=False).name
with open(temp_file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
img = nib.load(temp_file_path)
data = img.get_fdata()
atlas_filepath = 'aal_mask_pad.nii.gz'
if not os.path.exists(atlas_filepath):
raise FileNotFoundError(f"Atlas file not found at: {atlas_filepath}")
atlas_data = get_image_data(atlas_filepath)
features = preprocess_and_extract_features(data, atlas_data)
prediction = model.predict(features)
return str(prediction[0])
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
if temp_file_path and os.path.exists(temp_file_path):
os.remove(temp_file_path)