added conv block
Browse files
app.py
CHANGED
@@ -1,117 +1,3 @@
|
|
1 |
-
# # detect.py
|
2 |
-
# import torch
|
3 |
-
# import torchvision.transforms as transforms
|
4 |
-
# from torchvision.models import resnet50
|
5 |
-
# from PIL import Image
|
6 |
-
# import torch.nn as nn
|
7 |
-
|
8 |
-
# # Define the class names - make sure these match your training classes
|
9 |
-
# CLASS_NAMES = [
|
10 |
-
# "Apple___Apple_scab",
|
11 |
-
# "Apple___Black_rot",
|
12 |
-
# # Add all your class names here...
|
13 |
-
# ]
|
14 |
-
|
15 |
-
# def load_model(model_path):
|
16 |
-
# # Initialize the model architecture
|
17 |
-
# model = resnet50(pretrained=False)
|
18 |
-
# num_classes = len(CLASS_NAMES)
|
19 |
-
# model.fc = nn.Linear(model.fc.in_features, num_classes)
|
20 |
-
|
21 |
-
# # Load the state dict
|
22 |
-
# state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
23 |
-
# model.load_state_dict(state_dict)
|
24 |
-
# model.eval()
|
25 |
-
# return model
|
26 |
-
|
27 |
-
# def predict_image(image_path, model):
|
28 |
-
# """Predict the class of a given image"""
|
29 |
-
# # Define the same transform as used during training
|
30 |
-
# transform = transforms.Compose([
|
31 |
-
# transforms.Resize((224, 224)),
|
32 |
-
# transforms.ToTensor(),
|
33 |
-
# ])
|
34 |
-
|
35 |
-
# # Load and preprocess the image
|
36 |
-
# image = Image.open(image_path).convert('RGB')
|
37 |
-
# image_tensor = transform(image).unsqueeze(0)
|
38 |
-
|
39 |
-
# # Make prediction
|
40 |
-
# with torch.no_grad():
|
41 |
-
# outputs = model(image_tensor)
|
42 |
-
# _, predicted = torch.max(outputs, 1)
|
43 |
-
|
44 |
-
# return CLASS_NAMES[predicted.item()]
|
45 |
-
|
46 |
-
# # streamlit_app.py
|
47 |
-
# import streamlit as st
|
48 |
-
# import torch
|
49 |
-
# import torchvision.transforms as transforms
|
50 |
-
# from PIL import Image
|
51 |
-
# import os
|
52 |
-
# from detect import load_model, predict_image, CLASS_NAMES
|
53 |
-
|
54 |
-
# # Set page config
|
55 |
-
# st.set_page_config(page_title="Plant Disease Predictor", page_icon="🍃", layout="wide")
|
56 |
-
|
57 |
-
# # Load the model
|
58 |
-
# @st.cache_resource
|
59 |
-
# def load_model_cached():
|
60 |
-
# model_path = 'models/leaf_disease_res50_model_epoch_10.pth'
|
61 |
-
# model = load_model(model_path)
|
62 |
-
# return model
|
63 |
-
|
64 |
-
# # Load model at startup
|
65 |
-
# model = load_model_cached()
|
66 |
-
|
67 |
-
# # Streamlit app
|
68 |
-
# st.title("Plant Disease Predictor")
|
69 |
-
# st.write("Upload an image of a plant leaf to predict if it has a disease.")
|
70 |
-
|
71 |
-
# uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
72 |
-
|
73 |
-
# if uploaded_file is not None:
|
74 |
-
# image = Image.open(uploaded_file).convert('RGB')
|
75 |
-
# st.image(image, caption='Uploaded Image', use_column_width=True)
|
76 |
-
|
77 |
-
# if st.button('Predict'):
|
78 |
-
# # Show prediction in progress
|
79 |
-
# with st.spinner('Analyzing image...'):
|
80 |
-
# # Save the uploaded file temporarily
|
81 |
-
# with open("temp_image.jpg", "wb") as f:
|
82 |
-
# f.write(uploaded_file.getbuffer())
|
83 |
-
|
84 |
-
# # Make prediction
|
85 |
-
# prediction = predict_image("temp_image.jpg", model)
|
86 |
-
|
87 |
-
# # Remove temporary file
|
88 |
-
# os.remove("temp_image.jpg")
|
89 |
-
|
90 |
-
# # Display result
|
91 |
-
# st.success(f"Prediction: {prediction}")
|
92 |
-
|
93 |
-
# # Display confidence scores
|
94 |
-
# transform = transforms.Compose([
|
95 |
-
# transforms.Resize((224, 224)), # Match the training size
|
96 |
-
# transforms.ToTensor(),
|
97 |
-
# ])
|
98 |
-
|
99 |
-
# with torch.no_grad():
|
100 |
-
# img_tensor = transform(image).unsqueeze(0)
|
101 |
-
# outputs = model(img_tensor)
|
102 |
-
# probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
|
103 |
-
|
104 |
-
# # Display top 5 predictions
|
105 |
-
# top5_prob, top5_catid = torch.topk(probabilities, 5)
|
106 |
-
# st.write("Top 5 Predictions:")
|
107 |
-
# for i in range(top5_prob.size(0)):
|
108 |
-
# st.write(f"{CLASS_NAMES[top5_catid[i]]}: {top5_prob[i].item()*100:.2f}%")
|
109 |
-
|
110 |
-
# # Display list of detectable diseases
|
111 |
-
# st.write("## List of Detectable Plant Diseases")
|
112 |
-
# st.write("This model can detect the following plant diseases:")
|
113 |
-
# for disease in CLASS_NAMES:
|
114 |
-
# st.write(f"- {disease.replace('___', ' - ')}")
|
115 |
|
116 |
import gradio as gr
|
117 |
import torch
|
@@ -121,6 +7,33 @@ import json
|
|
121 |
import os
|
122 |
# from leaf_disease_predict import ResNet9, load_model, predict_image, CLASS_NAMES
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
class ResNet9(ImageClassificationBase):
|
125 |
def __init__(self, in_channels, num_diseases):
|
126 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
import gradio as gr
|
3 |
import torch
|
|
|
7 |
import os
|
8 |
# from leaf_disease_predict import ResNet9, load_model, predict_image, CLASS_NAMES
|
9 |
|
10 |
+
class ImageClassificationBase(torch.nn.Module):
|
11 |
+
def validation_step(self, batch):
|
12 |
+
images, labels = batch
|
13 |
+
out = self(images)
|
14 |
+
loss = torch.nn.functional.cross_entropy(out, labels)
|
15 |
+
acc = accuracy(out, labels)
|
16 |
+
return {"val_loss": loss.detach(), "val_accuracy": acc}
|
17 |
+
|
18 |
+
def validation_epoch_end(self, outputs):
|
19 |
+
batch_losses = [x["val_loss"] for x in outputs]
|
20 |
+
batch_accuracy = [x["val_accuracy"] for x in outputs]
|
21 |
+
epoch_loss = torch.stack(batch_losses).mean()
|
22 |
+
epoch_accuracy = torch.stack(batch_accuracy).mean()
|
23 |
+
return {"val_loss": epoch_loss, "val_accuracy": epoch_accuracy}
|
24 |
+
|
25 |
+
def epoch_end(self, epoch, result):
|
26 |
+
print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
|
27 |
+
epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_accuracy']))
|
28 |
+
|
29 |
+
def ConvBlock(in_channels, out_channels, pool=False):
|
30 |
+
layers = [torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
31 |
+
torch.nn.BatchNorm2d(out_channels),
|
32 |
+
torch.nn.ReLU(inplace=True)]
|
33 |
+
if pool:
|
34 |
+
layers.append(torch.nn.MaxPool2d(4))
|
35 |
+
return torch.nn.Sequential(*layers)
|
36 |
+
|
37 |
class ResNet9(ImageClassificationBase):
|
38 |
def __init__(self, in_channels, num_diseases):
|
39 |
super().__init__()
|