import os import jax import jax.numpy as jnp import numpy as np import requests import streamlit as st from PIL import Image from utils import load_model def split_image(im, num_rows=3, num_cols=3): im = np.array(im) row_size = im.shape[0] // num_rows col_size = im.shape[1] // num_cols tiles = [ im[x : x + M, y : y + N] for x in range(0, num_rows * row_size, row_size) for y in range(0, num_cols * col_size, col_size) ] return tiles # def split_image(X): # num_rows = X.shape[0] // 224 # num_cols = X.shape[1] // 224 # Xc = X[0:num_rows * 224, 0:num_cols * 224, :] # patches = [] # for j in range(num_rows): # for i in range(num_cols): # patches.append(Xc[j * 224:(j + 1) * 224, i * 224:(i + 1) * 224, :]) # return patches def app(model_name): model, processor = load_model(f"koclip/{model_name}") st.title("Most Relevant Part of Image") st.markdown( """ Given a piece of text, the CLIP model finds the part of an image that best explains the text. To try it out, you can 1) Upload an image 2) Explain a part of the image in text Which will yield the most relevant image tile from a 3x3 grid of the image """ ) query1 = st.text_input( "Enter a URL to an image...", value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg", ) query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"]) captions = st.text_input( "Enter query to find most relevant part of image ", value="이건 서울의 경복궁 사진이다.", ) if st.button("질문 (Query)"): if not any([query1, query2]): st.error("Please upload an image or paste an image URL.") else: image_data = ( query2 if query2 is not None else requests.get(query1, stream=True).raw ) image = Image.open(image_data) st.image(image) images = split_image(image) inputs = processor( text=captions, images=images, return_tensors="jax", padding=True ) inputs["pixel_values"] = jnp.transpose( inputs["pixel_values"], axes=[0, 2, 3, 1] ) outputs = model(**inputs) probs = jax.nn.softmax(outputs.logits_per_image, axis=0) for idx, prob in sorted(enumerate(probs), key=lambda x: x[1], reverse=True): st.text(f"Score: {prob[0]:.3f}") st.image(images[idx])