rishabh5752 commited on
Commit
ad887e6
1 Parent(s): b1a6758

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -56
app.py CHANGED
@@ -1,70 +1,47 @@
1
  import streamlit as st
 
2
  from PIL import Image
3
- import torch
4
- from torchvision import models, transforms
5
 
6
- # Load the pre-trained model
7
- model = models.densenet121(pretrained=True)
8
- model.eval()
 
9
 
10
- # Define the image transformations
11
- transform = transforms.Compose([
12
- transforms.Resize(256),
13
- transforms.CenterCrop(224),
14
- transforms.ToTensor(),
15
- transforms.Normalize(
16
- mean=[0.485, 0.456, 0.406],
17
- std=[0.229, 0.224, 0.225]
18
- ),
19
- ])
20
 
21
- # Define the class labels
22
- class_labels = ['Normal', 'Pneumonia']
23
 
24
- # Create a function to preprocess the image
25
- def preprocess_image(image):
26
- # Convert the image to RGB
27
- image = image.convert('RGB')
28
 
29
- # Resize the image to match the model's input shape
30
- image = image.resize((224, 224))
31
-
32
- # Convert the image to a tensor
33
- image_tensor = transform(image)
34
-
35
- # Add a batch dimension
36
- image_tensor = image_tensor.unsqueeze(0)
37
 
38
- return image_tensor
 
 
39
 
40
- # Create a function to make predictions
41
- def predict(image):
42
- # Preprocess the image
43
- preprocessed_image = preprocess_image(image)
44
 
45
- # Make the prediction
46
- with torch.no_grad():
47
- output = model(preprocessed_image)
48
- _, predicted_idx = torch.max(output, 1)
49
- predicted_label = class_labels[predicted_idx.item()]
50
-
51
- return predicted_label
52
 
53
- # Create the Streamlit app
54
- def main():
55
- st.title("Pneumonia Detection")
56
- st.write("Upload an image and the app will predict if it has pneumonia or not.")
57
-
58
- # Upload and display the image
59
- uploaded_image = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
60
-
61
- if uploaded_image is not None:
62
- image = Image.open(uploaded_image)
63
- st.image(image, caption="Uploaded Image", use_column_width=True)
64
-
65
- # Make a prediction
66
- predicted_label = predict(image)
67
- st.write("Prediction:", predicted_label)
68
 
69
  # Run the app
70
  if __name__ == '__main__':
 
1
  import streamlit as st
2
+ import pickle
3
  from PIL import Image
 
 
4
 
5
+ # Load the pretrained model from the pickle file
6
+ model_filename = 'model.pkl'
7
+ with open(model_filename, 'rb') as file:
8
+ model = pickle.load(file)
9
 
10
+ # Function to make predictions
11
+ def predict_pneumonia(image):
12
+ # Preprocess the image (you may need to resize or normalize it)
13
+ # preprocess_image(image)
 
 
 
 
 
 
14
 
15
+ # Make predictions using the loaded model
16
+ prediction = model.predict(image)
17
 
18
+ return prediction
 
 
 
19
 
20
+ # Streamlit app
21
+ def main():
22
+ # Set app title and layout
23
+ st.title("Pneumonia Detection")
24
+ st.markdown("---")
 
 
 
25
 
26
+ # Add an image uploader
27
+ st.header("Upload Chest X-ray Image")
28
+ uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
29
 
30
+ if uploaded_file is not None:
31
+ # Display the uploaded image
32
+ image = Image.open(uploaded_file)
33
+ st.image(image, caption="Uploaded Image", use_column_width=True)
34
 
35
+ # Make prediction when the user clicks the 'Predict' button
36
+ if st.button("Predict"):
37
+ # Perform prediction
38
+ prediction = predict_pneumonia(image)
 
 
 
39
 
40
+ # Display the prediction
41
+ if prediction == 1:
42
+ st.error("Prediction: Pneumonia detected")
43
+ else:
44
+ st.success("Prediction: No pneumonia detected")
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Run the app
47
  if __name__ == '__main__':