Spaces:
Sleeping
Sleeping
"""Server that will listen for GET and POST requests from the client.""" | |
import time | |
import logging | |
from pathlib import Path | |
from typing import List | |
from fastapi import FastAPI, File, Form, UploadFile, HTTPException | |
from fastapi.responses import JSONResponse, Response | |
from fastapi.middleware.cors import CORSMiddleware | |
from concrete.ml.deployment import FHEModelServer | |
import numpy as np | |
import gc | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allows all origins | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all methods | |
allow_headers=["*"], # Allows all headers | |
) | |
# Initialize the FHE server | |
DEPLOYMENT_DIR = Path(__file__).parent / "deployment_files" | |
FHE_SERVER = FHEModelServer(DEPLOYMENT_DIR) | |
def get_server_file_path(file_type: str, user_id: str) -> Path: | |
"""Get the path to a file on the server.""" | |
if file_type == "encrypted_image": | |
file_type = "encrypted" | |
elif file_type == "evaluation_key": | |
file_type = "evaluation" | |
return Path(__file__).parent / "server_tmp" / f"{file_type}_{user_id}" | |
async def send_input(user_id: str = Form(), files: List[UploadFile] = File(...)): | |
"""Receive the encrypted input image and the evaluation key from the client.""" | |
try: | |
for file in files: | |
file_path = get_server_file_path(file.filename.split("_")[0], user_id) | |
logger.info(f"Saving file to: {file_path}") | |
file_path.parent.mkdir(parents=True, exist_ok=True) # Ensure the directory exists | |
with file_path.open("wb") as buffer: | |
content = await file.read() | |
buffer.write(content) | |
# Check if the file was saved successfully | |
if not file_path.exists(): | |
raise IOError(f"Failed to save file at {file_path}") | |
else: | |
logger.info(f"File saved successfully at {file_path}") | |
return JSONResponse(content={"message": "Files received successfully"}) | |
except Exception as e: | |
logger.error(f"Error in send_input: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
def get_memory_usage(): | |
"""Get current memory usage in GB""" | |
with open('/proc/self/status') as f: | |
memusage = f.read().split('VmRSS:')[1].split('\n')[0][:-3] | |
return int(memusage.strip()) / 1024 / 1024 # Convert KB to GB | |
def run_fhe(user_id: str = Form()): | |
"""Execute seizure detection on the encrypted input image using FHE.""" | |
logger.info(f"Starting FHE execution for user {user_id}") | |
try: | |
# Retrieve the encrypted input image and the evaluation key paths | |
encrypted_image_path = get_server_file_path("encrypted", user_id) | |
evaluation_key_path = get_server_file_path("evaluation", user_id) | |
logger.info(f"Looking for encrypted_image at: {encrypted_image_path}") | |
logger.info(f"Looking for evaluation_key at: {evaluation_key_path}") | |
# Check if files exist | |
if not encrypted_image_path.exists(): | |
raise FileNotFoundError(f"Encrypted image file not found at {encrypted_image_path}") | |
if not evaluation_key_path.exists(): | |
raise FileNotFoundError(f"Evaluation key file not found at {evaluation_key_path}") | |
# Read the files using the above paths | |
with encrypted_image_path.open("rb") as encrypted_image_file, evaluation_key_path.open("rb") as evaluation_key_file: | |
encrypted_image = encrypted_image_file.read() | |
evaluation_key = evaluation_key_file.read() | |
memory_before = get_memory_usage() | |
logger.info(f"Memory usage before FHE execution: {memory_before:.2f} GB") | |
# Force garbage collection before FHE execution | |
gc.collect() | |
# Run the FHE execution | |
start = time.time() | |
# try: | |
# encrypted_output = FHE_SERVER.run(encrypted_image, evaluation_key) | |
# except MemoryError: | |
# logger.error("FHE execution failed due to insufficient memory") | |
# raise HTTPException(status_code=503, detail="Insufficient memory during FHE execution") | |
# except Exception as e: | |
# logger.error(f"FHE execution failed: {str(e)}") | |
# raise HTTPException(status_code=500, detail="FHE execution failed") | |
# finally: | |
# # Force garbage collection after FHE execution | |
# gc.collect() | |
# Placeholder output | |
# Generate a random 2-element array with values between 0 and 1 | |
placeholder_output = np.random.rand(2) | |
# Ensure the sum of the two elements is 1 (to mimic softmax output) | |
placeholder_output = placeholder_output / np.sum(placeholder_output) | |
encrypted_output = placeholder_output.tobytes() | |
fhe_execution_time = round(time.time() - start, 2) | |
# Check memory usage after FHE execution | |
memory_after = get_memory_usage() | |
logger.info(f"Memory usage after FHE execution: {memory_after:.2f} GB") | |
logger.info(f"Memory increase during FHE execution: {memory_after - memory_before:.2f} GB") | |
# Retrieve the encrypted output path | |
encrypted_output_path = get_server_file_path("encrypted_output", user_id) | |
# Write the file using the above path | |
with encrypted_output_path.open("wb") as encrypted_output_file: | |
encrypted_output_file.write(encrypted_output) | |
logger.info(f"FHE execution completed for user {user_id} in {fhe_execution_time} seconds") | |
return JSONResponse(content=fhe_execution_time) | |
except Exception as e: | |
logger.error(f"Error in run_fhe for user {user_id}: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
def get_output(user_id: str = Form()): | |
"""Retrieve the encrypted output.""" | |
try: | |
# Retrieve the encrypted output path | |
encrypted_output_path = get_server_file_path("encrypted_output", user_id) | |
# Check if file exists | |
if not encrypted_output_path.exists(): | |
raise FileNotFoundError("Encrypted output file not found") | |
# Read the file using the above path | |
with encrypted_output_path.open("rb") as encrypted_output_file: | |
encrypted_output = encrypted_output_file.read() | |
return Response(encrypted_output) | |
except Exception as e: | |
logger.error(f"Error in get_output for user {user_id}: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |