import os import jax import jax.numpy as jnp import numpy as np import requests import streamlit as st from PIL import Image from utils import load_model def split_image(im, num_rows=3, num_cols=3): im = np.array(im) row_size = im.shape[0] // num_rows col_size = im.shape[1] // num_cols tiles = [ im[row : row + row_size, col : col + col_size] for row in range(0, num_rows * row_size, row_size) for col in range(0, num_cols * col_size, col_size) ] return tiles def app(model_name): model, processor = load_model(f"koclip/{model_name}") st.title("Patch-based Relevance Ranking") st.markdown( """ Given a piece of text, the CLIP model finds the part of an image that best explains the text. To try it out, you can 1. Upload an image 2. Explain a part of the image in text which will yield the most relevant image tile from a grid of the image. You can specify how granular you want to be with your search by specifying the number of rows and columns that make up the image grid. --- """ ) query1 = st.text_input( "Enter a URL to an image...", value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg", ) query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"]) captions = st.text_input( "Enter a prompt to query the image.", value="이건 서울의 경복궁 사진이다.", ) col1, col2 = st.beta_columns(2) with col1: num_rows = st.slider( "Number of rows", min_value=1, max_value=5, value=3, step=1 ) with col2: num_cols = st.slider( "Number of columns", min_value=1, max_value=5, value=3, step=1 ) if st.button("질문 (Query)"): if not any([query1, query2]): st.error("Please upload an image or paste an image URL.") else: st.markdown("""---""") with st.spinner("Computing..."): image_data = ( query2 if query2 is not None else requests.get(query1, stream=True).raw ) image = Image.open(image_data) st.image(image) images = split_image(image, num_rows, num_cols) inputs = processor( text=captions, images=images, return_tensors="jax", padding=True ) inputs["pixel_values"] = jnp.transpose( inputs["pixel_values"], axes=[0, 2, 3, 1] ) outputs = model(**inputs) probs = jax.nn.softmax(outputs.logits_per_image, axis=0) for idx, prob in sorted( enumerate(probs), key=lambda x: x[1], reverse=True ): st.text(f"Score: {prob[0]:.3f}") st.image(images[idx])