import io import torch import streamlit as st from PIL import Image from transformers import AutoImageProcessor, SwinForImageClassification class ImageClassifier(object): def __init__(self, model): self.model = model def get_top_5_predictions(self, image): values, indices = torch.topk(self.get_output_probabilities(image), 5) return [ {'label': self.model.config.id2label[i.item()], 'score': v.item()} for i, v in zip(indices, values) ] def get_output_probabilities(self, image): output = self.classify_image(image) return torch.nn.functional.softmax(output.logits[0], dim=0) def classify_image(self, image): image_processor = self.create_image_processor() inputs = image_processor(image, return_tensors='pt') with torch.no_grad(): return self.model(**inputs) def create_image_processor(self): return AutoImageProcessor.from_pretrained(self.model.name_or_path) class ImageClassificationApp(object): def __init__(self, title, classifier): self.title = title self.classifier = classifier def render(self): st.title(self.title) uploaded_image = self.get_uploaded_image() if uploaded_image is not None: self.show_image_and_results(uploaded_image) def get_uploaded_image(self): return st.file_uploader('Choose an image...', type=['jpg', 'png', 'jpeg']) def show_image_and_results(self, uploaded_image): self.show_uploaded_image(uploaded_image) self.show_classification_results(self.get_image(uploaded_image.read())) def show_uploaded_image(self, uploaded_image): st.image(uploaded_image, caption='Uploaded Image', use_column_width=True) def show_classification_results(self, image): st.subheader('Classification Results:') self.write_top_5_predictions(image) def write_top_5_predictions(self, image): for prediction in self.classifier.get_top_5_predictions(image): st.write(f"- {prediction['label']}: {prediction['score']:.4f}") def get_image(self, image_data): return Image.open(io.BytesIO(image_data)) if __name__ == '__main__': model = SwinForImageClassification.from_pretrained( 'microsoft/swin-tiny-patch4-window7-224' ) classifier = ImageClassifier(model) ImageClassificationApp( 'Swin Image Classification App', classifier ).render()