Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
import random | |
import numpy as np | |
from PIL import Image | |
import imagehash | |
import cv2 | |
import os | |
import spaces | |
import subprocess | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension | |
from transformers.image_transforms import resize, to_channel_dimension_format | |
from typing import List | |
from PIL import Image | |
from collections import Counter | |
from datasets import load_dataset, concatenate_datasets | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
DEVICE = torch.device("cuda") | |
PROCESSOR = AutoProcessor.from_pretrained( | |
"HuggingFaceM4/idefics2_raven_finetuned", | |
token=os.environ["HF_AUTH_TOKEN"], | |
) | |
MODEL = AutoModelForCausalLM.from_pretrained( | |
"HuggingFaceM4/idefics2_raven_finetuned", | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16, | |
token=os.environ["HF_AUTH_TOKEN"], | |
).to(DEVICE) | |
if MODEL.config.use_resampler: | |
image_seq_len = MODEL.config.perceiver_config.resampler_n_latents | |
else: | |
image_seq_len = ( | |
MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size | |
) ** 2 | |
BOS_TOKEN = PROCESSOR.tokenizer.bos_token | |
BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids | |
DATASET = load_dataset("HuggingFaceM4/RAVEN_rendered", split="validation", token=os.environ["HF_AUTH_TOKEN"]) | |
## Utils | |
def convert_to_rgb(image): | |
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background | |
# for transparent images. The call to `alpha_composite` handles this case | |
if image.mode == "RGB": | |
return image | |
image_rgba = image.convert("RGBA") | |
background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) | |
alpha_composite = Image.alpha_composite(background, image_rgba) | |
alpha_composite = alpha_composite.convert("RGB") | |
return alpha_composite | |
# The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip, | |
# so this is a hack in order to redefine ONLY the transform method | |
def custom_transform(x): | |
x = convert_to_rgb(x) | |
x = to_numpy_array(x) | |
x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR) | |
x = PROCESSOR.image_processor.rescale(x, scale=1 / 255) | |
x = PROCESSOR.image_processor.normalize( | |
x, | |
mean=PROCESSOR.image_processor.image_mean, | |
std=PROCESSOR.image_processor.image_std | |
) | |
x = to_channel_dimension_format(x, ChannelDimension.FIRST) | |
x = torch.tensor(x) | |
return x | |
def pixel_difference(image1, image2): | |
def color(im): | |
arr = np.array(im).flatten() | |
arr_list = arr.tolist() | |
counts = Counter(arr_list) | |
most_common = counts.most_common(2) | |
if most_common[0][0] == 255: | |
return most_common[1][0] | |
else: | |
return most_common[0][0] | |
def canny_edges(im): | |
im = cv2.Canny(np.array(im), 50, 100) | |
im[im!=0] = 255 | |
return Image.fromarray(im) | |
def phash(im): | |
return imagehash.phash(canny_edges(im), hash_size=32) | |
def surface(im): | |
return (np.array(im) != 255).sum() | |
color_diff = np.abs(color(image1) - color(image2)) | |
hash_diff = phash(image1) - phash(image2) | |
surface_diff = np.abs(surface(image1) - surface(image2)) | |
if int(hash_diff/7) < 10: | |
return color_diff < 10 or int(surface_diff / (160 * 160) * 100) < 10 | |
elif color_diff < 10: | |
return int(surface_diff / (160 * 160) * 100) < 10 or int(hash_diff/7) < 10 | |
elif int(surface_diff / (160 * 160) * 100) < 10: | |
return int(hash_diff/7) < 10 or color_diff < 10 | |
else: | |
return False | |
# End of Utils | |
def load_sample(): | |
n = len(DATASET) | |
found_sample = False | |
while not found_sample: | |
idx = random.randint(0, n) | |
sample = DATASET[idx] | |
found_sample = True | |
return sample["image"], sample["label"], "", "", "" | |
def model_inference( | |
image, | |
): | |
if image is None: | |
raise ValueError("`image` is None. It should be a PIL image.") | |
# return "A" | |
inputs = PROCESSOR.tokenizer( | |
f"{BOS_TOKEN}User:<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>Which figure should complete the logical sequence?<end_of_utterance>\nAssistant:", | |
return_tensors="pt", | |
add_special_tokens=False, | |
) | |
inputs["pixel_values"] = PROCESSOR.image_processor( | |
[image], | |
transform=custom_transform | |
) | |
inputs = { | |
k: v.to(DEVICE) | |
for k, v in inputs.items() | |
} | |
generation_kwargs = dict( | |
inputs, | |
bad_words_ids=BAD_WORDS_IDS, | |
max_length=128, | |
) | |
# Regular generation version | |
generated_ids = MODEL.generate(**generation_kwargs) | |
generated_text = PROCESSOR.batch_decode( | |
generated_ids, | |
skip_special_tokens=True | |
)[0] | |
return generated_text[-1] | |
model_prediction = gr.TextArea( | |
label="AI's guess", | |
visible=True, | |
lines=1, | |
max_lines=1, | |
interactive=False, | |
) | |
user_prediction = gr.TextArea( | |
label="Your guess", | |
visible=True, | |
lines=1, | |
max_lines=1, | |
interactive=False, | |
) | |
result = gr.TextArea( | |
label="Win or lose?", | |
visible=True, | |
lines=1, | |
max_lines=1, | |
interactive=False, | |
) | |
css = """ | |
.gradio-container{max-width: 1000px!important} | |
h1{display: flex;align-items: center;justify-content: center;gap: .25em} | |
*{transition: width 0.5s ease, flex-grow 0.5s ease} | |
""" | |
with gr.Blocks(title="Beat the AI", theme=gr.themes.Base(), css=css) as demo: | |
gr.Markdown( | |
"Are you smarter than the AI?" | |
) | |
load_new_sample = gr.Button(value="Load new sample") | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=4, min_width=250) as upload_area: | |
imagebox = gr.Image( | |
image_mode="L", | |
type="pil", | |
visible=True, | |
sources=None, | |
) | |
with gr.Column(scale=4): | |
with gr.Row(): | |
a = gr.Button(value="A", min_width=1) | |
b = gr.Button(value="B", min_width=1) | |
c = gr.Button(value="C", min_width=1) | |
d = gr.Button(value="D", min_width=1) | |
with gr.Row(): | |
e = gr.Button(value="E", min_width=1) | |
f = gr.Button(value="F", min_width=1) | |
g = gr.Button(value="G", min_width=1) | |
h = gr.Button(value="H", min_width=1) | |
with gr.Row(): | |
user_prediction.render() | |
model_prediction.render() | |
solution = gr.TextArea( | |
label="Solution", | |
visible=False, | |
lines=1, | |
max_lines=1, | |
interactive=False, | |
) | |
with gr.Row(): | |
result.render() | |
def result_string(model_pred, user_pred, solution): | |
solution_letter = chr(ord('A') + int(z)) | |
solution_string = f"The correct answer is f{solution_letter}." | |
win_or_loose = "π₯" if user_pred == solution_letter else "π" | |
if user_pred == solution_letter and model_pred == solution_letter: | |
comparison_string = "Both you and the AI got it correctly!" | |
elif user_pred == solution_letter and model_pred != solution_letter: | |
comparison_string = "You beat the AI!" | |
elif user_pred != solution_letter and model_pred != solution_letter: | |
comparison_string = "Both you and the AI got it wrong!" | |
return f"{win_or_loose} {comparison_string}\n{solution_string}" | |
load_new_sample.click( | |
fn=load_sample, | |
inputs=[], | |
outputs=[imagebox, solution, model_prediction, user_prediction, result] | |
) | |
gr.on( | |
triggers=[ | |
a.click, | |
b.click, | |
c.click, | |
d.click, | |
e.click, | |
f.click, | |
g.click, | |
h.click, | |
], | |
fn=model_inference, | |
inputs=[imagebox], | |
outputs=[model_prediction], | |
).then( | |
fn=result_string, | |
inputs=[model_prediction, user_prediction, solution], | |
outputs=[result], | |
) | |
a.click(fn=lambda: "A", inputs=[], outputs=[user_prediction]) | |
b.click(fn=lambda: "B", inputs=[], outputs=[user_prediction]) | |
c.click(fn=lambda: "C", inputs=[], outputs=[user_prediction]) | |
d.click(fn=lambda: "D", inputs=[], outputs=[user_prediction]) | |
e.click(fn=lambda: "E", inputs=[], outputs=[user_prediction]) | |
f.click(fn=lambda: "F", inputs=[], outputs=[user_prediction]) | |
g.click(fn=lambda: "G", inputs=[], outputs=[user_prediction]) | |
h.click(fn=lambda: "H", inputs=[], outputs=[user_prediction]) | |
demo.load() | |
demo.queue(max_size=40, api_open=False) | |
demo.launch(max_threads=400) |