sayakpaul HF staff commited on
Commit
0dd57bb
1 Parent(s): 1c951f1

feat: loading from HF hub

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -1,20 +1,19 @@
 
1
  import gradio as gr
2
  import tensorflow as tf
3
- import tensorflow_hub as hub
4
  from PIL import Image
5
 
6
  import utils
7
 
8
  _RESOLUTION = 224
9
- _MODEL_PATH = "gs://cait-tf/cait_xxs24_224"
10
 
11
 
12
  def get_model() -> tf.keras.Model:
13
- """Initiates a tf.keras.Model from TF-Hub."""
14
  inputs = tf.keras.Input((_RESOLUTION, _RESOLUTION, 3))
15
- hub_module = hub.KerasLayer(_MODEL_PATH)
16
 
17
- logits, sa_atn_score_dict, ca_atn_score_dict = hub_module(inputs)
18
 
19
  return tf.keras.Model(
20
  inputs, [logits, sa_atn_score_dict, ca_atn_score_dict]
 
1
+ from huggingface_hub.keras_mixin import from_pretrained_keras
2
  import gradio as gr
3
  import tensorflow as tf
 
4
  from PIL import Image
5
 
6
  import utils
7
 
8
  _RESOLUTION = 224
 
9
 
10
 
11
  def get_model() -> tf.keras.Model:
12
+ """Initiates a tf.keras.Model from HF Hub."""
13
  inputs = tf.keras.Input((_RESOLUTION, _RESOLUTION, 3))
14
+ hub_module = from_pretrained_keras("probing-vits/cait_xxs24_224_classification")
15
 
16
+ logits, sa_atn_score_dict, ca_atn_score_dict = hub_module(inputs, training=False)
17
 
18
  return tf.keras.Model(
19
  inputs, [logits, sa_atn_score_dict, ca_atn_score_dict]