4darsh-Dev commited on
Commit
965e7ba
·
verified ·
1 Parent(s): c4b9624

added conv block

Browse files
Files changed (1) hide show
  1. app.py +27 -114
app.py CHANGED
@@ -1,117 +1,3 @@
1
- # # detect.py
2
- # import torch
3
- # import torchvision.transforms as transforms
4
- # from torchvision.models import resnet50
5
- # from PIL import Image
6
- # import torch.nn as nn
7
-
8
- # # Define the class names - make sure these match your training classes
9
- # CLASS_NAMES = [
10
- # "Apple___Apple_scab",
11
- # "Apple___Black_rot",
12
- # # Add all your class names here...
13
- # ]
14
-
15
- # def load_model(model_path):
16
- # # Initialize the model architecture
17
- # model = resnet50(pretrained=False)
18
- # num_classes = len(CLASS_NAMES)
19
- # model.fc = nn.Linear(model.fc.in_features, num_classes)
20
-
21
- # # Load the state dict
22
- # state_dict = torch.load(model_path, map_location=torch.device('cpu'))
23
- # model.load_state_dict(state_dict)
24
- # model.eval()
25
- # return model
26
-
27
- # def predict_image(image_path, model):
28
- # """Predict the class of a given image"""
29
- # # Define the same transform as used during training
30
- # transform = transforms.Compose([
31
- # transforms.Resize((224, 224)),
32
- # transforms.ToTensor(),
33
- # ])
34
-
35
- # # Load and preprocess the image
36
- # image = Image.open(image_path).convert('RGB')
37
- # image_tensor = transform(image).unsqueeze(0)
38
-
39
- # # Make prediction
40
- # with torch.no_grad():
41
- # outputs = model(image_tensor)
42
- # _, predicted = torch.max(outputs, 1)
43
-
44
- # return CLASS_NAMES[predicted.item()]
45
-
46
- # # streamlit_app.py
47
- # import streamlit as st
48
- # import torch
49
- # import torchvision.transforms as transforms
50
- # from PIL import Image
51
- # import os
52
- # from detect import load_model, predict_image, CLASS_NAMES
53
-
54
- # # Set page config
55
- # st.set_page_config(page_title="Plant Disease Predictor", page_icon="🍃", layout="wide")
56
-
57
- # # Load the model
58
- # @st.cache_resource
59
- # def load_model_cached():
60
- # model_path = 'models/leaf_disease_res50_model_epoch_10.pth'
61
- # model = load_model(model_path)
62
- # return model
63
-
64
- # # Load model at startup
65
- # model = load_model_cached()
66
-
67
- # # Streamlit app
68
- # st.title("Plant Disease Predictor")
69
- # st.write("Upload an image of a plant leaf to predict if it has a disease.")
70
-
71
- # uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
72
-
73
- # if uploaded_file is not None:
74
- # image = Image.open(uploaded_file).convert('RGB')
75
- # st.image(image, caption='Uploaded Image', use_column_width=True)
76
-
77
- # if st.button('Predict'):
78
- # # Show prediction in progress
79
- # with st.spinner('Analyzing image...'):
80
- # # Save the uploaded file temporarily
81
- # with open("temp_image.jpg", "wb") as f:
82
- # f.write(uploaded_file.getbuffer())
83
-
84
- # # Make prediction
85
- # prediction = predict_image("temp_image.jpg", model)
86
-
87
- # # Remove temporary file
88
- # os.remove("temp_image.jpg")
89
-
90
- # # Display result
91
- # st.success(f"Prediction: {prediction}")
92
-
93
- # # Display confidence scores
94
- # transform = transforms.Compose([
95
- # transforms.Resize((224, 224)), # Match the training size
96
- # transforms.ToTensor(),
97
- # ])
98
-
99
- # with torch.no_grad():
100
- # img_tensor = transform(image).unsqueeze(0)
101
- # outputs = model(img_tensor)
102
- # probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
103
-
104
- # # Display top 5 predictions
105
- # top5_prob, top5_catid = torch.topk(probabilities, 5)
106
- # st.write("Top 5 Predictions:")
107
- # for i in range(top5_prob.size(0)):
108
- # st.write(f"{CLASS_NAMES[top5_catid[i]]}: {top5_prob[i].item()*100:.2f}%")
109
-
110
- # # Display list of detectable diseases
111
- # st.write("## List of Detectable Plant Diseases")
112
- # st.write("This model can detect the following plant diseases:")
113
- # for disease in CLASS_NAMES:
114
- # st.write(f"- {disease.replace('___', ' - ')}")
115
 
116
  import gradio as gr
117
  import torch
@@ -121,6 +7,33 @@ import json
121
  import os
122
  # from leaf_disease_predict import ResNet9, load_model, predict_image, CLASS_NAMES
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  class ResNet9(ImageClassificationBase):
125
  def __init__(self, in_channels, num_diseases):
126
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  import gradio as gr
3
  import torch
 
7
  import os
8
  # from leaf_disease_predict import ResNet9, load_model, predict_image, CLASS_NAMES
9
 
10
+ class ImageClassificationBase(torch.nn.Module):
11
+ def validation_step(self, batch):
12
+ images, labels = batch
13
+ out = self(images)
14
+ loss = torch.nn.functional.cross_entropy(out, labels)
15
+ acc = accuracy(out, labels)
16
+ return {"val_loss": loss.detach(), "val_accuracy": acc}
17
+
18
+ def validation_epoch_end(self, outputs):
19
+ batch_losses = [x["val_loss"] for x in outputs]
20
+ batch_accuracy = [x["val_accuracy"] for x in outputs]
21
+ epoch_loss = torch.stack(batch_losses).mean()
22
+ epoch_accuracy = torch.stack(batch_accuracy).mean()
23
+ return {"val_loss": epoch_loss, "val_accuracy": epoch_accuracy}
24
+
25
+ def epoch_end(self, epoch, result):
26
+ print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
27
+ epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_accuracy']))
28
+
29
+ def ConvBlock(in_channels, out_channels, pool=False):
30
+ layers = [torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
31
+ torch.nn.BatchNorm2d(out_channels),
32
+ torch.nn.ReLU(inplace=True)]
33
+ if pool:
34
+ layers.append(torch.nn.MaxPool2d(4))
35
+ return torch.nn.Sequential(*layers)
36
+
37
  class ResNet9(ImageClassificationBase):
38
  def __init__(self, in_channels, num_diseases):
39
  super().__init__()