cocktailpeanut commited on
Commit
f72e31f
1 Parent(s): dea4744
Files changed (2) hide show
  1. app.py +2 -2
  2. models.py +8 -4
app.py CHANGED
@@ -226,8 +226,8 @@ def create_app():
226
  )
227
  def get_gpu_kind():
228
  device = jax.devices()[0]
229
- if not gradio_helpers.should_mock() and device.platform != 'gpu':
230
- raise gr.Error('GPU not visible to JAX!')
231
  return f'GPU={device.device_kind}'
232
  demo.load(get_gpu_kind, None, gpu_kind)
233
 
 
226
  )
227
  def get_gpu_kind():
228
  device = jax.devices()[0]
229
+ # if not gradio_helpers.should_mock() and device.platform != 'gpu':
230
+ # raise gr.Error('GPU not visible to JAX!')
231
  return f'GPU={device.device_kind}'
232
  demo.load(get_gpu_kind, None, gpu_kind)
233
 
models.py CHANGED
@@ -11,17 +11,21 @@ import gradio_helpers
11
  import paligemma_bv
12
 
13
 
14
- ORGANIZATION = 'google'
 
15
  BASE_MODELS = [
16
- ('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
17
- ('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
 
 
18
  ]
19
  MODELS = {
20
  **{
21
  model_name: (
22
  f'{ORGANIZATION}/{repo}',
23
  f'{model_name}.bf16.npz',
24
- 'bfloat16', # Model repo revision.
 
25
  )
26
  for repo, model_name in BASE_MODELS
27
  },
 
11
  import paligemma_bv
12
 
13
 
14
+ #ORGANIZATION = 'google'
15
+ ORGANIZATION = 'cocktailpeanut'
16
  BASE_MODELS = [
17
+ ('pg', 'paligemma-3b-mix-224'),
18
+ ('pg', 'paligemma-3b-mix-448'),
19
+ # ('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
20
+ # ('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
21
  ]
22
  MODELS = {
23
  **{
24
  model_name: (
25
  f'{ORGANIZATION}/{repo}',
26
  f'{model_name}.bf16.npz',
27
+ 'main',
28
+ #'bfloat16', # Model repo revision.
29
  )
30
  for repo, model_name in BASE_MODELS
31
  },