import os import subprocess import sys import cv2 import gdown from PIL import Image import numpy as np import streamlit as st import torch from torchvision import transforms def setup_env(path='Variations-of-SFANet-for-Crowd-Counting'): if os.path.exists(path): return path subprocess.run( [ 'git', 'clone', f'https://github.com/Pongpisit-Thanasutives/{path}.git', f'{path}', ], capture_output=True, check=True, ) sys.path.append(path) with open(os.path.join(path, 'models', '__init__.py'), 'w') as f: f.write('') return path def get_model(path, weights): from models import M_SFANet_UCF_QNRF model = M_SFANet_UCF_QNRF.Model() model.load_state_dict( torch.load(weights, map_location=torch.device('cpu'))) return model.eval() def download_weights( url='https://drive.google.com/uc?id=1fGuH4o0hKbgdP1kaj9rbjX2HUL1IH0oo', out="Paper's_weights_UCF_QNRF.zip", ): weights = "Paper's_weights_UCF_QNRF/best_M-SFANet*_UCF_QNRF.pth" if os.path.exists(weights): return weights gdown.download(url, out) subprocess.run( ['unzip', out], capture_output=True, check=True, ) return weights def transform_image(img): trans = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) height, width = img.size[1], img.size[0] height = round(height / 16) * 16 width = round(width / 16) * 16 img = cv2.resize(np.array(img), (width, height), cv2.INTER_CUBIC) return trans(Image.fromarray(img))[None, :] def main(): st.write("Demo of [Encoder-Decoder Based Convolutional Neural Networks with Multi-Scale-Aware Modules for Crowd Counting](https://arxiv.org/abs/2003.05586)") # noqa path = setup_env() weights = download_weights() model = get_model(path, weights) image_file = st.file_uploader( "Upload image", type=['png', 'jpg', 'jpeg']) if image_file is not None: image = Image.open(image_file).convert('RGB') st.image(image) density_map = model(transform_image(image)) density_map_img = density_map.detach().numpy()[0].transpose(1, 2, 0) st.image(density_map_img / density_map_img.max()) st.write("Estimated count: ", torch.sum(density_map).item()) else: st.write("Example image to use that you can drag and drop:") st.image(Image.open('crowd.jpg').convert('RGB')) main()