import json import re from typing import Any, Dict, List from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel from models import ModelManager from utils import mask_pii app = FastAPI() # Initialize model manager model_manager = ModelManager() try: model_manager.load_models() except Exception as e: raise RuntimeError(f"Error loading models: {e}") # Helper class for marking lists that need compact JSON representation class CompactListWrapper: def __init__(self, data_list): self.data = data_list # Custom JSON Encoder (used by CustomFormattedJSONResponse) class CustomJsonEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, CompactListWrapper): return f"__COMPACT_LIST_PLACEHOLDER__{json.dumps(o.data, separators=(',',':'))}__END_PLACEHOLDER__" return super().default(o) # Custom JSONResponse class for specific formatting class CustomFormattedJSONResponse(JSONResponse): def render(self, content: Any) -> bytes: # content is the dictionary passed to the response instance json_string_with_placeholders = json.dumps( content, indent=2, cls=CustomJsonEncoder # Our encoder that inserts placeholders ) # Replace the quoted placeholders with their unquoted compact list content final_json_string = re.sub( r'"__COMPACT_LIST_PLACEHOLDER__(.*?)__END_PLACEHOLDER__"', r'\1', json_string_with_placeholders ) return final_json_string.encode("utf-8") class EmailInput(BaseModel): input_email_body: str @app.post("/classify") async def classify_email(email_input: EmailInput): try: # Mask PII in the email masked_email_str, masked_entities_list_of_dicts = mask_pii( email_input.input_email_body, model_manager.ner_pipeline ) # Classify the masked email predicted_category_str = model_manager.predict(masked_email_str) # Prepare data, wrapping 'position' lists in CompactListWrapper processed_masked_entities = [] for entity_dict in masked_entities_list_of_dicts: # Create a new dict to avoid modifying original from mask_pii if it's reused processed_entity = entity_dict.copy() if "position" in processed_entity and isinstance(processed_entity["position"], list): processed_entity["position"] = CompactListWrapper(processed_entity["position"]) processed_masked_entities.append(processed_entity) response_data = { "input_email_body": email_input.input_email_body, "list_of_masked_entities": processed_masked_entities, "masked_email": masked_email_str, "category_of_the_email": predicted_category_str } # Use the custom response class return CustomFormattedJSONResponse(content=response_data) except Exception as e: # It's good practice to log the actual exception for debugging on the server # import traceback # print(f"Error in classify_email: {str(e)}\n{traceback.format_exc()}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)