File size: 10,735 Bytes
e126020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import math
import os
import sys

import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torchvision

os.system("git clone https://github.com/xplip/pixel.git")
sys.path.append('./pixel')

from transformers import set_seed
from pixel.src.pixel import (
    PIXELConfig,
    PIXELForPreTraining,
    SpanMaskingGenerator,
    PyGameTextRenderer,
    get_transforms,
    resize_model_embeddings,
    truncate_decoder_pos_embeddings,
    get_attention_mask
)

model_name_or_path = "Team-PIXEL/pixel-base"
max_seq_length = 529
text_renderer = PyGameTextRenderer.from_pretrained(model_name_or_path, max_seq_length=max_seq_length)
config = PIXELConfig.from_pretrained(model_name_or_path)
model = PIXELForPreTraining.from_pretrained(model_name_or_path, config=config)

def clip(x: torch.Tensor):
    x = torch.einsum("chw->hwc", x)
    x = torch.clip(x * 255, 0, 255)
    x = torch.einsum("hwc->chw", x)
    return x

def get_image(img: torch.Tensor, do_clip: bool = True):
    if do_clip:
        img = clip(img)
    img = torchvision.utils.make_grid(img, normalize=True)
    image = Image.fromarray(
        img.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
    )
    return image

def inference(text: str, mask_ratio: float = 0.25, max_span_length: int = 6, seed: int = 42):
    config.update({"mask_ratio": mask_ratio})
    resize_model_embeddings(model, max_seq_length)
    truncate_decoder_pos_embeddings(model, max_seq_length)

    set_seed(seed)

    transforms = get_transforms(
        do_resize=True,
        size=(text_renderer.pixels_per_patch, text_renderer.pixels_per_patch * text_renderer.max_seq_length),
    )

    encoding = text_renderer(text=text)
    attention_mask = get_attention_mask(
        num_text_patches=encoding.num_text_patches, seq_length=text_renderer.max_seq_length
    )

    img = transforms(Image.fromarray(encoding.pixel_values)).unsqueeze(0)
    attention_mask = attention_mask.unsqueeze(0)
    inputs = {"pixel_values": img.float(), "attention_mask": attention_mask}

    mask_generator = SpanMaskingGenerator(
        num_patches=text_renderer.max_seq_length,
        num_masking_patches=math.ceil(mask_ratio * text_renderer.max_seq_length),
        max_span_length=max_span_length,
        spacing="span"
    )
    mask = torch.tensor(mask_generator(num_text_patches=(encoding.num_text_patches + 1))).unsqueeze(0)
    inputs.update({"patch_mask": mask})

    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)

    predictions = model.unpatchify(outputs["logits"]).detach().cpu().squeeze()

    mask = outputs["mask"].detach().cpu()
    mask = mask.unsqueeze(-1).repeat(1, 1, text_renderer.pixels_per_patch ** 2 * 3)
    mask = model.unpatchify(mask).squeeze()  # 1 is removing, 0 is keeping

    attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, text_renderer.pixels_per_patch ** 2 * 3)
    attention_mask = model.unpatchify(attention_mask).squeeze()

    original_img = model.unpatchify(model.patchify(img)).squeeze()

    im_masked = original_img * (1 - (torch.bitwise_and(mask == 1, attention_mask == 1)).long())

    masked_predictions = predictions * mask * attention_mask

    reconstruction = im_masked + masked_predictions

    return [get_image(original_img), get_image(im_masked), get_image(masked_predictions, do_clip=False), get_image(reconstruction, do_clip=False)]
  
