ydshieh HF staff commited on
Commit
f6d2d95
1 Parent(s): 4628220

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +65 -10
README.md CHANGED
@@ -12,35 +12,90 @@ as a proof-of-concept for the 🤗 FlaxVisionEncoderDecoder Framework.
12
 
13
  The model can be used as follows:
14
 
 
15
  ```python
16
 
 
17
  import requests
18
  from PIL import Image
19
- from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel
 
 
20
 
21
  loc = "ydshieh/vit-gpt2-coco-en"
22
 
23
  feature_extractor = ViTFeatureExtractor.from_pretrained(loc)
24
  tokenizer = AutoTokenizer.from_pretrained(loc)
25
- model = FlaxVisionEncoderDecoderModel.from_pretrained(loc)
 
26
 
27
- # We will verify our results on an image of cute cats
28
- url = "http://images.cocodataset.org/val2017/000000039769.jpg"
29
- with Image.open(requests.get(url, stream=True).raw) as img:
30
- pixel_values = feature_extractor(images=img, return_tensors="np").pixel_values
31
 
32
- def generate_step(pixel_values):
 
 
 
 
 
33
 
34
- output_ids = model.generate(pixel_values, max_length=16, num_beams=4).sequences
35
  preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
36
  preds = [pred.strip() for pred in preds]
37
 
38
  return preds
39
 
40
- preds = generate_step(pixel_values)
41
- print(preds)
42
 
 
 
 
 
 
 
43
  # should produce
44
  # ['a cat laying on top of a couch next to another cat']
45
 
46
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  The model can be used as follows:
14
 
15
+ In PyTorch
16
  ```python
17
 
18
+ import torch
19
  import requests
20
  from PIL import Image
21
+ from transformers import ViTFeatureExtractor, AutoTokenizer, VisionEncoderDecoderModel
22
+ from transformers.testing_utils import require_sentorch_device
23
+
24
 
25
  loc = "ydshieh/vit-gpt2-coco-en"
26
 
27
  feature_extractor = ViTFeatureExtractor.from_pretrained(loc)
28
  tokenizer = AutoTokenizer.from_pretrained(loc)
29
+ model = VisionEncoderDecoderModel.from_pretrained(loc)
30
+ model.eval()
31
 
 
 
 
 
32
 
33
+ def predict(image):
34
+
35
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
36
+
37
+ with torch.no_grad():
38
+ output_ids = model.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True).sequences
39
 
 
40
  preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
41
  preds = [pred.strip() for pred in preds]
42
 
43
  return preds
44
 
 
 
45
 
46
+ # We will verify our results on an image of cute cats
47
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
48
+ with Image.open(requests.get(url, stream=True).raw) as image:
49
+ preds = predict(image)
50
+
51
+ print(preds)
52
  # should produce
53
  # ['a cat laying on top of a couch next to another cat']
54
 
55
  ```
56
+
57
+ In Flax
58
+ ```python
59
+
60
+ import jax
61
+ import requests
62
+ from PIL import Image
63
+ from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel
64
+
65
+
66
+ loc = "ydshieh/vit-gpt2-coco-en"
67
+
68
+ feature_extractor = ViTFeatureExtractor.from_pretrained(loc)
69
+ tokenizer = AutoTokenizer.from_pretrained(loc)
70
+ model = FlaxVisionEncoderDecoderModel.from_pretrained(loc)
71
+
72
+ gen_kwargs = {"max_length": 16, "num_beams": 4}
73
+
74
+
75
+ # This takes sometime when compiling the first time, but the subsequent inference will be much faster
76
+ @jax.jit
77
+ def generate(pixel_values):
78
+ output_ids = model.generate(pixel_values, **gen_kwargs).sequences
79
+ return output_ids
80
+
81
+
82
+ def predict(image):
83
+
84
+ pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
85
+ output_ids = generate(pixel_values)
86
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
87
+ preds = [pred.strip() for pred in preds]
88
+
89
+ return preds
90
+
91
+
92
+ # We will verify our results on an image of cute cats
93
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
94
+ with Image.open(requests.get(url, stream=True).raw) as image:
95
+ preds = predict(image)
96
+
97
+ print(preds)
98
+ # should produce
99
+ # ['a cat laying on top of a couch next to another cat']
100
+
101
+ ```