koclip / most_relevant_part.py
jaketae's picture
fix: discard remaining pixels
6525b03
raw
history blame
2.61 kB
import os
import jax
import jax.numpy as jnp
import numpy as np
import requests
import streamlit as st
from PIL import Image
from utils import load_model
def split_image(im, num_rows=3, num_cols=3):
im = np.array(im)
row_size = im.shape[0] // num_rows
col_size = im.shape[1] // num_cols
tiles = [
im[x : x + M, y : y + N]
for x in range(0, num_rows * row_size, row_size)
for y in range(0, num_cols * col_size, col_size)
]
return tiles
# def split_image(X):
# num_rows = X.shape[0] // 224
# num_cols = X.shape[1] // 224
# Xc = X[0:num_rows * 224, 0:num_cols * 224, :]
# patches = []
# for j in range(num_rows):
# for i in range(num_cols):
# patches.append(Xc[j * 224:(j + 1) * 224, i * 224:(i + 1) * 224, :])
# return patches
def app(model_name):
model, processor = load_model(f"koclip/{model_name}")
st.title("Most Relevant Part of Image")
st.markdown(
"""
Given a piece of text, the CLIP model finds the part of an image that best explains the text.
To try it out, you can
1) Upload an image
2) Explain a part of the image in text
Which will yield the most relevant image tile from a 3x3 grid of the image
"""
)
query1 = st.text_input(
"Enter a URL to an image...",
value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg",
)
query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])
captions = st.text_input(
"Enter query to find most relevant part of image ",
value="이건 서울의 경복궁 사진이다.",
)
if st.button("질문 (Query)"):
if not any([query1, query2]):
st.error("Please upload an image or paste an image URL.")
else:
image_data = (
query2 if query2 is not None else requests.get(query1, stream=True).raw
)
image = Image.open(image_data)
st.image(image)
images = split_image(image)
inputs = processor(
text=captions, images=images, return_tensors="jax", padding=True
)
inputs["pixel_values"] = jnp.transpose(
inputs["pixel_values"], axes=[0, 2, 3, 1]
)
outputs = model(**inputs)
probs = jax.nn.softmax(outputs.logits_per_image, axis=0)
for idx, prob in sorted(enumerate(probs), key=lambda x: x[1], reverse=True):
st.text(f"Score: {prob[0]:.3f}")
st.image(images[idx])