acul3's picture
Update README.md
1865061
---
language: id
---
# Image-captioning-Indonesia
This is an encoder-decoder image captioning model using [CLIP](https://huggingface.co/transformers/model_doc/clip.html) as the visual encoder and [Marian](https://huggingface.co/transformers/model_doc/marian.html) as the textual decoder on datasets with Indonesian captions.
This model was trained using HuggingFace's Flax framework and is part of the [JAX/Flax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104) organized by [HuggingFace](https://huggingface.co). All training was done on a TPUv3-8 VM sponsored by the Google Cloud team.
## How to use
At time of writing, you will need to install [HuggingFace](https://github.com/huggingface/) from its latest master branch in order to load `FlaxMarian`.
You will also need to have the [`flax_clip_vision_marian` folder](https://github.com/indonesian-nlp/Indonesia-Image-Captioning/tree/main/flax_clip_vision_marian) in your project directory to load the model using the `FlaxCLIPVisionMarianForConditionalGeneration` class.
```python
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
import torch
import numpy as np
from transformers import MarianTokenizer
from flax_clip_vision_marian.modeling_clip_vision_marian import FlaxCLIPVisionMarianForConditionalGeneration
clip_marian_model_name = 'flax-community/Image-captioning-Indonesia'
model = FlaxCLIPVisionMarianForConditionalGeneration.from_pretrained(clip_marian_model_name)
marian_model_name = 'Helsinki-NLP/opus-mt-en-id'
tokenizer = MarianTokenizer.from_pretrained(marian_model_name)
config = model.config
image_size = config.clip_vision_config.image_size
# Image transformation
transforms = torch.nn.Sequential(
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
ConvertImageDtype(torch.float),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
)
# Hyperparameters
max_length = 8
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def generate_step(batch):
output_ids = model.generate(pixel_values, **gen_kwargs)
token_ids = np.array(output_ids.sequences)[0]
caption = tokenizer.decode(token_ids)
return caption
image_file_path = image_file_path
image = read_image(image_file_path, mode=ImageReadMode.RGB)
image = transforms(image)
pixel_values = torch.stack([image]).permute(0, 2, 3, 1).numpy()
generated_ids = generate_step(pixel_values)
print(generated_ids)
```
## Training data
The Model was trained on translated Coco,Flickr and ViZWiz, each of them were translated using google translate and marian mt. we took only random 2 captions per image for each datasets
## Training procedure
The model was trained on a TPUv3-8 VM provided by the Google Cloud team.
## Team members
- Cahya Wirawan ([@cahya](https://huggingface.co/cahya))
- Galuh Sahid ([@Galuh](https://huggingface.co/Galuh))
- Muhammad Agung Hambali ([@AyameRushia](https://huggingface.co/AyameRushia))
- Samsul Rahmadani ([@munggok](https://huggingface.co/munggok))