File size: 2,196 Bytes
18ca2a9
 
 
089d2a3
6b1177a
52a4ec3
18ca2a9
 
 
 
 
 
 
089d2a3
18ca2a9
 
 
 
52a4ec3
18ca2a9
 
 
52a4ec3
 
 
089d2a3
52a4ec3
089d2a3
 
52a4ec3
75ecf13
 
089d2a3
75ecf13
 
18ca2a9
089d2a3
 
 
 
 
 
 
 
 
 
 
 
 
52a4ec3
75ecf13
18ca2a9
b9a0770
7b3d1d9
18ca2a9
6b1177a
18ca2a9
 
 
 
89ab2bc
52a4ec3
 
 
 
 
 
 
 
 
89ab2bc
18ca2a9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import sys

import jax
import streamlit as st
import transformers
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer

LOCAL_PATH = snapshot_download("flax-community/clip-spanish")
sys.path.append(LOCAL_PATH)

from modeling_hybrid_clip import FlaxHybridCLIP
from test_on_image import prepare_image, prepare_text


def save_file_to_disk(uplaoded_file):
    temp_file = os.path.join("/tmp", uplaoded_file.name)
    with open(temp_file, "wb") as f:
        f.write(uploaded_file.getbuffer())
    return temp_file


@st.cache(
    hash_funcs={
        transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast: id,
        FlaxHybridCLIP: id,
    },
    show_spinner=False
)
def load_tokenizer_and_model():
    # load the saved model
    tokenizer = AutoTokenizer.from_pretrained("bertin-project/bertin-roberta-base-spanish")
    model = FlaxHybridCLIP.from_pretrained(LOCAL_PATH)
    return tokenizer, model

def run_inference(image_path, text, model, tokenizer):
    pixel_values = prepare_image(image_path, model)
    input_text = prepare_text(text, tokenizer)
    model_output = model(
        input_text["input_ids"],
        pixel_values,
        attention_mask=input_text["attention_mask"],
        train=False,
        return_dict=True,
    )
    logits = model_output["logits_per_image"]
    score = jax.nn.sigmoid(logits)[0][0]
    return score

tokenizer, model = load_tokenizer_and_model()

st.title("Caption Scoring")
uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg"])
text_input = st.text_input("Type a caption")

if uploaded_file is not None and text_input:
    local_image_path = None
    try:
        local_image_path = save_file_to_disk(uploaded_file)
        score = run_inference(local_image_path, text_input, model, tokenizer).tolist()
        st.image(
            uploaded_file,
            caption=text_input,
            width=None,
            use_column_width=None,
            clamp=False,
            channels="RGB",
            output_format="auto",
        )
        st.write(f"## Score: {score:.2f}")
    finally:
        if local_image_path:
            os.remove(local_image_path)