import os import requests import streamlit as st from PIL import Image import jax import jax.numpy as jnp import numpy as np from utils import load_model def split_image(im): im = np.array(im) M = im.shape[0] // 3 N = im.shape[1] // 3 tiles = [ im[x:x + M, y:y + N] for x in range(0, im.shape[0], M) for y in range(0, im.shape[1], N) ] return tiles # def split_image(X): # num_rows = X.shape[0] // 224 # num_cols = X.shape[1] // 224 # Xc = X[0:num_rows * 224, 0:num_cols * 224, :] # patches = [] # for j in range(num_rows): # for i in range(num_cols): # patches.append(Xc[j * 224:(j + 1) * 224, i * 224:(i + 1) * 224, :]) # return patches def app(model_name): model, processor = load_model(f"koclip/{model_name}") st.title("Most Relevant Part of Image") st.markdown(""" Given a piece of text, the CLIP model finds the part of an image that best explains the text. To try it out, you can 1) Upload an image 2) Explain a part of the image in text Which will yield the most relevant image tile from a 3x3 grid of the image """) query1 = st.text_input( "Enter a URL to an image...", value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg") query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"]) captions = st.text_input( "Enter query to find most relevant part of image ", value="이건 서울의 경복궁 사진이다.", ) if st.button("질문 (Query)"): if not any([query1, query2]): st.error("Please upload an image or paste an image URL.") else: image_data = (query2 if query2 is not None else requests.get( query1, stream=True).raw) image = Image.open(image_data) st.image(image) images = split_image(image) inputs = processor(text=captions, images=images, return_tensors="jax", padding=True) inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1]) outputs = model(**inputs) probs = jax.nn.softmax(outputs.logits_per_image, axis=0) for idx, prob in sorted(enumerate(probs), key=lambda x: x[1], reverse=True): st.text(f"Score: {prob[0]:.3f}") st.image(images[idx])