Polos-Demo / app.py
yuwd's picture
init
03f6091
raw
history blame
1 kB
import streamlit as st
from PIL import Image
from polos.models import download_model, load_checkpoint
@st.cache(allow_output_mutation=True)
def load_model():
model_path = download_model("polos")
model = load_checkpoint(model_path)
return model
model = load_model()
default_image = Image.open("test.jpg").convert("RGB")
default_refs = [
"there is a dog sitting on a couch with a person reaching out",
"a dog laying on a couch with a person",
'a dog is laying on a couch with a person'
]
data = [
{
"img": default_image,
"mt": "",
"refs": default_refs
}
]
# Streamlitインターフェースの設定
st.title('Polos Demo')
# ユーザー入力のテキストフィールド
user_input = st.text_input("Enter the input sentence:", '')
# 入力がある場合、モデルを使用してスコアを計算
if user_input:
data[0]['mt'] = user_input
_, scores = model.predict(data, batch_size=1, cuda=False)
st.write("Score:", scores)