File size: 2,742 Bytes
5319378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
---

tags:
- image-classification
library_name: generic
---


## Example

The model is by no means a state-of-the-art model, but nevertheless
produces reasonable image captioning results. It was mainly fine-tuned 
as a proof-of-concept for the 🤗 FlaxVisionEncoderDecoder Framework.

The model can be used as follows:

**In PyTorch**
```python



import torch

import requests

from PIL import Image

from transformers import ViTFeatureExtractor, AutoTokenizer, VisionEncoderDecoderModel





loc = "ydshieh/vit-gpt2-coco-en"



feature_extractor = ViTFeatureExtractor.from_pretrained(loc)

tokenizer = AutoTokenizer.from_pretrained(loc)

model = VisionEncoderDecoderModel.from_pretrained(loc)

model.eval()





def predict(image):



    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values



    with torch.no_grad():

        output_ids = model.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True).sequences



    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

    preds = [pred.strip() for pred in preds]



    return preds





# We will verify our results on an image of cute cats

url = "http://images.cocodataset.org/val2017/000000039769.jpg"

with Image.open(requests.get(url, stream=True).raw) as image:

    preds = predict(image)



print(preds)

# should produce

# ['a cat laying on top of a couch next to another cat']



```

**In Flax**
```python



import jax

import requests

from PIL import Image

from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel





loc = "ydshieh/vit-gpt2-coco-en"



feature_extractor = ViTFeatureExtractor.from_pretrained(loc)

tokenizer = AutoTokenizer.from_pretrained(loc)

model = FlaxVisionEncoderDecoderModel.from_pretrained(loc)



gen_kwargs = {"max_length": 16, "num_beams": 4}





# This takes sometime when compiling the first time, but the subsequent inference will be much faster

@jax.jit

def generate(pixel_values):

    output_ids = model.generate(pixel_values, **gen_kwargs).sequences

    return output_ids

    

    

def predict(image):



    pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values

    output_ids = generate(pixel_values)

    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

    preds = [pred.strip() for pred in preds]

    

    return preds

    

    

# We will verify our results on an image of cute cats

url = "http://images.cocodataset.org/val2017/000000039769.jpg"

with Image.open(requests.get(url, stream=True).raw) as image:

    preds = predict(image)

    

print(preds)

# should produce

# ['a cat laying on top of a couch next to another cat']



```