Spaces:
Runtime error
Runtime error
import io | |
import torch | |
import streamlit as st | |
from PIL import Image | |
from transformers import AutoImageProcessor, SwinForImageClassification | |
class ImageClassifier(object): | |
def __init__(self, model): | |
self.model = model | |
def get_top_5_predictions(self, image): | |
values, indices = torch.topk(self.get_output_probabilities(image), 5) | |
return [ | |
{'label': self.model.config.id2label[i.item()], 'score': v.item()} | |
for i, v in zip(indices, values) | |
] | |
def get_output_probabilities(self, image): | |
output = self.classify_image(image) | |
return torch.nn.functional.softmax(output.logits[0], dim=0) | |
def classify_image(self, image): | |
image_processor = self.create_image_processor() | |
inputs = image_processor(image, return_tensors='pt') | |
with torch.no_grad(): | |
return self.model(**inputs) | |
def create_image_processor(self): | |
return AutoImageProcessor.from_pretrained(self.model.name_or_path) | |
class ImageClassificationApp(object): | |
def __init__(self, title, classifier): | |
self.title = title | |
self.classifier = classifier | |
def render(self): | |
st.title(self.title) | |
uploaded_image = self.get_uploaded_image() | |
if uploaded_image is not None: | |
self.show_image_and_results(uploaded_image) | |
def get_uploaded_image(self): | |
return st.file_uploader('Choose an image...', type=['jpg', 'png', 'jpeg']) | |
def show_image_and_results(self, uploaded_image): | |
self.show_uploaded_image(uploaded_image) | |
self.show_classification_results(self.get_image(uploaded_image.read())) | |
def show_uploaded_image(self, uploaded_image): | |
st.image(uploaded_image, caption='Uploaded Image', use_column_width=True) | |
def show_classification_results(self, image): | |
st.subheader('Classification Results:') | |
self.write_top_5_predictions(image) | |
def write_top_5_predictions(self, image): | |
for prediction in self.classifier.get_top_5_predictions(image): | |
st.write(f"- {prediction['label']}: {prediction['score']:.4f}") | |
def get_image(self, image_data): | |
return Image.open(io.BytesIO(image_data)) | |
if __name__ == '__main__': | |
model = SwinForImageClassification.from_pretrained( | |
'microsoft/swin-tiny-patch4-window7-224' | |
) | |
classifier = ImageClassifier(model) | |
ImageClassificationApp( | |
'Swin Image Classification App', | |
classifier | |
).render() | |