import gradio as gr from PIL import Image import os from plant_disease_classifier import PlantDiseaseClassifier # Directory containing test images TEST_IMAGE_DIR = "test" # Define model paths and types model_types = ["resnet", "vit", "levit"] model_paths = { "resnet": "resnet50_ft.pth", "vit": "vit32b_ft.pth", "levit": "levit128s_ft.pth", } classifiers = { name: PlantDiseaseClassifier(model_type, model_path) for name, model_type, model_path in zip(model_paths.keys(), model_types, model_paths.values()) } def get_subdirectories(directory): """Get a list of subdirectories in the directory.""" return [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))] def get_images_in_subdirectory(subdirectory): """Get a list of images in the selected subdirectory.""" subdir_path = os.path.join(TEST_IMAGE_DIR, subdirectory) if os.path.exists(subdir_path): return [f for f in os.listdir(subdir_path) if f.lower().endswith(('.jpg', '.png'))] return [] def predict(image, model_name): classifier = classifiers[model_name] predicted_class = classifier.predict(image) return predicted_class def classify_preloaded_image(subdirectory, image_name, model_name): image_path = os.path.join(TEST_IMAGE_DIR, subdirectory, image_name) image = Image.open(image_path).convert("RGB") return predict(image, model_name) def display_selected_image(subdirectory, image_name): """Display the selected image.""" image_path = os.path.join(TEST_IMAGE_DIR, subdirectory, image_name) if os.path.exists(image_path): return Image.open(image_path).convert("RGB") return None def classify_uploaded_image(image, model_name): return predict(image, model_name) model_choices = list(model_paths.keys()) # Define Gradio app with gr.Blocks() as demo: gr.Markdown("# Plant Disease Classifier") with gr.Tab("Upload an Image"): with gr.Row(): image_input = gr.Image(type="pil", label="Upload an Image") model_input_upload = gr.Dropdown(choices=model_choices, label="Select Model", value="resnet") classify_button_upload = gr.Button("Classify") output_text_upload = gr.Textbox(label="Predicted Class") classify_button_upload.click(classify_uploaded_image, inputs=[image_input, model_input_upload], outputs=output_text_upload) with gr.Tab("Select a Preloaded Image"): with gr.Row(): subdir_dropdown = gr.Dropdown(choices=get_subdirectories(TEST_IMAGE_DIR), label="Select a Subdirectory") image_dropdown = gr.Dropdown(choices=[], label="Select an Image") model_input_preloaded = gr.Dropdown(choices=model_choices, label="Select Model", value="resnet") with gr.Row(): image_display = gr.Image(label="Selected Image", interactive=False) classify_button_preloaded = gr.Button("Classify") output_text_preloaded = gr.Textbox(label="Predicted Class") # Update image dropdown based on selected subdirectory def update_images(subdirectory): return gr.update(choices=get_images_in_subdirectory(subdirectory)) subdir_dropdown.change(update_images, inputs=subdir_dropdown, outputs=image_dropdown) # Update displayed image based on selected image image_dropdown.change(display_selected_image, inputs=[subdir_dropdown, image_dropdown], outputs=image_display) classify_button_preloaded.click( classify_preloaded_image, inputs=[subdir_dropdown, image_dropdown, model_input_preloaded], outputs=output_text_preloaded ) demo.launch()