aswinkvj commited on
Commit
0482489
1 Parent(s): a8d0c50
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +97 -0
  2. model.py +82 -0
  3. requirements.txt +8 -0
  4. samples/COCO_val2017_000000006771.jpg +0 -0
  5. samples/COCO_val2017_000000021903.jpg +0 -0
  6. samples/COCO_val2017_000000030213.jpg +0 -0
  7. samples/COCO_val2017_000000039956.jpg +0 -0
  8. samples/COCO_val2017_000000045472.jpg +0 -0
  9. samples/COCO_val2017_000000053505.jpg +0 -0
  10. samples/COCO_val2017_000000057597.jpg +0 -0
  11. samples/COCO_val2017_000000059386.jpg +0 -0
  12. samples/COCO_val2017_000000067406.jpg +0 -0
  13. samples/COCO_val2017_000000069795.jpg +0 -0
  14. samples/COCO_val2017_000000084431.jpg +0 -0
  15. samples/COCO_val2017_000000088432.jpg +0 -0
  16. samples/COCO_val2017_000000100238.jpg +0 -0
  17. samples/COCO_val2017_000000104619.jpg +0 -0
  18. samples/COCO_val2017_000000104803.jpg +0 -0
  19. samples/COCO_val2017_000000124442.jpg +0 -0
  20. samples/COCO_val2017_000000125936.jpg +0 -0
  21. samples/COCO_val2017_000000132703.jpg +0 -0
  22. samples/COCO_val2017_000000146155.jpg +0 -0
  23. samples/COCO_val2017_000000149770.jpg +0 -0
  24. samples/COCO_val2017_000000152120.jpg +0 -0
  25. samples/COCO_val2017_000000154431.jpg +0 -0
  26. samples/COCO_val2017_000000161609.jpg +0 -0
  27. samples/COCO_val2017_000000163258.jpg +0 -0
  28. samples/COCO_val2017_000000168593.jpg +0 -0
  29. samples/COCO_val2017_000000170116.jpg +0 -0
  30. samples/COCO_val2017_000000172330.jpg +0 -0
  31. samples/COCO_val2017_000000173371.jpg +0 -0
  32. samples/COCO_val2017_000000175535.jpg +0 -0
  33. samples/COCO_val2017_000000178469.jpg +0 -0
  34. samples/COCO_val2017_000000180188.jpg +0 -0
  35. samples/COCO_val2017_000000180296.jpg +0 -0
  36. samples/COCO_val2017_000000181969.jpg +0 -0
  37. samples/COCO_val2017_000000190676.jpg +0 -0
  38. samples/COCO_val2017_000000199055.jpg +0 -0
  39. samples/COCO_val2017_000000204186.jpg +0 -0
  40. samples/COCO_val2017_000000213547.jpg +0 -0
  41. samples/COCO_val2017_000000216497.jpg +0 -0
  42. samples/COCO_val2017_000000216739.jpg +0 -0
  43. samples/COCO_val2017_000000224675.jpg +0 -0
  44. samples/COCO_val2017_000000226903.jpg +0 -0
  45. samples/COCO_val2017_000000230983.jpg +0 -0
  46. samples/COCO_val2017_000000232684.jpg +0 -0
  47. samples/COCO_val2017_000000234757.jpg +0 -0
  48. samples/COCO_val2017_000000256195.jpg +0 -0
  49. samples/COCO_val2017_000000266409.jpg +0 -0
  50. samples/COCO_val2017_000000267946.jpg +0 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import io
