VictorSanh commited on
Commit
f20057b
β€’
1 Parent(s): 157a0b7
Files changed (2) hide show
  1. app.py +261 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import random
4
+ import numpy as np
5
+ from PIL import Image
6
+ import imagehash
7
+ import cv2
8
+ import os
9
+
10
+ from transformers import AutoProcessor, AutoModelForCausalLM
11
+ from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
12
+ from transformers.image_transforms import resize, to_channel_dimension_format
13
+
14
+ from typing import List
15
+ from PIL import Image
16
+ from collections import Counter
17
+
18
+ from datasets import load_dataset, concatenate_datasets
19
+
20
+
21
+ DEVICE = torch.device("cuda")
22
+ PROCESSOR = AutoProcessor.from_pretrained(
23
+ "HuggingFaceM4/idefics2_raven_finetuned",
24
+ token=os.environ["HF_AUTH_TOKEN"],
25
+ )
26
+ MODEL = AutoModelForCausalLM.from_pretrained(
27
+ "HuggingFaceM4/idefics2_raven_finetuned",
28
+ trust_remote_code=True,
29
+ torch_dtype=torch.bfloat16,
30
+ token=os.environ["HF_AUTH_TOKEN"],
31
+ ).to(DEVICE)
32
+ if MODEL.config.use_resampler:
33
+ image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
34
+ else:
35
+ image_seq_len = (
36
+ MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size
37
+ ) ** 2
38
+ BOS_TOKEN = PROCESSOR.tokenizer.bos_token
39
+ BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
40
+ DATASET = load_dataset("HuggingFaceM4/RAVEN_rendered", split="validation")
41
+
42
+ ## Utils
43
+
44
+ def convert_to_rgb(image):
45
+ # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
46
+ # for transparent images. The call to `alpha_composite` handles this case
47
+ if image.mode == "RGB":
48
+ return image
49
+
50
+ image_rgba = image.convert("RGBA")
51
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
52
+ alpha_composite = Image.alpha_composite(background, image_rgba)
53
+ alpha_composite = alpha_composite.convert("RGB")
54
+ return alpha_composite
55
+
56
+ # The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip,
57
+ # so this is a hack in order to redefine ONLY the transform method
58
+ def custom_transform(x):
59
+ x = convert_to_rgb(x)
60
+ x = to_numpy_array(x)
61
+ x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
62
+ x = PROCESSOR.image_processor.rescale(x, scale=1 / 255)
63
+ x = PROCESSOR.image_processor.normalize(
64
+ x,
65
+ mean=PROCESSOR.image_processor.image_mean,
66
+ std=PROCESSOR.image_processor.image_std
67
+ )
68
+ x = to_channel_dimension_format(x, ChannelDimension.FIRST)
69
+ x = torch.tensor(x)
70
+ return x
71
+
72
+ def pixel_difference(image1, image2):
73
+ def color(im):
74
+ arr = np.array(im).flatten()
75
+ arr_list = arr.tolist()
76
+ counts = Counter(arr_list)
77
+ most_common = counts.most_common(2)
78
+ if most_common[0][0] == 255:
79
+ return most_common[1][0]
80
+ else:
81
+ return most_common[0][0]
82
+
83
+ def canny_edges(im):
84
+ im = cv2.Canny(np.array(im), 50, 100)
85
+ im[im!=0] = 255
86
+ return Image.fromarray(im)
87
+
88
+ def phash(im):
89
+ return imagehash.phash(canny_edges(im), hash_size=32)
90
+
91
+ def surface(im):
92
+ return (np.array(im) != 255).sum()
93
+
94
+ color_diff = np.abs(color(image1) - color(image2))
95
+ hash_diff = phash(image1) - phash(image2)
96
+ surface_diff = np.abs(surface(image1) - surface(image2))
97
+
98
+ if int(hash_diff/7) < 10:
99
+ return color_diff < 10 or int(surface_diff / (160 * 160) * 100) < 10
100
+ elif color_diff < 10:
101
+ return int(surface_diff / (160 * 160) * 100) < 10 or int(hash_diff/7) < 10
102
+ elif int(surface_diff / (160 * 160) * 100) < 10:
103
+ return int(hash_diff/7) < 10 or color_diff < 10
104
+ else:
105
+ return False
106
+
107
+ # End of Utils
108
+
109
+
110
+ def load_sample():
111
+ n = len(DATASET)
112
+ found_sample = False
113
+ while not found_sample:
114
+ idx = random.randint(0, n)
115
+ sample = DATASET[idx]
116
+ found_sample = True
117
+ return sample["image"], sample["label"], "", "", ""
118
+
119
+
120
+ # @spaces.GPU(duration=180)
121
+ def model_inference(
122
+ image,
123
+ ):
124
+ if image is None:
125
+ raise ValueError("`image` is None. It should be a PIL image.")
126
+
127
+ # return "A"
128
+ inputs = PROCESSOR.tokenizer(
129
+ f"{BOS_TOKEN}User:<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>Which figure should complete the logical sequence?<end_of_utterance>\nAssistant:",
130
+ return_tensors="pt",
131
+ add_special_tokens=False,
132
+ )
133
+ inputs["pixel_values"] = PROCESSOR.image_processor(
134
+ [image],
135
+ transform=custom_transform
136
+ )
137
+ inputs = {
138
+ k: v.to(DEVICE)
139
+ for k, v in inputs.items()
140
+ }
141
+ generation_kwargs = dict(
142
+ inputs,
143
+ bad_words_ids=BAD_WORDS_IDS,
144
+ max_length=4,
145
+ )
146
+ # Regular generation version
147
+ generated_ids = MODEL.generate(**generation_kwargs)
148
+ generated_text = PROCESSOR.batch_decode(
149
+ generated_ids,
150
+ skip_special_tokens=True
151
+ )[0]
152
+ return generated_text[-1]
153
+
154
+
155
+ model_prediction = gr.TextArea(
156
+ label="AI's guess",
157
+ visible=True,
158
+ lines=1,
159
+ max_lines=1,
160
+ interactive=False,
161
+ )
162
+ user_prediction = gr.TextArea(
163
+ label="Your guess",
164
+ visible=True,
165
+ lines=1,
166
+ max_lines=1,
167
+ interactive=False,
168
+ )
169
+ result = gr.TextArea(
170
+ label="Win or lose?",
171
+ visible=True,
172
+ lines=1,
173
+ max_lines=1,
174
+ interactive=False,
175
+ )
176
+
177
+
178
+
179
+ css = """
180
+ .gradio-container{max-width: 1000px!important}
181
+ h1{display: flex;align-items: center;justify-content: center;gap: .25em}
182
+ *{transition: width 0.5s ease, flex-grow 0.5s ease}
183
+ """
184
+
185
+
186
+ with gr.Blocks(title="Beat the AI", theme=gr.themes.Base(), css=css) as demo:
187
+ gr.Markdown(
188
+ "Are you smarter than the AI?"
189
+ )
190
+ load_new_sample = gr.Button(value="Load new sample")
191
+ with gr.Row(equal_height=True):
192
+ with gr.Column(scale=4, min_width=250) as upload_area:
193
+ imagebox = gr.Image(
194
+ image_mode="L",
195
+ type="pil",
196
+ visible=True,
197
+ sources=None,
198
+ )
199
+ with gr.Column(scale=4):
200
+ with gr.Row():
201
+ a = gr.Button(value="A", min_width=1)
202
+ b = gr.Button(value="B", min_width=1)
203
+ c = gr.Button(value="C", min_width=1)
204
+ d = gr.Button(value="D", min_width=1)
205
+ with gr.Row():
206
+ e = gr.Button(value="E", min_width=1)
207
+ f = gr.Button(value="F", min_width=1)
208
+ g = gr.Button(value="G", min_width=1)
209
+ h = gr.Button(value="H", min_width=1)
210
+ with gr.Row():
211
+ model_prediction.render()
212
+ user_prediction.render()
213
+ solution = gr.TextArea(
214
+ label="Solution",
215
+ visible=False,
216
+ lines=1,
217
+ max_lines=1,
218
+ interactive=False,
219
+ )
220
+ with gr.Row():
221
+ result.render()
222
+
223
+
224
+ load_new_sample.click(
225
+ fn=load_sample,
226
+ inputs=[],
227
+ outputs=[imagebox, solution, model_prediction, user_prediction, result]
228
+ )
229
+ gr.on(
230
+ triggers=[
231
+ a.click,
232
+ b.click,
233
+ c.click,
234
+ d.click,
235
+ e.click,
236
+ f.click,
237
+ g.click,
238
+ h.click,
239
+ ],
240
+ fn=model_inference,
241
+ inputs=[imagebox],
242
+ outputs=[model_prediction],
243
+ ).then(
244
+ fn=lambda x, y, z: "πŸ₯‡" if x==y else f"πŸ’© The solution is {chr(ord('A') + int(z))}",
245
+ inputs=[model_prediction, user_prediction, solution],
246
+ outputs=[result],
247
+ )
248
+
249
+ a.click(fn=lambda: "A", inputs=[], outputs=[user_prediction])
250
+ b.click(fn=lambda: "B", inputs=[], outputs=[user_prediction])
251
+ c.click(fn=lambda: "C", inputs=[], outputs=[user_prediction])
252
+ d.click(fn=lambda: "D", inputs=[], outputs=[user_prediction])
253
+ e.click(fn=lambda: "E", inputs=[], outputs=[user_prediction])
254
+ f.click(fn=lambda: "F", inputs=[], outputs=[user_prediction])
255
+ g.click(fn=lambda: "G", inputs=[], outputs=[user_prediction])
256
+ h.click(fn=lambda: "H", inputs=[], outputs=[user_prediction])
257
+
258
+ demo.load()
259
+
260
+ demo.queue(max_size=40, api_open=False)
261
+ demo.launch(max_threads=400)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ cv2
2
+ torch
3
+ imagehash
4
+ transformers
5
+ datasets
6
+ pillow
7
+ numpy