|
import streamlit as st |
|
from PIL import Image |
|
import torch |
|
from torchvision import models, transforms |
|
|
|
|
|
model = models.densenet121(pretrained=True) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225] |
|
), |
|
]) |
|
|
|
|
|
class_labels = ['Normal', 'Pneumonia'] |
|
|
|
|
|
def preprocess_image(image): |
|
|
|
image = image.convert('RGB') |
|
|
|
|
|
image = image.resize((224, 224)) |
|
|
|
|
|
image_tensor = transform(image) |
|
|
|
|
|
image_tensor = image_tensor.unsqueeze(0) |
|
|
|
return image_tensor |
|
|
|
|
|
def predict(image): |
|
|
|
preprocessed_image = preprocess_image(image) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(preprocessed_image) |
|
_, predicted_idx = torch.max(output, 1) |
|
predicted_label = class_labels[predicted_idx.item()] |
|
|
|
return predicted_label |
|
|
|
|
|
def main(): |
|
st.title("Pneumonia Detection") |
|
st.write("Upload an image and the app will predict if it has pneumonia or not.") |
|
|
|
|
|
uploaded_image = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_image is not None: |
|
image = Image.open(uploaded_image) |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
predicted_label = predict(image) |
|
st.write("Prediction:", predicted_label) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|