history blame
No virus
3.17 kB
language: id


This is an encoder-decoder image captioning model using CLIP as the visual encoder and Marian as the textual decoder on datasets with Indonesian captions.

This model was trained using HuggingFace's Flax framework and is part of the JAX/Flax Community Week organized by HuggingFace. All training was done on a TPUv3-8 VM sponsored by the Google Cloud team.

How to use

At time of writing, you will need to install HuggingFace from its latest master branch in order to load FlaxMarian.

You will also need to have the flax_clip_vision_marian folder in your project directory to load the model using the FlaxCLIPVisionMarianForConditionalGeneration class.

from import ImageReadMode, read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode

import torch
import numpy as np
from transformers import MarianTokenizer
from flax_clip_vision_marian.modeling_clip_vision_marian import FlaxCLIPVisionMarianForConditionalGeneration

clip_marian_model_name = 'flax-community/Image-captioning-Indonesia'
model = FlaxCLIPVisionMarianForConditionalGeneration.from_pretrained(clip_marian_model_name)

marian_model_name = 'Helsinki-NLP/opus-mt-en-id'
tokenizer = MarianTokenizer.from_pretrained(marian_model_name)

config = model.config
image_size = config.clip_vision_config.image_size

# Image transformation
transforms = torch.nn.Sequential(
                    Resize([image_size], interpolation=InterpolationMode.BICUBIC),
                    Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),

# Hyperparameters
max_length = 8
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

def generate_step(batch):
    output_ids = model.generate(pixel_values, **gen_kwargs)
    token_ids = np.array(output_ids.sequences)[0]
    caption = tokenizer.decode(token_ids)
    return caption

image_file_path = image_file_path
image = read_image('000000039769.jpg', mode=ImageReadMode.RGB)
image = transforms(image)
pixel_values = torch.stack([image]).permute(0, 2, 3, 1).numpy()

generated_ids = generate_step(pixel_values)


Training data

To be added.

Training procedure

The model was trained on a TPUv3-8 VM provided by the Google Cloud team.

Team members