import os import jax import jax.numpy as jnp import numpy as np import requests import streamlit as st from PIL import Image from utils import load_model def split_image(im, num_rows=3, num_cols=3): im = np.array(im) row_size = im.shape[0] // num_rows col_size = im.shape[1] // num_cols tiles = [ im[row : row + row_size, col : col + col_size] for row in range(0, num_rows * row_size, row_size) for col in range(0, num_cols * col_size, col_size) ] return tiles # def split_image(X): # num_rows = X.shape[0] // 224 # num_cols = X.shape[1] // 224 # Xc = X[0:num_rows * 224, 0:num_cols * 224, :] # patches = [] # for j in range(num_rows): # for i in range(num_cols): # patches.append(Xc[j * 224:(j + 1) * 224, i * 224:(i + 1) * 224, :]) # return patches def app(model_name): model, processor = load_model(f"koclip/{model_name}") st.title("Most Relevant Part of Image") st.markdown( """ Given a piece of text, the CLIP model finds the part of an image that best explains the text. To try it out, you can 1. Upload an image 2. Explain a part of the image in text which will yield the most relevant image tile from a grid of the image. You can specify how granular you want to be with your search by specifying the number of rows and columns that make up the image grid. """ ) query1 = st.text_input( "Enter a URL to an image...", value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg", ) query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"]) captions = st.text_input( "Enter query to find most relevant part of image ", value="이건 서울의 경복궁 사진이다.", ) num_rows = st.slider("Number of rows", min_value=1, max_value=5, value=3, step=1) num_cols = st.slider("Number of columns", min_value=1, max_value=5, value=3, step=1) if st.button("질문 (Query)"): if not any([query1, query2]): st.error("Please upload an image or paste an image URL.") else: image_data = ( query2 if query2 is not None else requests.get(query1, stream=True).raw ) image = Image.open(image_data) st.image(image) images = split_image(image, num_rows, num_cols) inputs = processor( text=captions, images=images, return_tensors="jax", padding=True ) inputs["pixel_values"] = jnp.transpose( inputs["pixel_values"], axes=[0, 2, 3, 1] ) outputs = model(**inputs) probs = jax.nn.softmax(outputs.logits_per_image, axis=0) for idx, prob in sorted(enumerate(probs), key=lambda x: x[1], reverse=True): st.text(f"Score: {prob[0]:.3f}") st.image(images[idx])