File size: 3,340 Bytes
5948cd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7539a3
5948cd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57929b3
5948cd7
 
 
 
 
 
 
 
 
1865061
5948cd7
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
---
language: id
---

# Image-captioning-Indonesia

This is an encoder-decoder image captioning model using [CLIP](https://huggingface.co/transformers/model_doc/clip.html) as the visual encoder and [Marian](https://huggingface.co/transformers/model_doc/marian.html) as the textual decoder on datasets with Indonesian captions.

This model was trained using HuggingFace's Flax framework and is part of the [JAX/Flax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104) organized by [HuggingFace](https://huggingface.co). All training was done on a TPUv3-8 VM sponsored by the Google Cloud team.

## How to use
At time of writing, you will need to install [HuggingFace](https://github.com/huggingface/) from its latest master branch in order to load `FlaxMarian`.

You will also need to have the [`flax_clip_vision_marian` folder](https://github.com/indonesian-nlp/Indonesia-Image-Captioning/tree/main/flax_clip_vision_marian) in your project directory to load the model using the `FlaxCLIPVisionMarianForConditionalGeneration` class.

```python
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode

import torch
import numpy as np
from transformers import MarianTokenizer
from flax_clip_vision_marian.modeling_clip_vision_marian import FlaxCLIPVisionMarianForConditionalGeneration

clip_marian_model_name = 'flax-community/Image-captioning-Indonesia'
model = FlaxCLIPVisionMarianForConditionalGeneration.from_pretrained(clip_marian_model_name)

marian_model_name = 'Helsinki-NLP/opus-mt-en-id'
tokenizer = MarianTokenizer.from_pretrained(marian_model_name)

config = model.config
image_size = config.clip_vision_config.image_size

# Image transformation
transforms = torch.nn.Sequential(
                    Resize([image_size], interpolation=InterpolationMode.BICUBIC),
                    CenterCrop(image_size),
                    ConvertImageDtype(torch.float),
                    Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
                )

# Hyperparameters
max_length = 8
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

def generate_step(batch):
    output_ids = model.generate(pixel_values, **gen_kwargs)
    token_ids = np.array(output_ids.sequences)[0]
    caption = tokenizer.decode(token_ids)
    return caption

image_file_path = image_file_path
image = read_image(image_file_path, mode=ImageReadMode.RGB)
image = transforms(image)
pixel_values = torch.stack([image]).permute(0, 2, 3, 1).numpy()

generated_ids = generate_step(pixel_values)

print(generated_ids)
```

## Training data
The Model was trained on translated Coco,Flickr and ViZWiz, each of them were translated using google translate and marian mt. we took only random 2 captions per image for each datasets

## Training procedure 
The model was trained on a TPUv3-8 VM provided by the Google Cloud team.

## Team members
- Cahya Wirawan ([@cahya](https://huggingface.co/cahya))
- Galuh Sahid ([@Galuh](https://huggingface.co/Galuh))
- Muhammad Agung Hambali ([@AyameRushia](https://huggingface.co/AyameRushia))
- Samsul Rahmadani ([@munggok](https://huggingface.co/munggok))