import streamlit as st
import numpy as np
from PIL import Image
import requests
import ModelClass
from glob import glob
import torch
import torch.nn as nn
import numpy as np
@st.cache_resource
def load_model():
return ModelClass.get_model()
@st.cache_data
def get_images():
l = glob('./inputs/*')
l = {i.split('/')[-1]: i for i in l}
return l
def infer(img):
image = img.convert('RGB')
image = ModelClass.get_transform()(image)
image = image.unsqueeze(dim=0)
model = load_model()
model.eval()
with torch.no_grad():
out = model(image)
out = nn.Softmax()(out).squeeze()
return out
st.set_page_config(
page_title="Whale Identification",
page_icon="🧊",
layout="centered",
initial_sidebar_state="expanded",
menu_items={
'Get Help': 'https://www.extremelycoolapp.com/help',
'Report a bug': "https://www.extremelycoolapp.com/bug",
'About': """
# This is a header. This is an *extremely* cool app!
How how are you doin.
---
I am fine
"""
}
)
# fix sidebar
st.markdown("""
""", unsafe_allow_html=True
)
hide_st_style = """
"""
#st.markdown(hide_st_style, unsafe_allow_html=True)
def predict(image):
# Dummy prediction
classes = ['cat', 'dog']
prediction = np.random.rand(len(classes))
prediction /= np.sum(prediction)
return dict(zip(classes, prediction))
def app():
st.title('ActionNet')
# st.markdown("[![View in W&B](https://img.shields.io/badge/View%20in-W%26B-blue)](https://wandb.ai//?workspace=user-)")
st.markdown('Human Action Recognition using CNN: A Conputer Vision project that trains a ResNet model to classify human activities. The dataset contains 15 activity classes, and the model predicts the activity from input images.')
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
test_images = get_images()
test_image = st.selectbox('Or choose a test image', list(test_images.keys()))
st.markdown('#### Selected Image')
left_column, right_column = st.columns([1.5, 2.5], gap="medium")
with left_column:
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, use_column_width=True)
else:
image_url = test_images[test_image]
image = Image.open(image_url)
st.image(image, use_column_width=True)
if st.button('✨ Get prediction from AI', type='primary'):
spacer = st.empty()
res = infer(image)
prob = res.numpy()
idx = np.argpartition(prob, -4)[-4:]
right_column.markdown('#### Results')
idx = list(idx)
for i in idx:
class_name = ModelClass.get_class(i).replace('_', ' ').capitalize()
class_probability = prob[i].astype(float)
right_column.write(f'{class_name}: {class_probability:.2%}')
right_column.progress(class_probability)
st.markdown("---")
st.markdown("Built by [Shamim Ahamed](https://www.shamimahamed.com/). Data provided by [aiplanet](https://aiplanet.com/challenges/data-sprint-76-human-activity-recognition/233/overview/about)")
app()