File size: 4,529 Bytes
653217a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import io
from PIL import Image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize, ToTensor, Compose
from torchvision.transforms.functional import InterpolationMode
import torch
import numpy as np
from transformers import MarianTokenizer
from flax_clip_vision_marian.modeling_clip_vision_marian import FlaxCLIPVisionMarianForConditionalGeneration
import logging
import streamlit as st
from mtranslate import translate

class CaptionGenerator:
    def __init__(self):
        self.tokenizer = None
        self.clip_marian_model = None
        self.marian_model_name = 'Helsinki-NLP/opus-mt-en-id'
        self.clip_marian_model_name = 'flax-community/Image-captioning-Indonesia'

        self.config = None
        self.image_size = None
        self.custom_transforms = None

    def load(self):
        logging.info("Loading tokenizer...")
        marian_model_name = 'Helsinki-NLP/opus-mt-en-id'
        self.tokenizer = MarianTokenizer.from_pretrained(self.marian_model_name)
        logging.info("Tokenizer loaded.")

        logging.info("Loading model...")
        self.model = FlaxCLIPVisionMarianForConditionalGeneration.from_pretrained(self.clip_marian_model_name)
        logging.info("Model loaded.")

        self.config = self.model.config
        self.image_size = self.config.clip_vision_config.image_size

        self.custom_transforms = torch.nn.Sequential(
                            Resize([self.image_size], interpolation=InterpolationMode.BICUBIC),
                            CenterCrop(self.image_size),
                            ConvertImageDtype(torch.float),
                            Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
                        )

    def process_image(self, file):
        logging.info("Loading image...")
        image_data = file.read()
        input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
        loader = Compose([ToTensor()])  
        image = loader(input_image)
        image = self.custom_transforms(image)
        pixel_values = torch.stack([image]).permute(0, 2, 3, 1).numpy()
        logging.info("Image loaded.")

        return pixel_values
        
    def generate_step(self, pixel_values, max_len, num_beams):
        gen_kwargs = {"max_length": max_len , "num_beams": num_beams}

        logging.info("Generating caption...")
        output_ids = self.model.generate(pixel_values, **gen_kwargs)
        token_ids = np.array(output_ids.sequences)[0]
        caption = self.tokenizer.decode(token_ids)
        logging.info("Caption generated.")

        return caption

    def get_caption(self, file, max_len, num_beams):
        pixel_values = self.process_image(file)

        generated_ids = self.generate_step(pixel_values, max_len, num_beams)
        return generated_ids

@st.cache(allow_output_mutation=True)
def load_caption_generator():
    generator = CaptionGenerator()
    generator.load()
    return generator

def main():
    st.set_page_config(page_title="Indonesian Image Captioning Demo", page_icon="🖼️")
    generator = load_caption_generator()

    st.title("Indonesian Image Captioning Demo")

    st.markdown(
        """Indonesian image captioning demo, trained on [CLIP](https://huggingface.co/transformers/model_doc/clip.html) and [Marian](https://huggingface.co/transformers/model_doc/marian.html). Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).
        """
    )

    st.sidebar.subheader("Configurable parameters")

    max_len = st.sidebar.number_input(
        "Maximum length",
        value=8,
        help="The maximum length of the sequence (caption) to be generated."
    )

    num_beams = st.sidebar.number_input(
        "Number of beams",
        value=4,
        help="Number of beams for beam search. 1 means no beam search."
    )

    input_image = st.file_uploader("Insert image")
    if st.button("Run"):
        with st.spinner(text="Getting results..."):
            if input_image:
                caption = generator.get_caption(file=input_image, max_len=max_len, num_beams=num_beams)
                st.subheader("Result")
                st.write(caption.replace("<pad>", ""))
                st.text("English translation")
                st.write(translate(caption, "en", "id").replace("<pad>", ""))
            else:
                st.write("Please upload an image.")

if __name__ == '__main__':
    main()