from fastapi import FastAPI
from joblib import load
from concrete.ml.deployment import FHEModelServer
from pydantic import BaseModel
import base64
from pathlib import Path

current_dir = Path(__file__).parent

# Load the model
fhe_model = FHEModelServer("deployment/financial_rating")
fhe_legal_model=FHEModelServer("deployment/legal_rating")

class PredictRequest(BaseModel):
    evaluation_key: str
    encrypted_encoding: str

# Initialize an instance of FastAPI
app = FastAPI()

# Define the default route 
@app.get("/")
def root():
    return {"message": "Welcome to Your Sentiment Classification FHE Model Server!"}

@app.post("/predict_sentiment")
def predict_sentiment(query: PredictRequest):
    encrypted_encoding = base64.b64decode(query.encrypted_encoding)
    evaluation_key = base64.b64decode(query.evaluation_key)
    prediction = fhe_model.run(encrypted_encoding, evaluation_key)

    # Encode base64 the prediction
    encoded_prediction = base64.b64encode(prediction).decode()
    return {"encrypted_prediction": encoded_prediction}


@app.post("/legal_rating")
def predict_sentiment(query: PredictRequest):
    encrypted_encoding = base64.b64decode(query.encrypted_encoding)
    evaluation_key = base64.b64decode(query.evaluation_key)
    prediction = fhe_legal_model.run(encrypted_encoding, evaluation_key)

    # Encode base64 the prediction
    encoded_prediction = base64.b64encode(prediction).decode()
    return {"encrypted_prediction": encoded_prediction}