| | """
|
| | Simple FastAPI REST API for Milk Spoilage Classification
|
| |
|
| | This provides a clean REST endpoint for Custom GPT and other integrations.
|
| | """
|
| |
|
| | from fastapi import FastAPI
|
| | from fastapi.middleware.cors import CORSMiddleware
|
| | from pydantic import BaseModel, Field
|
| | import joblib
|
| | import numpy as np
|
| | from typing import Dict
|
| |
|
| |
|
| | model = joblib.load("model.joblib")
|
| |
|
| |
|
| | app = FastAPI(
|
| | title="Milk Spoilage Classification API",
|
| | description="Predict milk spoilage type based on microbial count data",
|
| | version="1.0.0"
|
| | )
|
| |
|
| |
|
| | app.add_middleware(
|
| | CORSMiddleware,
|
| | allow_origins=["*"],
|
| | allow_credentials=True,
|
| | allow_methods=["*"],
|
| | allow_headers=["*"],
|
| | )
|
| |
|
| |
|
| | class PredictionInput(BaseModel):
|
| | spc_d7: float = Field(..., description="Standard Plate Count at Day 7 (log CFU/mL)", ge=0.0, le=10.0)
|
| | spc_d14: float = Field(..., description="Standard Plate Count at Day 14 (log CFU/mL)", ge=0.0, le=10.0)
|
| | spc_d21: float = Field(..., description="Standard Plate Count at Day 21 (log CFU/mL)", ge=0.0, le=10.0)
|
| | tgn_d7: float = Field(..., description="Total Gram-Negative at Day 7 (log CFU/mL)", ge=0.0, le=10.0)
|
| | tgn_d14: float = Field(..., description="Total Gram-Negative at Day 14 (log CFU/mL)", ge=0.0, le=10.0)
|
| | tgn_d21: float = Field(..., description="Total Gram-Negative at Day 21 (log CFU/mL)", ge=0.0, le=10.0)
|
| |
|
| | class Config:
|
| | json_schema_extra = {
|
| | "example": {
|
| | "spc_d7": 4.0,
|
| | "spc_d14": 5.0,
|
| | "spc_d21": 6.0,
|
| | "tgn_d7": 3.0,
|
| | "tgn_d14": 4.0,
|
| | "tgn_d21": 5.0
|
| | }
|
| | }
|
| |
|
| | class PredictionOutput(BaseModel):
|
| | prediction: str = Field(..., description="Predicted spoilage class")
|
| | probabilities: Dict[str, float] = Field(..., description="Probability for each class")
|
| | confidence: float = Field(..., description="Confidence score (max probability)")
|
| |
|
| |
|
| | @app.get("/")
|
| | async def root():
|
| | """Root endpoint with API information."""
|
| | return {
|
| | "message": "Milk Spoilage Classification API",
|
| | "endpoints": {
|
| | "predict": "/predict",
|
| | "health": "/health",
|
| | "docs": "/docs"
|
| | }
|
| | }
|
| |
|
| |
|
| | @app.post("/predict", response_model=PredictionOutput, tags=["Prediction"])
|
| | async def predict(input_data: PredictionInput):
|
| | """
|
| | Predict milk spoilage type based on microbial counts.
|
| |
|
| | Returns the predicted class, probabilities for all classes, and confidence score.
|
| | """
|
| |
|
| | features = np.array([[
|
| | input_data.spc_d7, input_data.spc_d14, input_data.spc_d21,
|
| | input_data.tgn_d7, input_data.tgn_d14, input_data.tgn_d21
|
| | ]])
|
| |
|
| |
|
| | prediction = model.predict(features)[0]
|
| | probabilities = model.predict_proba(features)[0]
|
| |
|
| |
|
| | prob_dict = {
|
| | str(cls): float(prob)
|
| | for cls, prob in zip(model.classes_, probabilities)
|
| | }
|
| |
|
| | return PredictionOutput(
|
| | prediction=str(prediction),
|
| | probabilities=prob_dict,
|
| | confidence=float(max(probabilities))
|
| | )
|
| |
|
| |
|
| | @app.get("/health", tags=["Health"])
|
| | async def health_check():
|
| | """Health check endpoint."""
|
| | return {
|
| | "status": "healthy",
|
| | "model_loaded": model is not None,
|
| | "classes": model.classes_.tolist()
|
| | }
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | import uvicorn
|
| | uvicorn.run(app, host="0.0.0.0", port=7860)
|
| |
|