ydshieh commited on
Commit
d79c24a
1 Parent(s): e79fa8b

fix model download

Browse files
Files changed (2) hide show
  1. model.py +13 -3
  2. requirements.txt +2 -1
model.py CHANGED
@@ -1,10 +1,11 @@
1
- import os, sys
2
  import numpy as np
3
  from PIL import Image
4
 
5
  import jax
6
  from transformers import ViTFeatureExtractor
7
  from transformers import GPT2Tokenizer
 
8
 
9
  current_path = os.path.dirname(os.path.abspath(__file__))
10
  sys.path.append(current_path)
@@ -12,8 +13,17 @@ sys.path.append(current_path)
12
  # Main model - ViTGPT2LM
13
  from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
14
 
15
- model_name_or_path = 'flax-community/vit-gpt2/checkpoints/ckpt_5'
16
- flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_name_or_path)
 
 
 
 
 
 
 
 
 
17
 
18
  def predict(image):
19
  return 'dummy caption!', ['dummy', 'caption', '!'], [1, 2, 3]
 
1
+ import os, sys, shutil
2
  import numpy as np
3
  from PIL import Image
4
 
5
  import jax
6
  from transformers import ViTFeatureExtractor
7
  from transformers import GPT2Tokenizer
8
+ from huggingface_hub import hf_hub_download
9
 
10
  current_path = os.path.dirname(os.path.abspath(__file__))
11
  sys.path.append(current_path)
 
13
  # Main model - ViTGPT2LM
14
  from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
15
 
16
+ # create target model directory
17
+ model_dir = './models/'
18
+ os.makedirs(model_dir, exist_ok=True)
19
+ # copy config file
20
+ filepath = hf_hub_download("flax-community/vit-gpt2", "checkpoints/ckpt_5/config.json")
21
+ shutil.copyfile(filepath, os.path.join(model_dir, 'config.json'))
22
+ # copy model file
23
+ filepath = hf_hub_download("flax-community/vit-gpt2", "checkpoints/ckpt_5/flax_model.msgpack")
24
+ shutil.copyfile(filepath, os.path.join('flax_model.msgpack'))
25
+
26
+ flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_dir)
27
 
28
  def predict(image):
29
  return 'dummy caption!', ['dummy', 'caption', '!'], [1, 2, 3]
requirements.txt CHANGED
@@ -2,4 +2,5 @@ streamlit==0.84.1
2
  Pillow
3
  jax[cpu]
4
  flax
5
- transformers
 
 
2
  Pillow
3
  jax[cpu]
4
  flax
5
+ transformers
6
+ huggingface_hub