--- language: id --- # Image-captioning-Indonesia This is an encoder-decoder image captioning model using [CLIP](https://huggingface.co/transformers/model_doc/clip.html) as the visual encoder and [Marian](https://huggingface.co/transformers/model_doc/marian.html) as the textual decoder on datasets with Indonesian captions. This model was trained using HuggingFace's Flax framework and is part of the [JAX/Flax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104) organized by [HuggingFace](https://huggingface.co). All training was done on a TPUv3-8 VM sponsored by the Google Cloud team. ## How to use At time of writing, you will need to install [HuggingFace](https://github.com/huggingface/) from its latest master branch in order to load `FlaxMarian`. You will also need to have the [`flax_clip_vision_marian` folder](https://github.com/indonesian-nlp/Indonesia-Image-Captioning/tree/main/flax_clip_vision_marian) in your project directory to load the model using the `FlaxCLIPVisionMarianForConditionalGeneration` class. ```python from torchvision.io import ImageReadMode, read_image from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize from torchvision.transforms.functional import InterpolationMode import torch import numpy as np from transformers import MarianTokenizer from flax_clip_vision_marian.modeling_clip_vision_marian import FlaxCLIPVisionMarianForConditionalGeneration clip_marian_model_name = 'flax-community/Image-captioning-Indonesia' model = FlaxCLIPVisionMarianForConditionalGeneration.from_pretrained(clip_marian_model_name) marian_model_name = 'Helsinki-NLP/opus-mt-en-id' tokenizer = MarianTokenizer.from_pretrained(marian_model_name) config = model.config image_size = config.clip_vision_config.image_size # Image transformation transforms = torch.nn.Sequential( Resize([image_size], interpolation=InterpolationMode.BICUBIC), CenterCrop(image_size), ConvertImageDtype(torch.float), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ) # Hyperparameters max_length = 8 num_beams = 4 gen_kwargs = {"max_length": max_length, "num_beams": num_beams} def generate_step(batch): output_ids = model.generate(pixel_values, **gen_kwargs) token_ids = np.array(output_ids.sequences)[0] caption = tokenizer.decode(token_ids) return caption image_file_path = image_file_path image = read_image('000000039769.jpg', mode=ImageReadMode.RGB) image = transforms(image) pixel_values = torch.stack([image]).permute(0, 2, 3, 1).numpy() generated_ids = generate_step(pixel_values) print(generated_ids) ``` ## Training data To be added. ## Training procedure The model was trained on a TPUv3-8 VM provided by the Google Cloud team. ## Team members - Cahya Wirawan ([@cahya](https://huggingface.co/cahya)) - Galuh Sahid ([@Galuh](https://huggingface.co/Galuh)) - Muhammad Agung Hambali ([@AyameRushia](https://huggingface.co/AyameRushia)) - Samsul Rahmadani ([@munggok](https://huggingface.co/munggok))