edugp commited on
Commit
089d2a3
1 Parent(s): 7b3d1d9

User BERTIN model

Browse files
Files changed (1) hide show
  1. app.py +19 -4
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import sys
3
 
 
4
  import streamlit as st
5
  import transformers
6
  from huggingface_hub import snapshot_download
@@ -10,7 +11,7 @@ LOCAL_PATH = snapshot_download("flax-community/clip-spanish")
10
  sys.path.append(LOCAL_PATH)
11
 
12
  from modeling_hybrid_clip import FlaxHybridCLIP
13
- from test_on_image import run_inference
14
 
15
 
16
  def save_file_to_disk(uplaoded_file):
@@ -22,16 +23,30 @@ def save_file_to_disk(uplaoded_file):
22
 
23
  @st.cache(
24
  hash_funcs={
25
- transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: id,
26
  FlaxHybridCLIP: id,
27
- }
 
28
  )
29
  def load_tokenizer_and_model():
30
  # load the saved model
31
- tokenizer = AutoTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-cased")
32
  model = FlaxHybridCLIP.from_pretrained(LOCAL_PATH)
33
  return tokenizer, model
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  tokenizer, model = load_tokenizer_and_model()
37
 
1
  import os
2
  import sys
3
 
4
+ import jax
5
  import streamlit as st
6
  import transformers
7
  from huggingface_hub import snapshot_download
11
  sys.path.append(LOCAL_PATH)
12
 
13
  from modeling_hybrid_clip import FlaxHybridCLIP
14
+ from test_on_image import prepare_image, prepare_text
15
 
16
 
17
  def save_file_to_disk(uplaoded_file):
23
 
24
  @st.cache(
25
  hash_funcs={
26
+ transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast: id,
27
  FlaxHybridCLIP: id,
28
+ },
29
+ show_spinner=False
30
  )
31
  def load_tokenizer_and_model():
32
  # load the saved model
33
+ tokenizer = AutoTokenizer.from_pretrained("bertin-project/bertin-roberta-base-spanish")
34
  model = FlaxHybridCLIP.from_pretrained(LOCAL_PATH)
35
  return tokenizer, model
36
 
37
+ def run_inference(image_path, text, model, tokenizer):
38
+ pixel_values = prepare_image(image_path, model)
39
+ input_text = prepare_text(text, tokenizer)
40
+ model_output = model(
41
+ input_text["input_ids"],
42
+ pixel_values,
43
+ attention_mask=input_text["attention_mask"],
44
+ train=False,
45
+ return_dict=True,
46
+ )
47
+ logits = model_output["logits_per_image"]
48
+ score = jax.nn.sigmoid(logits)[0][0]
49
+ return score
50
 
51
  tokenizer, model = load_tokenizer_and_model()
52