plip commited on
Commit
e126020
1 Parent(s): 5a49eaa
Files changed (4) hide show
  1. README.md +4 -4
  2. app.py +192 -0
  3. packages.txt +8 -0
  4. requirements.txt +23 -0
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
  title: PIXEL
3
- emoji: 🦀
4
- colorFrom: indigo
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 3.1.1
8
  app_file: app.py
9
- pinned: false
10
  license: apache-2.0
11
  ---
12
 
1
  ---
2
  title: PIXEL
3
+ emoji: 🐱
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.1.1
8
  app_file: app.py
9
+ pinned: true
10
  license: apache-2.0
11
  ---
12
 
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import sys
4
+
5
+ import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from PIL import Image
9
+ import torch
10
+ import torchvision
11
+
12
+ os.system("git clone https://github.com/xplip/pixel.git")
13
+ sys.path.append('./pixel')
14
+
15
+ from transformers import set_seed
16
+ from pixel.src.pixel import (
17
+ PIXELConfig,
18
+ PIXELForPreTraining,
19
+ SpanMaskingGenerator,
20
+ PyGameTextRenderer,
21
+ get_transforms,
22
+ resize_model_embeddings,
23
+ truncate_decoder_pos_embeddings,
24
+ get_attention_mask
25
+ )
26
+
27
+ model_name_or_path = "Team-PIXEL/pixel-base"
28
+ max_seq_length = 529
29
+ text_renderer = PyGameTextRenderer.from_pretrained(model_name_or_path, max_seq_length=max_seq_length)
30
+ config = PIXELConfig.from_pretrained(model_name_or_path)
31
+ model = PIXELForPreTraining.from_pretrained(model_name_or_path, config=config)
32
+
33
+ def clip(x: torch.Tensor):
34
+ x = torch.einsum("chw->hwc", x)
35
+ x = torch.clip(x * 255, 0, 255)
36
+ x = torch.einsum("hwc->chw", x)
37
+ return x
38
+
39
+ def get_image(img: torch.Tensor, do_clip: bool = True):
40
+ if do_clip:
41
+ img = clip(img)
42
+ img = torchvision.utils.make_grid(img, normalize=True)
43
+ image = Image.fromarray(
44
+ img.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
45
+ )
46
+ return image
47
+
48
+ def inference(text: str, mask_ratio: float = 0.25, max_span_length: int = 6, seed: int = 42):
49
+ config.update({"mask_ratio": mask_ratio})
50
+ resize_model_embeddings(model, max_seq_length)
51
+ truncate_decoder_pos_embeddings(model, max_seq_length)
52
+
53
+ set_seed(seed)
54
+
55
+ transforms = get_transforms(
56
+ do_resize=True,
57
+ size=(text_renderer.pixels_per_patch, text_renderer.pixels_per_patch * text_renderer.max_seq_length),
58
+ )
59
+
60
+ encoding = text_renderer(text=text)
61
+ attention_mask = get_attention_mask(
62
+ num_text_patches=encoding.num_text_patches, seq_length=text_renderer.max_seq_length
63
+ )
64
+
65
+ img = transforms(Image.fromarray(encoding.pixel_values)).unsqueeze(0)
66
+ attention_mask = attention_mask.unsqueeze(0)
67
+ inputs = {"pixel_values": img.float(), "attention_mask": attention_mask}
68
+
69
+ mask_generator = SpanMaskingGenerator(
70
+ num_patches=text_renderer.max_seq_length,
71
+ num_masking_patches=math.ceil(mask_ratio * text_renderer.max_seq_length),
72
+ max_span_length=max_span_length,
73
+ spacing="span"
74
+ )
75
+ mask = torch.tensor(mask_generator(num_text_patches=(encoding.num_text_patches + 1))).unsqueeze(0)
76
+ inputs.update({"patch_mask": mask})
77
+
78
+ model.eval()
79
+ with torch.no_grad():
80
+ outputs = model(**inputs)
81
+
82
+ predictions = model.unpatchify(outputs["logits"]).detach().cpu().squeeze()
83
+
84
+ mask = outputs["mask"].detach().cpu()
85
+ mask = mask.unsqueeze(-1).repeat(1, 1, text_renderer.pixels_per_patch ** 2 * 3)
86
+ mask = model.unpatchify(mask).squeeze() # 1 is removing, 0 is keeping
87
+
88
+ attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, text_renderer.pixels_per_patch ** 2 * 3)
89
+ attention_mask = model.unpatchify(attention_mask).squeeze()
90
+
91
+ original_img = model.unpatchify(model.patchify(img)).squeeze()
92
+
93
+ im_masked = original_img * (1 - (torch.bitwise_and(mask == 1, attention_mask == 1)).long())
94
+
95
+ masked_predictions = predictions * mask * attention_mask
96
+
97
+ reconstruction = im_masked + masked_predictions
98
+
99
+ return [get_image(original_img), get_image(im_masked), get_image(masked_predictions, do_clip=False), get_image(reconstruction, do_clip=False)]
100
+
101
+ examples = [
102
+ ["Penguins are designed to be streamlined and hydrodynamic, so having long legs would add extra drag. Having short legs with webbed feet to act like rudders, helps to give them that torpedo-like figure. If we compare bird anatomy with humans, we would see something a bit peculiar. By taking a look at the side-by-side image in Figure 1, you can see how their leg bones compare to ours. What most people mistake for knees are actually the ankles of the birds. This gives the illusion that bird knees bend opposite of ours. The knees are actually tucked up inside the body cavity of the bird! So how does this look inside of a penguin? In the images below, you can see boxes surrounding the penguins’ knees.", 0.2, 6, 42],
103
+ ["Félicette didn’t seem like a typical astronaut. She weighed just five and a half pounds. She’d spent most of her life on the streets of Paris. And Félicette was a cat, one of 14 trained by French scientists for space flight. In 1963, she went where no feline had gone before. Chosen for her calm demeanor and low weight, Félicette was strapped into a rocket in October of that year. She spent 15 minutes on a dizzying flight to the stars before returning safely to earth. Her legacy, however, has been largely forgotten. While other space animals like Laika the dog and Ham the chimp have been celebrated, Félicette became a footnote of history. This is the story of the only cat to go to space.", 0.25, 4, 42],
104
+ ["In many, many ways, fish of the species Brienomyrus brachyistius do not speak at all like Barack Obama. For starters, they communicate not through a spoken language but through electrical pulses booped out by specialized organs found near the tail. Their vocabulary is also quite unpresidentially poor, with each individual capable of producing just one electric wave—a unique but monotonous signal. “It’s even simpler than Morse code,” Bruce Carlson, a biologist at Washington University in St. Louis who studies Brienomyrus fish, told me. In at least one significant way, though, fish of the species Brienomyrus brachyistius do speak a little bit like Barack Obama. When they want to send an important message… They stop, just for a moment. Those gaps tend to occur in very particular patterns, right before fishy phrases and sentences with “high-information content” about property, say, or courtship, Carlson said. Electric fish have, like the former president, mastered the art of the dramatic pause—a rhetorical trick that can help listeners cue in more strongly to what speakers have to say next, Carlson and his colleagues report in a study published today in Current Biology.", 0.5, 1, 42],
105
+ ]
106
+ placeholder_text = "Our message is simple. Because we truly believe in our peanut-loving hearts that peanuts make everything better. Peanuts are perfectly powerful because they're packed with nutrition and they bring people together. Our thirst for peanut knowledge is unquenchable, so we’re always sharing snackable news stories about the benefits of peanuts, recent stats, research, etc. Our passion for peanuts is infectious. We root for peanuts as if they were a home run away from winning it all. We care about peanuts and the people who grow them. We give shout-outs to those who lift up and promote peanuts and the peanut story. We’re an authority on peanuts and we're anything but boring."
107
+
108
+ demo = gr.Blocks(css="#output_image {width: auto; display: block; margin-left: auto; margin-right: auto;} #button {display: block; margin: 0 auto;}")
109
+
110
+ with demo:
111
+ gr.Markdown("## PIXEL Masked Autoencoding")
112
+ gr.Markdown("Gradio demo for [PIXEL](https://huggingface.co/Team-PIXEL/pixel-base), introduced in [Language Modelling with Pixels](https://arxiv.org/abs/2207.06991). To use it, simply input your piece of text or click one of the examples to load them. Read more at the links below.")
113
+ with gr.Row():
114
+ with gr.Column():
115
+ tb_text = gr.Textbox(
116
+ lines=1,
117
+ label="Text",
118
+ placeholder=placeholder_text)
119
+ sl_ratio = gr.Slider(
120
+ minimum=0.01,
121
+ maximum=1.0,
122
+ step=0.01,
123
+ value=0.25,
124
+ label="Span masking ratio",
125
+ )
126
+ sl_len = gr.Slider(
127
+ minimum=1,
128
+ maximum=6,
129
+ step=1,
130
+ value=6,
131
+ label="Masking max span length",
132
+ )
133
+ sl_seed = gr.Slider(
134
+ minimum=0,
135
+ maximum=1000,
136
+ step=1,
137
+ value=42,
138
+ label="Random seed"
139
+ )
140
+ with gr.Box().style(rounded=False):
141
+ btn = gr.Button("Run", variant="primary", elem_id="button")
142
+ with gr.Column():
143
+ with gr.Row():
144
+ with gr.Column():
145
+ with gr.Box().style(rounded=False):
146
+ gr.Markdown("**Original**")
147
+ out_original = gr.Image(
148
+ type="pil",
149
+ label="Original",
150
+ show_label=False,
151
+ elem_id="output_image"
152
+ )
153
+ with gr.Box().style(rounded=False):
154
+ gr.Markdown("**Masked Predictions**")
155
+ out_masked_pred = gr.Image(
156
+ type="pil",
157
+ label="Masked Predictions",
158
+ show_label=False,
159
+ elem_id="output_image"
160
+ )
161
+ with gr.Column():
162
+ with gr.Box().style(rounded=False):
163
+ gr.Markdown("**Masked**")
164
+ out_masked = gr.Image(
165
+ type="pil",
166
+ label="Masked",
167
+ show_label=False,
168
+ elem_id="output_image"
169
+ )
170
+ with gr.Box().style(rounded=False):
171
+ gr.Markdown("**Reconstruction**")
172
+ out_reconstruction = gr.Image(
173
+ type="pil",
174
+ label="Reconstruction",
175
+ show_label=False,
176
+ elem_id="output_image"
177
+ )
178
+ with gr.Row():
179
+ with gr.Box().style(rounded=False):
180
+ gr.Markdown("### Examples")
181
+ gr_examples = gr.Examples(
182
+ examples,
183
+ inputs=[tb_text, sl_ratio, sl_len, sl_seed],
184
+ outputs=[out_original, out_masked, out_masked_pred, out_reconstruction],
185
+ fn=inference,
186
+ cache_examples=True
187
+ )
188
+ gr.HTML("<p style='text-align: center'><a href='https://arxiv.org/abs/2207.06991' target='_blank'><b>Paper</b></a> | <a href='https://github.com/xplip/pixel' target='_blank'><b>Github</b></a></p>")
189
+ gr.HTML("<center><img src='https://visitor-badge.glitch.me/badge?page_id=Team-PIXEL/PIXEL' alt='visitor badge'></center>")
190
+
191
+ btn.click(fn=inference, inputs=[tb_text, sl_ratio, sl_len, sl_seed], outputs=[out_original, out_masked, out_masked_pred, out_reconstruction])
192
+ demo.launch(debug=True)
packages.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ libgirepository1.0-dev
2
+ gcc
3
+ libcairo2-dev
4
+ pkg-config
5
+ python3-dev
6
+ gir1.2-gtk-3.0
7
+ libpango1.0-dev
8
+ libpangocairo-1.0-0
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bs4
2
+ fonttools
3
+ gradio
4
+ manimpango
5
+ matplotlib
6
+ nltk
7
+ numpy
8
+ opencv-python-headless
9
+ pandas
10
+ pyarrow
11
+ pycairo
12
+ pygame
13
+ PyGObject
14
+ pyyaml
15
+ scipy
16
+ seqeval
17
+ sklearn
18
+ spacy
19
+ submitit
20
+ torch
21
+ torchvision
22
+ transformers==4.17.0
23
+ wandb