ai_raven / app.py
VictorSanh's picture
add community grant
12089a1
import torch
import gradio as gr
import random
import numpy as np
from PIL import Image
import imagehash
import cv2
import os
import spaces
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
from transformers.image_transforms import resize, to_channel_dimension_format
from typing import List
from PIL import Image
from collections import Counter
from datasets import load_dataset, concatenate_datasets
DEVICE = torch.device("cuda")
PROCESSOR = AutoProcessor.from_pretrained(
"HuggingFaceM4/idefics2_raven_finetuned",
token=os.environ["HF_AUTH_TOKEN"],
)
MODEL = AutoModelForCausalLM.from_pretrained(
"HuggingFaceM4/idefics2_raven_finetuned",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
token=os.environ["HF_AUTH_TOKEN"],
).to(DEVICE)
if MODEL.config.use_resampler:
image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
else:
image_seq_len = (
MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size
) ** 2
BOS_TOKEN = PROCESSOR.tokenizer.bos_token
BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
DATASET = load_dataset("HuggingFaceM4/RAVEN_rendered", split="validation")
## Utils
def convert_to_rgb(image):
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
# for transparent images. The call to `alpha_composite` handles this case
if image.mode == "RGB":
return image
image_rgba = image.convert("RGBA")
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
alpha_composite = Image.alpha_composite(background, image_rgba)
alpha_composite = alpha_composite.convert("RGB")
return alpha_composite
# The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip,
# so this is a hack in order to redefine ONLY the transform method
def custom_transform(x):
x = convert_to_rgb(x)
x = to_numpy_array(x)
x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
x = PROCESSOR.image_processor.rescale(x, scale=1 / 255)
x = PROCESSOR.image_processor.normalize(
x,
mean=PROCESSOR.image_processor.image_mean,
std=PROCESSOR.image_processor.image_std
)
x = to_channel_dimension_format(x, ChannelDimension.FIRST)
x = torch.tensor(x)
return x
def pixel_difference(image1, image2):
def color(im):
arr = np.array(im).flatten()
arr_list = arr.tolist()
counts = Counter(arr_list)
most_common = counts.most_common(2)
if most_common[0][0] == 255:
return most_common[1][0]
else:
return most_common[0][0]
def canny_edges(im):
im = cv2.Canny(np.array(im), 50, 100)
im[im!=0] = 255
return Image.fromarray(im)
def phash(im):
return imagehash.phash(canny_edges(im), hash_size=32)
def surface(im):
return (np.array(im) != 255).sum()
color_diff = np.abs(color(image1) - color(image2))
hash_diff = phash(image1) - phash(image2)
surface_diff = np.abs(surface(image1) - surface(image2))
if int(hash_diff/7) < 10:
return color_diff < 10 or int(surface_diff / (160 * 160) * 100) < 10
elif color_diff < 10:
return int(surface_diff / (160 * 160) * 100) < 10 or int(hash_diff/7) < 10
elif int(surface_diff / (160 * 160) * 100) < 10:
return int(hash_diff/7) < 10 or color_diff < 10
else:
return False
# End of Utils
def load_sample():
n = len(DATASET)
found_sample = False
while not found_sample:
idx = random.randint(0, n)
sample = DATASET[idx]
found_sample = True
return sample["image"], sample["label"], "", "", ""
@spaces.GPU(duration=180)
def model_inference(
image,
):
if image is None:
raise ValueError("`image` is None. It should be a PIL image.")
# return "A"
inputs = PROCESSOR.tokenizer(
f"{BOS_TOKEN}User:<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>Which figure should complete the logical sequence?<end_of_utterance>\nAssistant:",
return_tensors="pt",
add_special_tokens=False,
)
inputs["pixel_values"] = PROCESSOR.image_processor(
[image],
transform=custom_transform
)
inputs = {
k: v.to(DEVICE)
for k, v in inputs.items()
}
generation_kwargs = dict(
inputs,
bad_words_ids=BAD_WORDS_IDS,
max_length=4,
)
# Regular generation version
generated_ids = MODEL.generate(**generation_kwargs)
generated_text = PROCESSOR.batch_decode(
generated_ids,
skip_special_tokens=True
)[0]
return generated_text[-1]
model_prediction = gr.TextArea(
label="AI's guess",
visible=True,
lines=1,
max_lines=1,
interactive=False,
)
user_prediction = gr.TextArea(
label="Your guess",
visible=True,
lines=1,
max_lines=1,
interactive=False,
)
result = gr.TextArea(
label="Win or lose?",
visible=True,
lines=1,
max_lines=1,
interactive=False,
)
css = """
.gradio-container{max-width: 1000px!important}
h1{display: flex;align-items: center;justify-content: center;gap: .25em}
*{transition: width 0.5s ease, flex-grow 0.5s ease}
"""
with gr.Blocks(title="Beat the AI", theme=gr.themes.Base(), css=css) as demo:
gr.Markdown(
"Are you smarter than the AI?"
)
load_new_sample = gr.Button(value="Load new sample")
with gr.Row(equal_height=True):
with gr.Column(scale=4, min_width=250) as upload_area:
imagebox = gr.Image(
image_mode="L",
type="pil",
visible=True,
sources=None,
)
with gr.Column(scale=4):
with gr.Row():
a = gr.Button(value="A", min_width=1)
b = gr.Button(value="B", min_width=1)
c = gr.Button(value="C", min_width=1)
d = gr.Button(value="D", min_width=1)
with gr.Row():
e = gr.Button(value="E", min_width=1)
f = gr.Button(value="F", min_width=1)
g = gr.Button(value="G", min_width=1)
h = gr.Button(value="H", min_width=1)
with gr.Row():
model_prediction.render()
user_prediction.render()
solution = gr.TextArea(
label="Solution",
visible=False,
lines=1,
max_lines=1,
interactive=False,
)
with gr.Row():
result.render()
load_new_sample.click(
fn=load_sample,
inputs=[],
outputs=[imagebox, solution, model_prediction, user_prediction, result]
)
gr.on(
triggers=[
a.click,
b.click,
c.click,
d.click,
e.click,
f.click,
g.click,
h.click,
],
fn=model_inference,
inputs=[imagebox],
outputs=[model_prediction],
).then(
fn=lambda x, y, z: "πŸ₯‡" if x==y else f"πŸ’© The solution is {chr(ord('A') + int(z))}",
inputs=[model_prediction, user_prediction, solution],
outputs=[result],
)
a.click(fn=lambda: "A", inputs=[], outputs=[user_prediction])
b.click(fn=lambda: "B", inputs=[], outputs=[user_prediction])
c.click(fn=lambda: "C", inputs=[], outputs=[user_prediction])
d.click(fn=lambda: "D", inputs=[], outputs=[user_prediction])
e.click(fn=lambda: "E", inputs=[], outputs=[user_prediction])
f.click(fn=lambda: "F", inputs=[], outputs=[user_prediction])
g.click(fn=lambda: "G", inputs=[], outputs=[user_prediction])
h.click(fn=lambda: "H", inputs=[], outputs=[user_prediction])
demo.load()
demo.queue(max_size=40, api_open=False)
demo.launch(max_threads=400)