Galuh commited on
Commit
5948cd7
1 Parent(s): 3e827fe

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +76 -0
README.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: id
3
+ ---
4
+
5
+ # Image-captioning-Indonesia
6
+
7
+ This is an encoder-decoder image captioning model using [CLIP](https://huggingface.co/transformers/model_doc/clip.html) as the visual encoder and [Marian](https://huggingface.co/transformers/model_doc/marian.html) as the textual decoder on datasets with Indonesian captions.
8
+
9
+ This model was trained using HuggingFace's Flax framework and is part of the [JAX/Flax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104) organized by [HuggingFace](https://huggingface.co). All training was done on a TPUv3-8 VM sponsored by the Google Cloud team.
10
+
11
+ ## How to use
12
+ At time of writing, you will need to install [HuggingFace](https://github.com/huggingface/) from its latest master branch in order to load `FlaxMarian`.
13
+
14
+ You will also need to have the [`flax_clip_vision_marian` folder](https://github.com/indonesian-nlp/Indonesia-Image-Captioning/tree/main/flax_clip_vision_marian) in your project directory to load the model using the `FlaxCLIPVisionMarianForConditionalGeneration` class.
15
+
16
+ Assuming that you have image bytes as input:
17
+
18
+ ```
19
+ from torchvision.io import ImageReadMode, read_image
20
+ from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
21
+ from torchvision.transforms.functional import InterpolationMode
22
+
23
+ import torch
24
+ import numpy as np
25
+ from transformers import MarianTokenizer
26
+ from flax_clip_vision_marian.modeling_clip_vision_marian import FlaxCLIPVisionMarianForConditionalGeneration
27
+
28
+ clip_marian_model_name = 'flax-community/Image-captioning-Indonesia'
29
+ model = FlaxCLIPVisionMarianForConditionalGeneration.from_pretrained(clip_marian_model_name)
30
+
31
+ marian_model_name = 'Helsinki-NLP/opus-mt-en-id'
32
+ tokenizer = MarianTokenizer.from_pretrained(marian_model_name)
33
+
34
+ config = model.config
35
+ image_size = config.clip_vision_config.image_size
36
+
37
+ # Image transformation
38
+ transforms = torch.nn.Sequential(
39
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
40
+ CenterCrop(image_size),
41
+ ConvertImageDtype(torch.float),
42
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
43
+ )
44
+
45
+ # Hyperparameters
46
+ max_length = 8
47
+ num_beams = 4
48
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
49
+
50
+ def generate_step(batch):
51
+ output_ids = model.generate(pixel_values, **gen_kwargs)
52
+ token_ids = np.array(output_ids.sequences)[0]
53
+ caption = tokenizer.decode(token_ids)
54
+ return caption
55
+
56
+ image_file_path = image_file_path
57
+ image = read_image('000000039769.jpg', mode=ImageReadMode.RGB)
58
+ image = transforms(image)
59
+ pixel_values = torch.stack([image]).permute(0, 2, 3, 1).numpy()
60
+
61
+ generated_ids = generate_step(pixel_values)
62
+
63
+ print(generated_ids)
64
+ ```
65
+
66
+ ## Training data
67
+ To be added.
68
+
69
+ ## Training procedure
70
+ The model was trained on a TPUv3-8 VM provided by the Google Cloud team.
71
+
72
+ ## Team members
73
+ - Cahya Wirawan ([@cahya](https://huggingface.co/cahya))
74
+ - Galuh Sahid ([@Galuh](https://huggingface.co/Galuh))
75
+ - Muhammad Agung Hambali ([@AyameRushia](https://huggingface.co/AyameRushia))
76
+ - Samsul Rahmadani ([@munggok](https://huggingface.co/munggok))