|
import streamlit as st |
|
from transformers import pipeline |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
import os |
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
|
st.title("Image Classification Web App") |
|
st.markdown("This app uses Hugging Face's 'transformers' library to classify images using pre-trained models. The app uses three different models for image classification: swin, convnext and vit. Please select a model to classify the image you put on the left sidebar.") |
|
|
|
|
|
st.sidebar.markdown("**Please provide a Satellite image for classification**") |
|
|
|
|
|
url = st.sidebar.text_input("Image URL") |
|
if url: |
|
try: |
|
response = requests.get(url) |
|
image = Image.open(BytesIO(response.content)) |
|
st.sidebar.image(image, caption='Uploaded Image', use_column_width=True) |
|
except Exception as e: |
|
st.sidebar.error("Invalid URL. Please enter a valid URL for an image.") |
|
|
|
|
|
uploaded_file = st.sidebar.file_uploader("Or upload an image", type=["jpg", "png"]) |
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file) |
|
st.image(image, caption='Uploaded Image', use_column_width=True) |
|
|
|
|
|
st.sidebar.markdown("## Find more information about the model architecture at the link below : ") |
|
st.sidebar.markdown("*Vision Transformer (ViT)* https://huggingface.co/docs/transformers/main/en/model_doc/vit") |
|
st.sidebar.markdown("*ConvNext Transformer* https://huggingface.co/docs/transformers/main/en/model_doc/convnext") |
|
st.sidebar.markdown("*Swin Transformer* https://huggingface.co/docs/transformers/main/en/model_doc/swin") |
|
|
|
|
|
|
|
def classify_image1(image): |
|
pipe1 = pipeline("image-classification", "SolubleFish/swin_transformer-finetuned-eurosat", token=hf_token) |
|
return pipe1(image) |
|
def classify_image2(image): |
|
pipe2 = pipeline("image-classification", "SolubleFish/image_classification_convnext", token=hf_token) |
|
return pipe2(image) |
|
def classify_image3(image): |
|
pipe3 = pipeline("image-classification", "SolubleFish/image_classification_vit", token=hf_token) |
|
return pipe3(image) |
|
|
|
|
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
|
|
|
if col1.button("Classify Image by Swin"): |
|
if url or uploaded_file: |
|
results = classify_image1(image) |
|
if results: |
|
|
|
for result in results: |
|
col1.markdown(f"Class name: **{result['label']}** \n\n Confidence: **{str(format(result['score']*100, '.2f'))}**"+"%") |
|
col1.success("Classification completed.") |
|
else: |
|
col1.error("No results found.") |
|
else: |
|
col1.error("Please provide an image for classification.") |
|
|
|
|
|
if col2.button("Classify Image by ConvNext"): |
|
if url or uploaded_file: |
|
results = classify_image2(image) |
|
if results: |
|
|
|
for result in results: |
|
col2.markdown(f"Class name: **{result['label']}** \n\n Confidence: **{str(format(result['score']*100, '.2f'))}**"+"%") |
|
col2.success("Classification completed.") |
|
else: |
|
col2.error("No results found.") |
|
else: |
|
col2.error("Please provide an image for classification.") |
|
|
|
|
|
if col3.button("Classify Image by ViT"): |
|
if url or uploaded_file: |
|
results = classify_image3(image) |
|
if results: |
|
|
|
for result in results: |
|
col3.markdown(f"Class name: **{result['label']}** \n\n Confidence: **{str(format(result['score']*100, '.2f'))}**"+"%") |
|
col3.success("Classification completed.") |
|
else: |
|
col3.error("No results found.") |
|
else: |
|
col3.error("Please provide an image for classification.") |