Spaces:
Runtime error
Runtime error
File size: 2,196 Bytes
18ca2a9 089d2a3 6b1177a 52a4ec3 18ca2a9 089d2a3 18ca2a9 52a4ec3 18ca2a9 52a4ec3 089d2a3 52a4ec3 089d2a3 52a4ec3 75ecf13 089d2a3 75ecf13 18ca2a9 089d2a3 52a4ec3 75ecf13 18ca2a9 b9a0770 7b3d1d9 18ca2a9 6b1177a 18ca2a9 89ab2bc 52a4ec3 89ab2bc 18ca2a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import os
import sys
import jax
import streamlit as st
import transformers
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
LOCAL_PATH = snapshot_download("flax-community/clip-spanish")
sys.path.append(LOCAL_PATH)
from modeling_hybrid_clip import FlaxHybridCLIP
from test_on_image import prepare_image, prepare_text
def save_file_to_disk(uplaoded_file):
temp_file = os.path.join("/tmp", uplaoded_file.name)
with open(temp_file, "wb") as f:
f.write(uploaded_file.getbuffer())
return temp_file
@st.cache(
hash_funcs={
transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast: id,
FlaxHybridCLIP: id,
},
show_spinner=False
)
def load_tokenizer_and_model():
# load the saved model
tokenizer = AutoTokenizer.from_pretrained("bertin-project/bertin-roberta-base-spanish")
model = FlaxHybridCLIP.from_pretrained(LOCAL_PATH)
return tokenizer, model
def run_inference(image_path, text, model, tokenizer):
pixel_values = prepare_image(image_path, model)
input_text = prepare_text(text, tokenizer)
model_output = model(
input_text["input_ids"],
pixel_values,
attention_mask=input_text["attention_mask"],
train=False,
return_dict=True,
)
logits = model_output["logits_per_image"]
score = jax.nn.sigmoid(logits)[0][0]
return score
tokenizer, model = load_tokenizer_and_model()
st.title("Caption Scoring")
uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg"])
text_input = st.text_input("Type a caption")
if uploaded_file is not None and text_input:
local_image_path = None
try:
local_image_path = save_file_to_disk(uploaded_file)
score = run_inference(local_image_path, text_input, model, tokenizer).tolist()
st.image(
uploaded_file,
caption=text_input,
width=None,
use_column_width=None,
clamp=False,
channels="RGB",
output_format="auto",
)
st.write(f"## Score: {score:.2f}")
finally:
if local_image_path:
os.remove(local_image_path)
|