import streamlit as st import numpy as np from PIL import Image import requests import ModelClass from glob import glob import torch import torch.nn as nn import numpy as np @st.cache_resource def load_model(): return ModelClass.get_model() @st.cache_data def get_images(): l = glob('./inputs/*') l = {i.split('/')[-1]: i for i in l} return l def infer(img): image = img.convert('RGB') image = ModelClass.get_transform()(image) image = image.unsqueeze(dim=0) model = load_model() model.eval() with torch.no_grad(): out = model(image) out = nn.Softmax()(out).squeeze() return out st.set_page_config( page_title="Whale Identification", page_icon="🧊", layout="centered", initial_sidebar_state="expanded", menu_items={ 'Get Help': 'https://www.extremelycoolapp.com/help', 'Report a bug': "https://www.extremelycoolapp.com/bug", 'About': """ # This is a header. This is an *extremely* cool app! How how are you doin. --- I am fine """ } ) # fix sidebar st.markdown(""" """, unsafe_allow_html=True ) hide_st_style = """ """ #st.markdown(hide_st_style, unsafe_allow_html=True) def predict(image): # Dummy prediction classes = ['cat', 'dog'] prediction = np.random.rand(len(classes)) prediction /= np.sum(prediction) return dict(zip(classes, prediction)) def app(): st.title('ActionNet') # st.markdown("[![View in W&B](https://img.shields.io/badge/View%20in-W%26B-blue)](https://wandb.ai//?workspace=user-)") st.markdown('Human Action Recognition using CNN: A Conputer Vision project that trains a ResNet model to classify human activities. The dataset contains 15 activity classes, and the model predicts the activity from input images.') uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) test_images = get_images() test_image = st.selectbox('Or choose a test image', list(test_images.keys())) st.markdown('#### Selected Image') left_column, right_column = st.columns([1.5, 2.5], gap="medium") with left_column: if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, use_column_width=True) else: image_url = test_images[test_image] image = Image.open(image_url) st.image(image, use_column_width=True) if st.button('✨ Get prediction from AI', type='primary'): spacer = st.empty() res = infer(image) prob = res.numpy() idx = np.argpartition(prob, -4)[-4:] right_column.markdown('#### Results') idx = list(idx) for i in idx: class_name = ModelClass.get_class(i).replace('_', ' ').capitalize() class_probability = prob[i].astype(float) right_column.write(f'{class_name}: {class_probability:.2%}') right_column.progress(class_probability) st.markdown("---") st.markdown("Built by [Shamim Ahamed](https://www.shamimahamed.com/). Data provided by [aiplanet](https://aiplanet.com/challenges/data-sprint-76-human-activity-recognition/233/overview/about)") app()