| | import os |
| | import json |
| | import torch |
| | import numpy as np |
| | import tensorflow as tf |
| | from PIL import Image |
| | from torchvision import models, transforms |
| |
|
| | |
| |
|
| | BASE_DIR = os.path.dirname(__file__) |
| | MODELS_DIR = os.path.join(BASE_DIR, "models") |
| | LABELS_DIR = os.path.join(BASE_DIR, "labels") |
| |
|
| | |
| |
|
| | @tf.keras.utils.register_keras_serializable() |
| | class FixedDropout(tf.keras.layers.Dropout): |
| | def __init__(self, rate, noise_shape=None, seed=None, **kwargs): |
| | super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs) |
| |
|
| | |
| | @tf.keras.utils.register_keras_serializable() |
| | class EfficientNetB3(tf.keras.Model): |
| | pass |
| |
|
| | |
| |
|
| | KERAS_INPUT_SIZES = { |
| | "corn": 300, |
| | } |
| |
|
| | |
| |
|
| | def preprocess_pytorch(img, size=224): |
| | transform = transforms.Compose([ |
| | transforms.Resize((size, size)), |
| | transforms.ToTensor(), |
| | transforms.Normalize( |
| | mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225] |
| | ) |
| | ]) |
| | return transform(img).unsqueeze(0) |
| |
|
| | def preprocess_keras(img, crop_name): |
| | img = img.convert("RGB") |
| | size = KERAS_INPUT_SIZES.get(crop_name, 224) |
| | img = img.resize((size, size)) |
| | arr = np.array(img).astype("float32") / 255.0 |
| | return np.expand_dims(arr, axis=0) |
| |
|
| | |
| |
|
| | PYTORCH_MODELS = {} |
| | KERAS_MODELS = {} |
| | LABELS = {} |
| |
|
| | |
| |
|
| | def load_models(): |
| | for file in os.listdir(MODELS_DIR): |
| | name, ext = os.path.splitext(file) |
| | crop_name = name.replace("_model", "").lower() |
| |
|
| | model_path = os.path.join(MODELS_DIR, file) |
| | label_path = os.path.join(LABELS_DIR, f"{crop_name}_labels.json") |
| |
|
| | if not os.path.exists(label_path): |
| | raise FileNotFoundError(f"Missing label file: {label_path}") |
| |
|
| | with open(label_path, "r") as f: |
| | LABELS[crop_name] = json.load(f) |
| |
|
| | |
| | if ext == ".pth": |
| | num_classes = len(LABELS[crop_name]) |
| | model = models.resnet18(weights=None) |
| | model.fc = torch.nn.Linear(model.fc.in_features, num_classes) |
| | model.load_state_dict(torch.load(model_path, map_location="cpu")) |
| | model.eval() |
| | PYTORCH_MODELS[crop_name] = model |
| |
|
| | |
| | elif ext in [".keras", ".h5"]: |
| | KERAS_MODELS[crop_name] = tf.keras.models.load_model( |
| | model_path, |
| | custom_objects={ |
| | "swish": tf.keras.activations.swish, |
| | "FixedDropout": FixedDropout, |
| | "EfficientNetB3": EfficientNetB3, |
| | }, |
| | compile=False |
| | ) |
| |
|
| | |
| | load_models() |
| |
|
| | |
| |
|
| | def predict(image, crop_name): |
| | crop_name = crop_name.strip().lower() |
| |
|
| | if crop_name in PYTORCH_MODELS: |
| | model = PYTORCH_MODELS[crop_name] |
| | labels = LABELS[crop_name] |
| | tensor = preprocess_pytorch(image) |
| | with torch.no_grad(): |
| | output = model(tensor) |
| | probs = torch.softmax(output[0], dim=0) |
| | idx = probs.argmax().item() |
| | return labels[idx], float(probs[idx]) |
| |
|
| | elif crop_name in KERAS_MODELS: |
| | model = KERAS_MODELS[crop_name] |
| | labels = LABELS[crop_name] |
| | arr = preprocess_keras(image, crop_name) |
| | preds = model.predict(arr, verbose=0)[0] |
| | idx = int(np.argmax(preds)) |
| | return labels[idx], float(preds[idx]) |
| |
|
| | else: |
| | raise ValueError(f"No model found for crop: {crop_name}") |
| |
|