examples = [
    ["Penguins are designed to be streamlined and hydrodynamic, so having long legs would add extra drag. Having short legs with webbed feet to act like rudders, helps to give them that torpedo-like figure. If we compare bird anatomy with humans, we would see something a bit peculiar. By taking a look at the side-by-side image in Figure 1, you can see how their leg bones compare to ours. What most people mistake for knees are actually the ankles of the birds. This gives the illusion that bird knees bend opposite of ours. The knees are actually tucked up inside the body cavity of the bird! So how does this look inside of a penguin? In the images below, you can see boxes surrounding the penguins’ knees.", 0.2, 6, 42],
    ["Félicette didn’t seem like a typical astronaut. She weighed just five and a half pounds. She’d spent most of her life on the streets of Paris. And Félicette was a cat, one of 14 trained by French scientists for space flight. In 1963, she went where no feline had gone before. Chosen for her calm demeanor and low weight, Félicette was strapped into a rocket in October of that year. She spent 15 minutes on a dizzying flight to the stars before returning safely to earth. Her legacy, however, has been largely forgotten. While other space animals like Laika the dog and Ham the chimp have been celebrated, Félicette became a footnote of history. This is the story of the only cat to go to space.", 0.25, 4, 42],
    ["In many, many ways, fish of the species Brienomyrus brachyistius do not speak at all like Barack Obama. For starters, they communicate not through a spoken language but through electrical pulses booped out by specialized organs found near the tail. Their vocabulary is also quite unpresidentially poor, with each individual capable of producing just one electric wave—a unique but monotonous signal. “It’s even simpler than Morse code,” Bruce Carlson, a biologist at Washington University in St. Louis who studies Brienomyrus fish, told me. In at least one significant way, though, fish of the species Brienomyrus brachyistius do speak a little bit like Barack Obama. When they want to send an important message… They stop, just for a moment. Those gaps tend to occur in very particular patterns, right before fishy phrases and sentences with “high-information content” about property, say, or courtship, Carlson said. Electric fish have, like the former president, mastered the art of the dramatic pause—a rhetorical trick that can help listeners cue in more strongly to what speakers have to say next, Carlson and his colleagues report in a study published today in Current Biology.", 0.5, 1, 42],
]
placeholder_text = "Our message is simple. Because we truly believe in our peanut-loving hearts that peanuts make everything better. Peanuts are perfectly powerful because they're packed with nutrition and they bring people together. Our thirst for peanut knowledge is unquenchable, so we’re always sharing snackable news stories about the benefits of peanuts, recent stats, research, etc. Our passion for peanuts is infectious. We root for peanuts as if they were a home run away from winning it all. We care about peanuts and the people who grow them. We give shout-outs to those who lift up and promote peanuts and the peanut story. We’re an authority on peanuts and we're anything but boring."

demo = gr.Blocks(css="#output_image {width: auto; display: block; margin-left: auto; margin-right: auto;} #button {display: block; margin: 0 auto;}")

with demo:
    gr.Markdown("## PIXEL Masked Autoencoding")
    gr.Markdown("Gradio demo for [PIXEL](https://huggingface.co/Team-PIXEL/pixel-base), introduced in [Language Modelling with Pixels](https://arxiv.org/abs/2207.06991). To use it, simply input your piece of text or click one of the examples to load them. Read more at the links below.")
    with gr.Row():
        with gr.Column():
            tb_text = gr.Textbox(
                lines=1,
                label="Text",
                placeholder=placeholder_text)
            sl_ratio = gr.Slider(
                minimum=0.01,
                maximum=1.0,
                step=0.01,
                value=0.25,
                label="Span masking ratio",
            )
            sl_len = gr.Slider(
                minimum=1,
                maximum=6,
                step=1,
                value=6,
                label="Masking max span length",
            )
            sl_seed = gr.Slider(
                minimum=0,
                maximum=1000,
                step=1,
                value=42,
                label="Random seed"
            )
            with gr.Box().style(rounded=False):
                btn = gr.Button("Run", variant="primary", elem_id="button")
        with gr.Column():
            with gr.Row():
                with gr.Column():
                    with gr.Box().style(rounded=False):
                        gr.Markdown("**Original**")
                        out_original = gr.Image(
                            type="pil",
                            label="Original",
                            show_label=False,
                            elem_id="output_image"
                        )
                    with gr.Box().style(rounded=False):
                        gr.Markdown("**Masked Predictions**")
                        out_masked_pred = gr.Image(
                            type="pil",
                            label="Masked Predictions",
                            show_label=False,
                            elem_id="output_image"
                        )
                with gr.Column():
                    with gr.Box().style(rounded=False):
                        gr.Markdown("**Masked**")
                        out_masked = gr.Image(
                            type="pil",
                            label="Masked",
                            show_label=False,
                            elem_id="output_image"
                        )
                    with gr.Box().style(rounded=False):
                        gr.Markdown("**Reconstruction**")
                        out_reconstruction = gr.Image(
                            type="pil",
                            label="Reconstruction",
                            show_label=False,
                            elem_id="output_image"
                        )
    with gr.Row():
        with gr.Box().style(rounded=False):
            gr.Markdown("### Examples")
            gr_examples = gr.Examples(
                examples,
                inputs=[tb_text, sl_ratio, sl_len, sl_seed],
                outputs=[out_original, out_masked, out_masked_pred, out_reconstruction],
                fn=inference,
                cache_examples=True
            )
    gr.HTML("<p style='text-align: center'><a href='https://arxiv.org/abs/2207.06991' target='_blank'><b>Paper</b></a> | <a href='https://github.com/xplip/pixel' target='_blank'><b>Github</b></a></p>")
    gr.HTML("<center><img src='https://visitor-badge.glitch.me/badge?page_id=Team-PIXEL/PIXEL' alt='visitor badge'></center>")
    
    btn.click(fn=inference, inputs=[tb_text, sl_ratio, sl_len, sl_seed], outputs=[out_original, out_masked, out_masked_pred, out_reconstruction])
demo.launch(debug=True)