koclip / image2text.py
jaketae's picture
feature: add image2text feature
8ff0261
raw history blame
No virus
1.58 kB
import streamlit as st
import numpy as np
import jax
import jax.numpy as jnp
from PIL import Image
from utils import load_model
def app(model_name):
model, processor = load_model(f"koclip/{model_name}")
st.title("Zero-shot Image Classification")
st.markdown(
"""
Some text goes in here.
"""
)
query = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
captions = st.text_input("์‚ฌ์šฉํ•˜์‹ค ์บก์…˜์„ ์‰ผํ‘œ ๋‹จ์œ„๋กœ ๊ตฌ๋ถ„ํ•ด์„œ ์ ์–ด์ฃผ์„ธ์š”", value="๊ณ ์–‘์ด,๊ฐ•์•„์ง€,๋Šํ‹ฐ๋‚˜๋ฌด...")
if st.button("์งˆ๋ฌธ (Query)"):
if query is None:
st.error("Please upload an image query.")
else:
image = Image.open(query)
st.image(image)
# pixel_values = processor(
# text=[""], images=image, return_tensors="jax", padding=True
# ).pixel_values
# pixel_values = jnp.transpose(pixel_values, axes=[0, 2, 3, 1])
# vec = np.asarray(model.get_image_features(pixel_values))
captions = captions.split(",")
inputs = processor(text=captions, images=image, return_tensors="jax", padding=True)
inputs["pixel_values"] = jnp.transpose(
inputs["pixel_values"], axes=[0, 2, 3, 1]
)
outputs = model(**inputs)
probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
for idx, prob in sorted(enumerate(*probs), key=lambda x: x[1], reverse=True):
st.text(f"Score: `{prob}`, {captions[idx]}")