4
+
5
+
6
+ # Designing the interface
7
+ st.title("🖼️ Image Captioning Demo 📝")
8
+ st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)")
9
+
10
+ st.sidebar.markdown(
11
+ """
12
+ An image captioning model by combining ViT model with GPT2 model.
13
+ The encoder (ViT) and decoder (GPT2) are combined using Hugging Face transformers' [Vision-To-Text Encoder-Decoder
14
+ framework](https://huggingface.co/transformers/master/model_doc/visionencoderdecoder.html).
15
+ The pretrained weights of both models are loaded, with a set of randomly initialized cross-attention weights.
16
+ The model is trained on the COCO 2017 dataset for about 6900 steps (batch_size=256).
17
+ [Follow-up work of [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n
18
+ """
19
+ )
20
+
21
+ with st.spinner('Loading and compiling ViT-GPT2 model ...'):
22
+ from model import *
23
+
24
+ random_image_id = get_random_image_id()
25
+
26
+ st.sidebar.title("Select a sample image")
27
+ sample_image_id = st.sidebar.selectbox(
28
+ "Please choose a sample image",
29
+ sample_image_ids
30
+ )
31
+
32
+ if st.sidebar.button("Random COCO 2017 (val) images"):
33
+ random_image_id = get_random_image_id()
34
+ sample_image_id = "None"
35
+
36
+ bytes_data = None
37
+ with st.sidebar.form("file-uploader-form", clear_on_submit=True):
38
+ uploaded_file = st.file_uploader("Choose a file")
39
+ submitted = st.form_submit_button("Upload")
40
+ if submitted and uploaded_file is not None:
41
+ bytes_data = io.BytesIO(uploaded_file.getvalue())
42
+
43
+ if (bytes_data is None) and submitted:
44
+
45
+ st.write("No file is selected to upload")
46
+
47
+ else:
48
+
49
+ image_id = random_image_id
50
+ if sample_image_id != "None":
51
+ assert type(sample_image_id) == int
52
+ image_id = sample_image_id
53
+
54
+ sample_name = f"COCO_val2017_{str(image_id).zfill(12)}.jpg"
55
+ sample_path = os.path.join(sample_dir, sample_name)
56
+
57
+ if bytes_data is not None:
58
+ image = Image.open(bytes_data)
59
+ elif os.path.isfile(sample_path):
60
+ image = Image.open(sample_path)
61
+ else:
62
+ url = f"http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg"
63
+ image = Image.open(requests.get(url, stream=True).raw)
64
+
65
+ width, height = image.size
66
+ resized = image.resize(size=(width, height))
67
+ if height > 384:
68
+ width = int(width / height * 384)
69
+ height = 384
70
+ resized = resized.resize(size=(width, height))
71
+ width, height = resized.size
72
+ if width > 512:
73
+ width = 512
74
+ height = int(height / width * 512)
75
+ resized = resized.resize(size=(width, height))
76
+
77
+ if bytes_data is None:
78
+ st.markdown(f"[{str(image_id).zfill(12)}.jpg](http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg)")
79
+ show = st.image(resized)
80
+ show.image(resized, '\n\nSelected Image')
81
+ resized.close()
82
+
83
+ # For newline
84
+ st.sidebar.write('\n')
85
+
86
+ with st.spinner('Generating image caption ...'):
87
+
88
+ caption = predict(image)
89
+
90
+ caption_en = caption
91
+ st.header(f'Predicted caption:\n\n')
92
+ st.subheader(caption_en)
93
+
94
+ st.sidebar.header("ViT-GPT2 predicts: ")
95
+ st.sidebar.write(f"{caption}")
96
+
97
+ image.close()
model.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os, shutil
3
+ import random
4
+
5
+
6
+ from PIL import Image
7
+ import jax
8
+ from transformers import FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
9
+ from huggingface_hub import hf_hub_download
10
+
11
+
12
+ # create target model directory
13
+ model_dir = './models/'
14
+ os.makedirs(model_dir, exist_ok=True)
15
+
16
+ files_to_download = [
17
+ "config.json",
18
+ "flax_model.msgpack",
19
+ "merges.txt",
20
+ "special_tokens_map.json",
21
+ "tokenizer.json",
22
+ "tokenizer_config.json",
23
+ "vocab.json",
24
+ "preprocessor_config.json",
25
+ ]
26
+
27
+ # copy files from checkpoint hub:
28
+ for fn in files_to_download:
29
+ file_path = hf_hub_download("ydshieh/vit-gpt2-coco-en-ckpts", f"ckpt_epoch_3_step_6900/{fn}")
30
+ shutil.copyfile(file_path, os.path.join(model_dir, fn))
31
+
32
+ model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
33
+ feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
34
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
35
+
36
+ max_length = 16
37
+ num_beams = 4
38
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
39
+
40
+
41
+ @jax.jit
42
+ def generate(pixel_values):
43
+ output_ids = model.generate(pixel_values, **gen_kwargs).sequences
44
+ return output_ids
45
+
46
+
47
+ def predict(image):
48
+
49
+ if image.mode != "RGB":
50
+ image = image.convert(mode="RGB")
51
+
52
+ pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
53
+
54
+ output_ids = generate(pixel_values)
55
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
56
+ preds = [pred.strip() for pred in preds]
57
+
58
+ return preds[0]
59
+
60
+
61
+ def _compile():
62
+
63
+ image_path = 'samples/val_000000039769.jpg'
64
+ image = Image.open(image_path)
65
+ predict(image)
66
+ image.close()
67
+
68
+
69
+ _compile()
70
+
71
+
72
+ sample_dir = './samples/'
73
+ sample_image_ids = tuple(["None"] + [int(f.replace('COCO_val2017_', '').replace('.jpg', '')) for f in os.listdir(sample_dir) if f.startswith('COCO_val2017_')])
74
+
75
+ with open(os.path.join(sample_dir, "coco-val2017-img-ids.json"), "r", encoding="UTF-8") as fp:
76
+ coco_2017_val_image_ids = json.load(fp)
77
+
78
+
79
+ def get_random_image_id():
80
+
81
+ image_id = random.sample(coco_2017_val_image_ids, k=1)[0]
82
+ return image_id
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit==0.84.1
2
+ Pillow
3
+ jax[cpu]
4
+ flax
5
+ transformers
6
+ huggingface_hub
7
+ googletrans==4.0.0-rc1
8
+ protobuf==3.20
samples/COCO_val2017_000000006771.jpg ADDED
samples/COCO_val2017_000000021903.jpg ADDED
samples/COCO_val2017_000000030213.jpg ADDED
samples/COCO_val2017_000000039956.jpg ADDED
samples/COCO_val2017_000000045472.jpg ADDED
samples/COCO_val2017_000000053505.jpg ADDED
samples/COCO_val2017_000000057597.jpg ADDED
samples/COCO_val2017_000000059386.jpg ADDED
samples/COCO_val2017_000000067406.jpg ADDED
samples/COCO_val2017_000000069795.jpg ADDED
samples/COCO_val2017_000000084431.jpg ADDED
samples/COCO_val2017_000000088432.jpg ADDED
samples/COCO_val2017_000000100238.jpg ADDED
samples/COCO_val2017_000000104619.jpg ADDED
samples/COCO_val2017_000000104803.jpg ADDED
samples/COCO_val2017_000000124442.jpg ADDED
samples/COCO_val2017_000000125936.jpg ADDED
samples/COCO_val2017_000000132703.jpg ADDED
samples/COCO_val2017_000000146155.jpg ADDED
samples/COCO_val2017_000000149770.jpg ADDED
samples/COCO_val2017_000000152120.jpg ADDED
samples/COCO_val2017_000000154431.jpg ADDED
samples/COCO_val2017_000000161609.jpg ADDED
samples/COCO_val2017_000000163258.jpg ADDED
samples/COCO_val2017_000000168593.jpg ADDED
samples/COCO_val2017_000000170116.jpg ADDED
samples/COCO_val2017_000000172330.jpg ADDED
samples/COCO_val2017_000000173371.jpg ADDED
samples/COCO_val2017_000000175535.jpg ADDED
samples/COCO_val2017_000000178469.jpg ADDED
samples/COCO_val2017_000000180188.jpg ADDED
samples/COCO_val2017_000000180296.jpg ADDED
samples/COCO_val2017_000000181969.jpg ADDED
samples/COCO_val2017_000000190676.jpg ADDED
samples/COCO_val2017_000000199055.jpg ADDED
samples/COCO_val2017_000000204186.jpg ADDED
samples/COCO_val2017_000000213547.jpg ADDED
samples/COCO_val2017_000000216497.jpg ADDED
samples/COCO_val2017_000000216739.jpg ADDED
samples/COCO_val2017_000000224675.jpg ADDED
samples/COCO_val2017_000000226903.jpg ADDED
samples/COCO_val2017_000000230983.jpg ADDED
samples/COCO_val2017_000000232684.jpg ADDED
samples/COCO_val2017_000000234757.jpg ADDED
samples/COCO_val2017_000000256195.jpg ADDED
samples/COCO_val2017_000000266409.jpg ADDED
samples/COCO_val2017_000000267946.jpg ADDED