File size: 2,154 Bytes
6985586
a01e989
 
 
 
 
 
6985586
 
bca94f8
 
 
 
 
 
6c67d85
 
bca94f8
 
 
 
a01e989
 
 
 
 
 
 
 
 
 
 
 
6c67d85
a01e989
 
 
 
 
 
6c67d85
a01e989
 
 
 
 
6c67d85
a01e989
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import streamlit as st
from text2image import get_model, get_tokenizer, get_image_transform
from utils import text_encoder, image_encoder
from PIL import Image
from jax import numpy as jnp
import pandas as pd


def app():
    st.title("From Image to Text")
    st.markdown(
        """

        ### 👋 Ciao!

        Here you can find the captions or the labels that are most related to a given image. It is a zero-shot
        image classification task!

        🤌 Italian mode on! 🤌

        """
    )

    filename = st.file_uploader(
        "Choose an image from your computer", type=["jpg", "jpeg", "png"]
    )

    MAX_CAP = 4

    col1, col2 = st.beta_columns([3, 1])

    with col2:
        captions_count = st.selectbox(
            "Number of labels", options=range(1, MAX_CAP + 1)
        )
        compute = st.button("Compute")

    with col1:
        captions = list()
        for idx in range(min(MAX_CAP, captions_count)):
            captions.append(st.text_input(f"Insert label {idx+1}"))

    if compute:
        captions = [c for c in captions if c != ""]

        if not captions or not filename:
            st.error("Please choose one image and at least one label")
        else:
            with st.spinner("Computing..."):
                model = get_model()
                tokenizer = get_tokenizer()

                text_embeds = list()
                for i, c in enumerate(captions):
                    text_embeds.extend(text_encoder(c, model, tokenizer))

                text_embeds = jnp.array(text_embeds)

                image = Image.open(filename).convert("RGB")
                transform = get_image_transform(model.config.vision_config.image_size)
                image_embed = image_encoder(transform(image), model)

                # we could have a softmax here
                cos_similarities = jnp.matmul(image_embed, text_embeds.T)

                chart_data = pd.Series(cos_similarities[0], index=captions)

                col1, col2 = st.beta_columns(2)
                with col1:
                    st.bar_chart(chart_data)

                with col2:
                    st.image(image)