Galuh
commited on
Commit
•
5948cd7
1
Parent(s):
3e827fe
Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: id
|
3 |
+
---
|
4 |
+
|
5 |
+
# Image-captioning-Indonesia
|
6 |
+
|
7 |
+
This is an encoder-decoder image captioning model using [CLIP](https://huggingface.co/transformers/model_doc/clip.html) as the visual encoder and [Marian](https://huggingface.co/transformers/model_doc/marian.html) as the textual decoder on datasets with Indonesian captions.
|
8 |
+
|
9 |
+
This model was trained using HuggingFace's Flax framework and is part of the [JAX/Flax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104) organized by [HuggingFace](https://huggingface.co). All training was done on a TPUv3-8 VM sponsored by the Google Cloud team.
|
10 |
+
|
11 |
+
## How to use
|
12 |
+
At time of writing, you will need to install [HuggingFace](https://github.com/huggingface/) from its latest master branch in order to load `FlaxMarian`.
|
13 |
+
|
14 |
+
You will also need to have the [`flax_clip_vision_marian` folder](https://github.com/indonesian-nlp/Indonesia-Image-Captioning/tree/main/flax_clip_vision_marian) in your project directory to load the model using the `FlaxCLIPVisionMarianForConditionalGeneration` class.
|
15 |
+
|
16 |
+
Assuming that you have image bytes as input:
|
17 |
+
|
18 |
+
```
|
19 |
+
from torchvision.io import ImageReadMode, read_image
|
20 |
+
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
21 |
+
from torchvision.transforms.functional import InterpolationMode
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import numpy as np
|
25 |
+
from transformers import MarianTokenizer
|
26 |
+
from flax_clip_vision_marian.modeling_clip_vision_marian import FlaxCLIPVisionMarianForConditionalGeneration
|
27 |
+
|
28 |
+
clip_marian_model_name = 'flax-community/Image-captioning-Indonesia'
|
29 |
+
model = FlaxCLIPVisionMarianForConditionalGeneration.from_pretrained(clip_marian_model_name)
|
30 |
+
|
31 |
+
marian_model_name = 'Helsinki-NLP/opus-mt-en-id'
|
32 |
+
tokenizer = MarianTokenizer.from_pretrained(marian_model_name)
|
33 |
+
|
34 |
+
config = model.config
|
35 |
+
image_size = config.clip_vision_config.image_size
|
36 |
+
|
37 |
+
# Image transformation
|
38 |
+
transforms = torch.nn.Sequential(
|
39 |
+
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
|
40 |
+
CenterCrop(image_size),
|
41 |
+
ConvertImageDtype(torch.float),
|
42 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
43 |
+
)
|
44 |
+
|
45 |
+
# Hyperparameters
|
46 |
+
max_length = 8
|
47 |
+
num_beams = 4
|
48 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
49 |
+
|
50 |
+
def generate_step(batch):
|
51 |
+
output_ids = model.generate(pixel_values, **gen_kwargs)
|
52 |
+
token_ids = np.array(output_ids.sequences)[0]
|
53 |
+
caption = tokenizer.decode(token_ids)
|
54 |
+
return caption
|
55 |
+
|
56 |
+
image_file_path = image_file_path
|
57 |
+
image = read_image('000000039769.jpg', mode=ImageReadMode.RGB)
|
58 |
+
image = transforms(image)
|
59 |
+
pixel_values = torch.stack([image]).permute(0, 2, 3, 1).numpy()
|
60 |
+
|
61 |
+
generated_ids = generate_step(pixel_values)
|
62 |
+
|
63 |
+
print(generated_ids)
|
64 |
+
```
|
65 |
+
|
66 |
+
## Training data
|
67 |
+
To be added.
|
68 |
+
|
69 |
+
## Training procedure
|
70 |
+
The model was trained on a TPUv3-8 VM provided by the Google Cloud team.
|
71 |
+
|
72 |
+
## Team members
|
73 |
+
- Cahya Wirawan ([@cahya](https://huggingface.co/cahya))
|
74 |
+
- Galuh Sahid ([@Galuh](https://huggingface.co/Galuh))
|
75 |
+
- Muhammad Agung Hambali ([@AyameRushia](https://huggingface.co/AyameRushia))
|
76 |
+
- Samsul Rahmadani ([@munggok](https://huggingface.co/munggok))
|