|
--- |
|
language: id |
|
--- |
|
|
|
# Image-captioning-Indonesia |
|
|
|
This is an encoder-decoder image captioning model using [CLIP](https://huggingface.co/transformers/model_doc/clip.html) as the visual encoder and [Marian](https://huggingface.co/transformers/model_doc/marian.html) as the textual decoder on datasets with Indonesian captions. |
|
|
|
This model was trained using HuggingFace's Flax framework and is part of the [JAX/Flax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104) organized by [HuggingFace](https://huggingface.co). All training was done on a TPUv3-8 VM sponsored by the Google Cloud team. |
|
|
|
## How to use |
|
At time of writing, you will need to install [HuggingFace](https://github.com/huggingface/) from its latest master branch in order to load `FlaxMarian`. |
|
|
|
You will also need to have the [`flax_clip_vision_marian` folder](https://github.com/indonesian-nlp/Indonesia-Image-Captioning/tree/main/flax_clip_vision_marian) in your project directory to load the model using the `FlaxCLIPVisionMarianForConditionalGeneration` class. |
|
|
|
```python |
|
from torchvision.io import ImageReadMode, read_image |
|
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
import torch |
|
import numpy as np |
|
from transformers import MarianTokenizer |
|
from flax_clip_vision_marian.modeling_clip_vision_marian import FlaxCLIPVisionMarianForConditionalGeneration |
|
|
|
clip_marian_model_name = 'flax-community/Image-captioning-Indonesia' |
|
model = FlaxCLIPVisionMarianForConditionalGeneration.from_pretrained(clip_marian_model_name) |
|
|
|
marian_model_name = 'Helsinki-NLP/opus-mt-en-id' |
|
tokenizer = MarianTokenizer.from_pretrained(marian_model_name) |
|
|
|
config = model.config |
|
image_size = config.clip_vision_config.image_size |
|
|
|
# Image transformation |
|
transforms = torch.nn.Sequential( |
|
Resize([image_size], interpolation=InterpolationMode.BICUBIC), |
|
CenterCrop(image_size), |
|
ConvertImageDtype(torch.float), |
|
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
|
) |
|
|
|
# Hyperparameters |
|
max_length = 8 |
|
num_beams = 4 |
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams} |
|
|
|
def generate_step(batch): |
|
output_ids = model.generate(pixel_values, **gen_kwargs) |
|
token_ids = np.array(output_ids.sequences)[0] |
|
caption = tokenizer.decode(token_ids) |
|
return caption |
|
|
|
image_file_path = image_file_path |
|
image = read_image(image_file_path, mode=ImageReadMode.RGB) |
|
image = transforms(image) |
|
pixel_values = torch.stack([image]).permute(0, 2, 3, 1).numpy() |
|
|
|
generated_ids = generate_step(pixel_values) |
|
|
|
print(generated_ids) |
|
``` |
|
|
|
## Training data |
|
The Model was trained on translated Coco,Flickr and ViZWiz, each of them were translated using google translate and marian mt. we took only random 2 captions per image for each datasets |
|
|
|
## Training procedure |
|
The model was trained on a TPUv3-8 VM provided by the Google Cloud team. |
|
|
|
## Team members |
|
- Cahya Wirawan ([@cahya](https://huggingface.co/cahya)) |
|
- Galuh Sahid ([@Galuh](https://huggingface.co/Galuh)) |
|
- Muhammad Agung Hambali ([@AyameRushia](https://huggingface.co/AyameRushia)) |
|
- Samsul Rahmadani ([@munggok](https://huggingface.co/munggok)) |