CDIApp / service /predict.py
sdutta28's picture
HF Changes
32cc554
import PIL
import numpy as np
import torch
from acfg.modelconfig import ModelConfig
import torchvision.transforms.functional as F
from torch.nn import functional as Fx
from acfg.appconfig import CLF_MODEL, OOD_MODEL, ServiceConfig, get_device
from service.external import llm_strategy
def transform_for_prediction(img: PIL.Image):
"""Transforms a PIL image for model prediction.
This function applies a series of transformations to prepare an image for model inference:
1. Resizes the image to the model's expected input size
2. Converts the image to a tensor
3. Normalizes the tensor using preconfigured mean and std values
Args:
img (PIL.Image): Input image to transform
Returns:
torch.Tensor: Transformed image tensor ready for model inference
"""
z = img
z = F.resize(img, [ModelConfig.IMG_SIZE, ModelConfig.IMG_SIZE])
z = F.to_tensor(z)
z = F.normalize(z, mean=ModelConfig.IMG_MEAN, std=ModelConfig.IMG_STD)
return z.to(get_device()[1])
def classify_disease(image):
image_tensor = transform_for_prediction(image).unsqueeze(0)
with torch.no_grad():
outputs = CLF_MODEL(image_tensor)
_, predicted = torch.max(outputs, 1)
prediction = predicted.item()
return ServiceConfig.ID2LABEL[prediction]
def img_in_distribution(image):
image_tensor = transform_for_prediction(image).unsqueeze(0)
with torch.no_grad():
output = OOD_MODEL(image_tensor)
mse_loss_value = Fx.mse_loss(output, image_tensor)
print("MSE", mse_loss_value)
return mse_loss_value < ServiceConfig.OOD_THRESHOLD
def workflow(image: np.array):
if not img_in_distribution(image):
disease_name = "Unknown"
remedy = "We do not know the remedy to this one. Sorry!"
else:
disease_name = classify_disease(image)
remedy = "No remedy needed. Plant is Healthy"
print(disease_name)
if "healthy" in disease_name:
return disease_name, remedy
else:
remedy = llm_strategy(ServiceConfig.LLM_MODEL_KEY, disease_name)
return disease_name, remedy