Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, HTTPException | |
| import os | |
| import pickle | |
| from back_end.models.embedding_model import generate_embedding | |
| from back_end.schemas.request import TextRequest | |
| from sklearn.linear_model import LogisticRegression | |
| from scipy.spatial.distance import cosine | |
| router = APIRouter() | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # Get the directory of the current file | |
| MODEL_PATH = os.path.join(BASE_DIR, "..", "models", "logistic.pkl") | |
| try: | |
| with open(MODEL_PATH, "rb") as f: | |
| logistic_model = pickle.load(f) | |
| except FileNotFoundError: | |
| raise RuntimeError(f"Model file not found at {MODEL_PATH}") | |
| except pickle.UnpicklingError: | |
| raise RuntimeError(f"Error unpickling model file at {MODEL_PATH}") | |
| def get_embedding(request: TextRequest): | |
| """Returns a 768-dimensional embedding for the given text.""" | |
| if not request.text: | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| embedding = generate_embedding(request.text) | |
| return {"dimensions": len(embedding), "embedding": embedding} | |
| def get_cosine_similarity(request: TextRequest): | |
| """Returns the cosine similarity between two input texts.""" | |
| if not hasattr(request, 'text') or not hasattr(request, 'text2'): | |
| raise HTTPException(status_code=400, detail="Both text inputs must be provided") | |
| embedding1 = generate_embedding(request.text) | |
| embedding2 = generate_embedding(request.text2) | |
| similarity = 1 - cosine(embedding1, embedding2) | |
| return {"cosine_similarity": similarity} | |
| def get_logistic_prediction(request: TextRequest): | |
| """Returns the prediction from the logistic regression model for the input text.""" | |
| if not request.text: | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| embedding = generate_embedding(request.text) | |
| try: | |
| prediction = logistic_model.predict([embedding])[0] | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Model prediction failed: {str(e)}") | |
| return {"prediction": prediction} | |