import sys import gradio as gr import jax from huggingface_hub import snapshot_download from PIL import Image from transformers import AutoTokenizer LOCAL_PATH = snapshot_download("flax-community/clip-spanish") sys.path.append(LOCAL_PATH) from modeling_hybrid_clip import FlaxHybridCLIP from test_on_image import prepare_image, prepare_text def save_file_to_disk(uplaoded_file): temp_file = "/tmp/image.jpeg" im = Image.fromarray(uplaoded_file) im.save(temp_file) # with open(temp_file, "wb") as f: # f.write(uploaded_file.getbuffer()) return temp_file def run_inference(image_path, text, model, tokenizer): pixel_values = prepare_image(image_path, model) input_text = prepare_text(text, tokenizer) model_output = model( input_text["input_ids"], pixel_values, attention_mask=input_text["attention_mask"], train=False, return_dict=True, ) logits = model_output["logits_per_image"] score = jax.nn.sigmoid(logits)[0][0] return score def load_tokenizer_and_model(): # load the saved model tokenizer = AutoTokenizer.from_pretrained( "bertin-project/bertin-roberta-base-spanish" ) model = FlaxHybridCLIP.from_pretrained(LOCAL_PATH) return tokenizer, model tokenizer, model = load_tokenizer_and_model() def score_image_caption_pair(uploaded_file, text_input): local_image_path = save_file_to_disk(uploaded_file) score = run_inference( local_image_path, text_input, model, tokenizer).tolist() return {"Score": score}, "{:.2f}".format(score) image = gr.inputs.Image(shape=(299, 299)) iface = gr.Interface( fn=score_image_caption_pair, inputs=[image, "text"], outputs=["label", "text"] ) iface.launch()