import streamlit as st from text2image import get_model, get_tokenizer, get_image_transform from utils import text_encoder from torchvision import transforms from PIL import Image from jax import numpy as jnp import pandas as pd import numpy as np import requests import psutil import time import jax import gc headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/70.0.3538.102 Safari/537.36 Edge/18.19582" } preprocess = transforms.Compose( [ transforms.ToTensor(), transforms.Resize(224), transforms.Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) ), ] ) def resize_longer(image, longer_size=224): old_size = image.size ratio = float(longer_size) / max(old_size) new_size = tuple([int(x * ratio) for x in old_size]) image = image.resize(new_size, Image.ANTIALIAS) return image def pad_to_square(image): (a,b)=image.shape[:2] if a 50: time.sleep(sleep_time) if not caption or not image_url: st.error("Please choose one image and at least one label") else: with st.spinner( "Computing... This might take up to a few minutes depending on the current load 😕 \n" "Otherwise, you can use this [Colab notebook](https://colab.research.google.com/drive/10neENr1DEAFq_GzsLqBDo0gZ50hOhkOr?usp=sharing)" ): heatmap, image = get_heatmap(image_url, caption, pixel_size, iterations) with col1: st.image(image, use_column_width=True) st.image(heatmap, use_column_width=True) st.image(np.asarray(image) / 255.0 * heatmap, use_column_width=True) gc.collect() elif image_url: image = requests.get( image_url, headers=headers, stream=True, ).raw image = Image.open(image).convert("RGB") with col1: st.image(image)