|
import PIL |
|
import numpy as np |
|
import torch |
|
from acfg.modelconfig import ModelConfig |
|
import torchvision.transforms.functional as F |
|
from torch.nn import functional as Fx |
|
|
|
|
|
from acfg.appconfig import CLF_MODEL, OOD_MODEL, ServiceConfig, get_device |
|
from service.external import llm_strategy |
|
|
|
|
|
def transform_for_prediction(img: PIL.Image): |
|
"""Transforms a PIL image for model prediction. |
|
|
|
This function applies a series of transformations to prepare an image for model inference: |
|
1. Resizes the image to the model's expected input size |
|
2. Converts the image to a tensor |
|
3. Normalizes the tensor using preconfigured mean and std values |
|
|
|
Args: |
|
img (PIL.Image): Input image to transform |
|
|
|
Returns: |
|
torch.Tensor: Transformed image tensor ready for model inference |
|
""" |
|
z = img |
|
z = F.resize(img, [ModelConfig.IMG_SIZE, ModelConfig.IMG_SIZE]) |
|
z = F.to_tensor(z) |
|
z = F.normalize(z, mean=ModelConfig.IMG_MEAN, std=ModelConfig.IMG_STD) |
|
return z.to(get_device()[1]) |
|
|
|
|
|
def classify_disease(image): |
|
image_tensor = transform_for_prediction(image).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
outputs = CLF_MODEL(image_tensor) |
|
_, predicted = torch.max(outputs, 1) |
|
prediction = predicted.item() |
|
|
|
return ServiceConfig.ID2LABEL[prediction] |
|
|
|
|
|
def img_in_distribution(image): |
|
image_tensor = transform_for_prediction(image).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
output = OOD_MODEL(image_tensor) |
|
mse_loss_value = Fx.mse_loss(output, image_tensor) |
|
print("MSE", mse_loss_value) |
|
|
|
return mse_loss_value < ServiceConfig.OOD_THRESHOLD |
|
|
|
|
|
def workflow(image: np.array): |
|
if not img_in_distribution(image): |
|
disease_name = "Unknown" |
|
remedy = "We do not know the remedy to this one. Sorry!" |
|
else: |
|
disease_name = classify_disease(image) |
|
remedy = "No remedy needed. Plant is Healthy" |
|
print(disease_name) |
|
|
|
if "healthy" in disease_name: |
|
return disease_name, remedy |
|
|
|
else: |
|
remedy = llm_strategy(ServiceConfig.LLM_MODEL_KEY, disease_name) |
|
|
|
return disease_name, remedy |
|
|