File size: 2,576 Bytes
03f6091
 
 
 
a005919
 
03f6091
 
 
 
 
 
 
 
 
ad5defb
a005919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03f6091
a005919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad5defb
a005919
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import streamlit as st
from PIL import Image
from polos.models import download_model, load_checkpoint

# モデルのロード
@st.cache_resource()
def load_model():
    model_path = download_model("polos")
    model = load_checkpoint(model_path)
    return model

model = load_model()

# Streamlitインターフェースの設定
st.title('Polos Demo')

# セッションステートの初期化
if 'image' not in st.session_state:
    st.session_state.image = None
if 'user_input' not in st.session_state:
    st.session_state.user_input = ''
if 'user_refs' not in st.session_state:
    st.session_state.user_refs = [
        "there is a dog sitting on a couch with a person reaching out",
        "a dog laying on a couch with a person",
        'a dog is laying on a couch with a person'
    ]
if 'score' not in st.session_state:
    st.session_state.score = None

# デフォルト画像の取得
@st.cache_resource()
def get_default_image():
    try:
        return Image.open("test.jpg").convert("RGB")
    except FileNotFoundError:
        return Image.new('RGB', (200, 200), color = 'gray')  # デフォルト画像が見つからない場合の代替画像

default_image = get_default_image()

# 画像アップロードのためのウィジェット
uploaded_image = st.file_uploader("Upload your image:", type=["jpg", "jpeg", "png"])
if uploaded_image is not None:
    st.session_state.image = Image.open(uploaded_image).convert("RGB")
elif st.session_state.image is None:
    st.session_state.image = default_image

# 常に画像を表示
st.image(st.session_state.image, caption="Displayed Image", use_column_width=True)

# 参照文の入力フィールド
user_refs = st.text_area("Enter reference sentences (separate each by a newline):", "\n".join(st.session_state.user_refs))
st.session_state.user_refs = user_refs.split("\n")

# ユーザー入力のテキストフィールド
user_input = st.text_input("Enter the input sentence:", value=st.session_state.user_input)
st.session_state.user_input = user_input

# Computeボタン
if st.button('Compute'):
    # データの準備
    data = [
        {
            "img": st.session_state.image,
            "mt": st.session_state.user_input,
            "refs": st.session_state.user_refs
        }
    ]

    # モデル予測
    if st.session_state.user_input:
        _, scores = model.predict(data, batch_size=1, cuda=False)
        st.session_state.score = scores[0]

# スコアの表示
if st.session_state.score is not None:
    st.metric(label="Score", value=f"{st.session_state.score:.5f}")