Dusduo's picture
small change for base image
6b216c6
raw
history blame
8.31 kB
#importing the libraries
import streamlit as st
from PIL import Image
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
import numpy as np
import pandas as pd
import time
import os
model_repository_id = "Dusduo/Pokemon-classification-1stGen"
# Loading the pokemon classifier model and its processor
image_processor = AutoImageProcessor.from_pretrained(model_repository_id)
model = AutoModelForImageClassification.from_pretrained(model_repository_id)
# Loading the pokemon information table
pokemon_info_df = pd.read_csv('pokemon_info.csv')
pokeball_image = Image.open('pokeball.png').resize((20,20))
#functions to predict image
def preprocess(processor: AutoImageProcessor, image):
return processor(image.convert("RGB").resize((200,200)), return_tensors="pt")
def predict(model: AutoModelForImageClassification, inputs, k=5):
# Forward the image to the model and retrieve the logits
with torch.no_grad():
logits = model(**inputs).logits
# Convert the retrieved logits into a vector of probabilities for each class
probabilities = torch.softmax(logits[0], dim=0).tolist()
# Discriminate wether or not the inputted image was an image of a Pokemon
# Compute the variance of the vector of probabilities
# The spread of the probability values is a good represent of the confusion of the model
# Or in other words, its confidence => the greater the spread, the lower its confidence
variance = np.var(probabilities)
# Too great of a spread: it is likely the image provided did not correspond to any known classes
if variance < 0.001: #not a pokemon
predicted_label = 'not a pokemon'
probability = -1
(top_k_labels, top_k_probability) = '_', '_'
else: # it is a pokemon
# Retrieve the predicted class (pokemon)
predicted_id = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_id]
# Retrieve the probability for the predicted class, and format it to 2 decimals
probability = round(probabilities[predicted_id]*100,2)
# Retrieve the top 5 classes and their probabilities
#top_k_labels = [model.config.id2label[key] for key in np.argpartition(logits.numpy(), -k)[-k:]]
#top_k_probability = [round(prob*100,2) for prob in np.sort(probabilities.numpy())[-k:]]
return predicted_label, probability #, (top_k_labels, top_k_probability)
# Designing the interface ------------------------------------------
# Use the full page instead of a narrow central column
st.set_page_config(layout="wide")
# Define the title
st.title("Gotta Classify 'Em All - 1st Generation Pokedex -")
# For newline
st.write('\n')
col1, col2 = st.columns([3,1]) # [3,1]
with col1:
image = Image.open('base.jpg')
show = st.image(image, use_column_width=True)
# Display Sample images ----
st.subheader('Sample images')
sample_imgs_dir = "sample_imgs/"
sample_imgs = os.listdir(sample_imgs_dir) # get the list of all sample images
img_idx = 0
n_cols = 4
groups = []
for i in range(0, len(sample_imgs), n_cols):
groups.append(sample_imgs[i:i+n_cols])
for group in groups:
cols = st.columns(n_cols)
for i,image_file in enumerate(group):
cols[i].image(sample_imgs_dir+image_file)
# Sidebar work and model outputs ---------------
st.sidebar.title("Upload Image")
#Disabling warning
#st.set_option('deprecation.showfileUploaderEncoding', False)
#Choose your own image
uploaded_file = st.sidebar.file_uploader("",type=['png', 'jpg', 'jpeg'], accept_multiple_files=False )
if uploaded_file is not None:
u_img = Image.open(uploaded_file)
show.image(u_img, 'Uploaded Image', width=400 )#use_column_width=True)
# Preprocess the image for the model
model_inputs = preprocess(image_processor, u_img)
# For newline
st.sidebar.write('\n')
if st.sidebar.button("Click Here to Classify"):
if uploaded_file is None:
st.sidebar.write("Please upload an Image to Classify")
else:
with st.spinner('Classifying ...'):
# Get prediction
prediction, probability = predict(model, model_inputs,5) #, (top_k_labels, top_k_probability)
time.sleep(2)
st.sidebar.success('Done!')
st.sidebar.header("Model predicts: ")
# Display prediction
if probability==-1:
st.sidebar.write("It seems like it is not a picture of a 1st Generation Pokemon alone.", '\n',
"There might be too many entities on the image." )
else:
st.sidebar.write(f" It's a(n) {prediction} picture.",'\n')
st.sidebar.write('Probability:',probability,'%')
# Retrieve predicted pokemon information
_, pokedex_number, english_name, romaji_name, katakana_name, weight_kg, height_m, type1, type2, color1, color2, classification, evolve_from, evolve_into, is_legendary = pokemon_info_df[pokemon_info_df['name']==prediction].values[0]
with col2:
# pokedex box
with st.container(border=True ):
# first row
with st.container():
pokeball_image_col,pokedex_number_col, pokemon_name_col = st.columns([1,1,8])
pokeball_image_col.image(pokeball_image)
pokedex_number_col.markdown(f'<div style="text-align: left; font-size: 1.4rem;"><b>{pokedex_number}</b></div>', unsafe_allow_html=True)
pokemon_name_col.markdown(f'<div style="text-align: right; font-size: 1.4rem;"><b>{english_name}</b></div>', unsafe_allow_html=True)
# second row
with st.container():
st.markdown(f'<div style="text-align: center; color: {color1}; font-size: 1.2rem;"><b>{classification}</b></div>', unsafe_allow_html=True)
# 3rd row
with st.container():
if pd.isna(type2):
st.write('\n')
st.markdown(f'<div style="display: flex; justify-content: center; align-items: center; "><div style="display: inline-block; padding: 5px; margin: 0 5px; border-radius: 5px; background-color: {color1}; color: white;">{type1}</div>', unsafe_allow_html=True)
else:
type1_col, type2_col = st.columns(2)
type1_col.markdown(f'<div style="display: flex; justify-content: center; align-items: center;"><div style="display: inline-block; padding: 5px; margin: 0 5px; border-radius: 5px; background-color: {color1}; color: white;">{type1}</div>', unsafe_allow_html=True)
type2_col.markdown(f'<div style="display: flex; justify-content: center; align-items: center;"><div style="display: inline-block; padding: 5px; margin: 0 5px; border-radius: 5px; background-color: {color2}; color: white;">{type2}</div>', unsafe_allow_html=True)
st.write('\n')
# 4th row
with st.container():
st.write(f'<div style=font-size: 1.4rem;><b>Height:</b> {height_m}m', unsafe_allow_html=True)
st.write('\n')
st.write(f'<div style=font-size: 1.4rem;><b>Weight:</b> {weight_kg}kg', unsafe_allow_html=True)
st.write('\n')
if not pd.isna(evolve_from):
st.markdown(f'<div style=font-size: 1.4rem;><b>Evolves from:</b> {evolve_from}', unsafe_allow_html=True)
#st.write(f'Evolves from: {evolve_from}')
st.write('\n')
if not pd.isna(evolve_into):
st.markdown(f'<div style=font-size: 1.4rem;><b>Evolves into:</b> {evolve_into}', unsafe_allow_html=True)
#st.write(f'Evolves into: {evolve_into}')
st.write('\n')