defectdetection / app06.py
nazlicanto's picture
Update app06.py
5a3020e
raw
history blame contribute delete
No virus
1.34 kB
import streamlit as st
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from PIL import Image
import numpy as np
import torch
# Define the model path
model_hub_path = "nazlicanto/model_defectdetection"
# Load the model and preprocessor
model = SegformerForSemanticSegmentation.from_pretrained(model_hub_path)
preprocessor = SegformerImageProcessor.from_pretrained(model_hub_path)
st.title("PCB Defect Detection")
# Upload image in Streamlit
uploaded_file = st.file_uploader("Upload a PCB image", type=["jpg", "png"])
if uploaded_file:
# Preprocess the image
test_image = Image.open(uploaded_file).convert("RGB")
inputs = preprocessor(images=test_image, return_tensors="pt")
# Model inference
with torch.no_grad():
outputs = model(**inputs)
# Post-process
semantic_map = preprocessor.post_process_semantic_segmentation(outputs, target_sizes=[test_image.size[::-1]])[0]
semantic_map = np.uint8(semantic_map)
semantic_map[semantic_map==1] = 255
semantic_map[semantic_map==2] = 195
semantic_map[semantic_map==3] = 135
semantic_map[semantic_map==4] = 75
# Display the results
st.image(test_image, caption="Uploaded Image", use_column_width=True)
st.image(semantic_map, caption="Predicted Defects", use_column_width=True, channels="GRAY")