File size: 3,033 Bytes
6985586
a01e989
 
 
 
90de990
a01e989
d05a223
 
6668c84
 
0466d84
 
 
 
 
6985586
 
a9e905c
bca94f8
 
 
a9e905c
 
 
 
 
 
d05a223
a9e905c
d05a223
bca94f8
 
a01e989
 
d05a223
a9e905c
90de990
6668c84
d05a223
a01e989
 
94e7a7b
a01e989
 
 
a9e905c
a01e989
705e8fa
a01e989
 
 
 
a9e905c
a01e989
 
 
 
d05a223
6c67d85
a01e989
 
 
 
 
 
 
90de990
a01e989
 
0466d84
90de990
a01e989
90de990
a01e989
 
6668c84
 
 
a01e989
 
 
367d052
a01e989
 
 
 
624b826
6668c84
6101644
 
0466d84
90de990
6101644
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import streamlit as st
from text2image import get_model, get_tokenizer, get_image_transform
from utils import text_encoder, image_encoder
from PIL import Image
from jax import numpy as jnp
from io import BytesIO
import pandas as pd
import requests
import jax
import gc

headers = {
    "User-Agent":
    "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/70.0.3538.102 Safari/537.36 Edge/18.19582"
}


def app():
    st.title("From Image to Text")
    st.markdown(
        """

        ### 👋 Ciao!

        Here you can find the captions or the labels that are most related to a given image. It is a zero-shot
        image classification task!

        🤌 Italian mode on! 🤌
        
        For example, try typing "gatto" (cat) in the space for label1 and "cane" (dog) in the space for label2 and click
        "classify"!

        """
    )

    image_url = st.text_input(
        "You can input the URL of an image",
        value="https://upload.wikimedia.org/wikipedia/commons/thumb/8/88/Ragdoll%2C_blue_mitted.JPG/1280px-Ragdoll%2C_blue_mitted.JPG",
    )

    MAX_CAP = 4

    col1, col2 = st.columns([0.75, 0.25])

    with col2:
        captions_count = st.selectbox(
            "Number of labels", options=range(1, MAX_CAP + 1), index=1
        )
        compute = st.button("CLASSIFY")

    with col1:
        captions = list()
        for idx in range(min(MAX_CAP, captions_count)):
            captions.append(st.text_input(f"Insert label {idx+1}"))

    if compute:
        captions = [c for c in captions if c != ""]

        if not captions or not image_url:
            st.error("Please choose one image and at least one label")
        else:
            with st.spinner("Computing..."):
                model = get_model()
                tokenizer = get_tokenizer()

                text_embeds = list()
                for i, c in enumerate(captions):
                    text_embeds.extend(text_encoder(c, model, tokenizer)[0])

                text_embeds = jnp.array(text_embeds)
                response = requests.get(image_url, headers=headers, stream=True)
                image = Image.open(BytesIO(response.content)).convert("RGB")
                transform = get_image_transform(model.config.vision_config.image_size)
                image_embed, _ = image_encoder(transform(image), model)

                # we could have a softmax here
                cos_similarities = jax.nn.softmax(
                    jnp.matmul(image_embed, text_embeds.T)
                )

                chart_data = pd.Series(cos_similarities[0], index=captions)

                col1, col2 = st.columns(2)
                with col1:
                    st.bar_chart(chart_data)

                with col2:
                    st.image(image, use_column_width=True)
        gc.collect()

    elif image_url:
        response = requests.get(image_url, headers=headers, stream=True)
        image = Image.open(BytesIO(response.content)).convert("RGB")
        st.image(image)