| import streamlit as st |
| from PIL import Image |
| import torch |
| import torchvision.transforms as transforms |
| import json |
| import sys |
| import os |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
| |
| from model.model import ResNet50 |
|
|
| |
| @st.cache_data |
| def load_class_names(): |
| try: |
| with open("imagenet_classes.json", 'r', encoding='utf-8') as f: |
| content = f.read() |
| content = ''.join(char for char in content if ord(char) >= 32 or char in '\n\r\t') |
| class_names = json.loads(content) |
| return class_names |
| except Exception as e: |
| st.error(f"Error loading class names: {str(e)}") |
| return {} |
|
|
| |
| @st.cache_resource |
| def load_model(): |
| try: |
| model = ResNet50(num_classes=1000) |
| checkpoint = torch.load("./checkpoints/model_best.pth", map_location=torch.device("cpu")) |
| if "model_state_dict" in checkpoint: |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| else: |
| st.error("Invalid model checkpoint format") |
| return None |
| model.eval() |
| return model |
| except Exception as e: |
| st.error(f"Error loading model: {str(e)}") |
| return None |
|
|
| |
| def preprocess_image(image): |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
| return transform(image).unsqueeze(0) |
|
|
| |
| st.title("Image Classification with ResNet50") |
| class_names = load_class_names() |
| model = load_model() |
|
|
| |
| if model is None: |
| st.error("Failed to load the model. Please check the model file.") |
| st.stop() |
|
|
| |
| if 'show_upload' not in st.session_state: |
| st.session_state.show_upload = True |
|
|
| |
| main_container = st.empty() |
|
|
| with main_container.container(): |
| if st.session_state.show_upload: |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
| |
| if uploaded_file: |
| |
| image = Image.open(uploaded_file).convert("RGB") |
| |
| col1, col2 = st.columns(2) |
| |
| with col1: |
| st.markdown("### Uploaded Image") |
| st.image(image, use_container_width=True) |
| |
| with col2: |
| st.markdown("### Predictions") |
| |
| input_tensor = preprocess_image(image) |
| with torch.no_grad(): |
| outputs = model(input_tensor) |
| probabilities = torch.nn.functional.softmax(outputs[0], dim=0) |
| top5_prob, top5_idx = torch.topk(probabilities, 5) |
| |
| results = [] |
| for i in range(5): |
| class_id = top5_idx[i].item() |
| prob = top5_prob[i].item() * 100 |
| class_name = class_names[str(class_id)] |
| results.append({ |
| "Rank": i + 1, |
| "Class": class_name, |
| "Confidence": f"{prob:.2f}%" |
| }) |
| st.table(results) |
| |
| |
| st.markdown("<br>", unsafe_allow_html=True) |
| col1, col2, col3 = st.columns([2, 1, 2]) |
| with col2: |
| if st.button("↻ New Image"): |
| main_container.empty() |
| st.session_state.show_upload = True |
| st.rerun() |