File size: 2,115 Bytes
6f0178d
c951094
6f0178d
 
 
686f21e
 
c951094
d79c24a
8d9306e
 
d79c24a
 
 
943681e
c951094
 
 
 
 
 
 
 
ac444cc
c951094
 
 
 
679d099
c951094
 
 
 
 
 
 
 
943681e
 
 
 
c951094
 
 
943681e
 
e755009
8d9306e
0ca7ab6
 
 
c951094
8f85ccf
c951094
 
 
8d9306e
c951094
943681e
 
c951094
943681e
 
 
8f85ccf
943681e
 
8d9306e
c951094
 
9a6a97f
 
144ec50
6f0178d
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import json
import os, shutil
import random


from PIL import Image
import jax
from transformers import FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
from huggingface_hub import hf_hub_download


# create target model directory
model_dir = './models/'
os.makedirs(model_dir, exist_ok=True)

files_to_download = [
    "config.json",
    "flax_model.msgpack",
    "merges.txt",
    "special_tokens_map.json",
    "tokenizer.json",
    "tokenizer_config.json",
    "vocab.json",
    "preprocessor_config.json",
]

# copy files from checkpoint hub:
for fn in files_to_download:
    file_path = hf_hub_download("ydshieh/vit-gpt2-coco-en-ckpts", f"ckpt_epoch_3_step_6900/{fn}")
    shutil.copyfile(file_path, os.path.join(model_dir, fn))

model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)

max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}


@jax.jit
def generate(pixel_values):
    output_ids = model.generate(pixel_values, **gen_kwargs).sequences
    return output_ids


def predict(image):

    if image.mode != "RGB":
        image = image.convert(mode="RGB")

    pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values

    output_ids = generate(pixel_values)
    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    preds = [pred.strip() for pred in preds]

    return preds[0]


def _compile():

    image_path = 'samples/val_000000039769.jpg'
    image = Image.open(image_path)
    predict(image)
    image.close()


_compile()


sample_dir = './samples/'
sample_image_ids = tuple(["None"] + [int(f.replace('COCO_val2017_', '').replace('.jpg', '')) for f in os.listdir(sample_dir) if f.startswith('COCO_val2017_')])

with open(os.path.join(sample_dir, "coco-val2017-img-ids.json"), "r", encoding="UTF-8") as fp:
    coco_2017_val_image_ids = json.load(fp)


def get_random_image_id():

    image_id = random.sample(coco_2017_val_image_ids, k=1)[0]
    return image_id