Spaces:
Running
Running
import streamlit as st | |
from streamlit_image_select import image_select | |
import zip_files | |
import random | |
import logging | |
from huggingface_hub import from_pretrained_fastai | |
def get_model(): | |
repo_id = "danbiagini/hockey_breeds" | |
return from_pretrained_fastai(repo_id) | |
def classify_image(learn, img): | |
categories = ('Hockey Goalie', 'Hockey Player', "Hockey Referee") | |
pred,idx,prob = learn.predict(img) | |
return dict(zip(categories, map(float, prob))) | |
def reroll_sample_images(): | |
# unzip the sample images | |
players = zip_files.extract_files_from_zip("src/images/samples/player-samples.zip") | |
goalies = zip_files.extract_files_from_zip("src/images/samples/goalie-samples.zip") | |
referees = zip_files.extract_files_from_zip("src/images/samples/referee-samples.zip") | |
#randomize a single file from players, goalies and referee for samples | |
st.session_state.sample = dict() | |
st.session_state.sample["player"] = players[list(players.keys())[random.randint(0, len(players) - 1)]] | |
st.session_state.sample["goalie"] = goalies[list(goalies.keys())[random.randint(0, len(goalies) - 1)]] | |
st.session_state.sample["referee"] = referees[list(referees.keys())[random.randint(0, len(referees) - 1)]] | |
if 'sample' not in st.session_state: | |
reroll_sample_images() | |
st.set_page_config(page_title='Hockey Breeds', layout="wide", | |
page_icon=":frame_with_picture:") | |
st.title('Hockey Breeds - Hello Computer Vision') | |
st.subheader('Image Classification') | |
img_class = '''Image Classification in Computer Vision is the act of determining the most appropriate label for an entire image from a set of fixed labels. | |
A popular topic of image classification in Computer Vision introductions and courses is to use an example problem of training a model to label images of various pet breeds. | |
*Hockey Breeds* is a hockey slant on this common theme in Computer Vision educational materials.''' | |
st.markdown(img_class) | |
st.subheader("Hockey Image Classification") | |
desc = '''This "Hockey Breeds" model was built using 50 hockey related images found on the web and in my own collection. I started with a pretrained *ResNet18* model (resnet18 is trained on *ImageNet*, a very large dataset with millions of images). I fine tuned the model by labeling the hockey photos, then training using python (*Fast.ai* & *PyTorch* libraries). | |
The total training time for this was approximately 5 minutes running on a low-end GPU. It’s impressive how accurate this quick / small model can be!''' | |
st.markdown(desc) | |
st.image("src/images/samples/sampl_batch.png") | |
st.subheader("Validation Results") | |
st.markdown('Validation of the model\'s performance was done using 26 images not included in the training set. The model performed fairly well against the validation dataset, with only 1 misclassified image.') | |
st.image("src/images/artifacts/confusion_matrix_v1.png", caption="Confusion Matrix for Hockey Breeds ") | |
st.subheader("Try It Out") | |
img = image_select(label="Select an image and hockey breeds will guess a label. See if you can find some incorrect guesses...", images=list(st.session_state.sample.values())) | |
st.button("Re-roll Samples", on_click=reroll_sample_images) | |
model = get_model() | |
if img: | |
res = classify_image(model, img) | |
# Sort the dictionary items by value in descending order | |
max = 0 | |
max_label = "" | |
for k,v in res.items(): | |
prob = round(v*100, 2) | |
if prob > max: | |
max = prob | |
max_label = k | |
st.metric(label=max_label, value=max) | |