import torch import gradio as gr import random import numpy as np from PIL import Image import imagehash import cv2 import os import spaces import subprocess from transformers import AutoProcessor, AutoModelForCausalLM from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension from transformers.image_transforms import resize, to_channel_dimension_format from typing import List from PIL import Image from collections import Counter from datasets import load_dataset, concatenate_datasets subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) DEVICE = torch.device("cuda") PROCESSOR = AutoProcessor.from_pretrained( "HuggingFaceM4/idefics2_raven_finetuned", token=os.environ["HF_AUTH_TOKEN"], ) MODEL = AutoModelForCausalLM.from_pretrained( "HuggingFaceM4/idefics2_raven_finetuned", trust_remote_code=True, torch_dtype=torch.bfloat16, token=os.environ["HF_AUTH_TOKEN"], ).to(DEVICE) if MODEL.config.use_resampler: image_seq_len = MODEL.config.perceiver_config.resampler_n_latents else: image_seq_len = ( MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size ) ** 2 BOS_TOKEN = PROCESSOR.tokenizer.bos_token BAD_WORDS_IDS = PROCESSOR.tokenizer(["", ""], add_special_tokens=False).input_ids DATASET = load_dataset("HuggingFaceM4/RAVEN_rendered", split="validation", token=os.environ["HF_AUTH_TOKEN"]) ## Utils def convert_to_rgb(image): # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background # for transparent images. The call to `alpha_composite` handles this case if image.mode == "RGB": return image image_rgba = image.convert("RGBA") background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) alpha_composite = Image.alpha_composite(background, image_rgba) alpha_composite = alpha_composite.convert("RGB") return alpha_composite # The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip, # so this is a hack in order to redefine ONLY the transform method def custom_transform(x): x = convert_to_rgb(x) x = to_numpy_array(x) x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR) x = PROCESSOR.image_processor.rescale(x, scale=1 / 255) x = PROCESSOR.image_processor.normalize( x, mean=PROCESSOR.image_processor.image_mean, std=PROCESSOR.image_processor.image_std ) x = to_channel_dimension_format(x, ChannelDimension.FIRST) x = torch.tensor(x) return x def pixel_difference(image1, image2): def color(im): arr = np.array(im).flatten() arr_list = arr.tolist() counts = Counter(arr_list) most_common = counts.most_common(2) if most_common[0][0] == 255: return most_common[1][0] else: return most_common[0][0] def canny_edges(im): im = cv2.Canny(np.array(im), 50, 100) im[im!=0] = 255 return Image.fromarray(im) def phash(im): return imagehash.phash(canny_edges(im), hash_size=32) def surface(im): return (np.array(im) != 255).sum() color_diff = np.abs(color(image1) - color(image2)) hash_diff = phash(image1) - phash(image2) surface_diff = np.abs(surface(image1) - surface(image2)) if int(hash_diff/7) < 10: return color_diff < 10 or int(surface_diff / (160 * 160) * 100) < 10 elif color_diff < 10: return int(surface_diff / (160 * 160) * 100) < 10 or int(hash_diff/7) < 10 elif int(surface_diff / (160 * 160) * 100) < 10: return int(hash_diff/7) < 10 or color_diff < 10 else: return False # End of Utils def load_sample(): n = len(DATASET) found_sample = False while not found_sample: idx = random.randint(0, n) sample = DATASET[idx] found_sample = True return sample["image"], sample["label"], "", "", "" @spaces.GPU(duration=180) def model_inference( image, ): if image is None: raise gr.Error("Load a new sample first!") # return "A" inputs = PROCESSOR.tokenizer( f"{BOS_TOKEN}User:{'' * image_seq_len}Which figure should complete the logical sequence?\nAssistant:", return_tensors="pt", add_special_tokens=False, ) inputs["pixel_values"] = PROCESSOR.image_processor( [image], transform=custom_transform ) inputs = { k: v.to(DEVICE) for k, v in inputs.items() } generation_kwargs = dict( inputs, bad_words_ids=BAD_WORDS_IDS, max_length=128, ) # Regular generation version generated_ids = MODEL.generate(**generation_kwargs) generated_text = PROCESSOR.batch_decode( generated_ids, skip_special_tokens=True )[0] return generated_text[-1] model_prediction = gr.Label( label="AI's guess", visible=True, ) user_prediction = gr.Label( label="Your guess", visible=True, ) result = gr.TextArea( label="Win or lose?", visible=True, lines=1, max_lines=1, interactive=False, ) css = """ .gradio-container{max-width: 1000px!important} h1{display: flex;align-items: center;justify-content: center;gap: .25em} *{transition: width 0.5s ease, flex-grow 0.5s ease} """ with gr.Blocks(title="Beat the AI", theme=gr.themes.Base(), css=css) as demo: gr.Markdown( """ # Can you beat the AI at RAVEN puzzles? *This demo features an early fine-tuned version of our forthcoming Idefics2 model (read about idefics1 [here](https://huggingface.co/HuggingFaceM4/idefics-80b-instruct). The model was specifically fine-tuned on the [RAVEN](https://huggingface.co/datasets/HuggingFaceM4/RAVEN) dataset and reached 91% accuracy on the validation set.* RAVE Progressive Matrices are abstract visual reasoning puzzles. The panels describe logical sequences of shapes and colors (row by row). One is asked to find the option that completes the 3rd sequence following the same logic described by the first two sequences. We recommend looking at the images on a full screen with enough brightness given that some options differ by small differences in sizes and nuances of colors. To get started, load a new puzzle. 🧠 """ ) load_new_sample = gr.Button(value="Load a new puzzle") with gr.Row(equal_height=True): with gr.Column(scale=4, min_width=250) as upload_area: imagebox = gr.Image( image_mode="L", type="pil", visible=True, sources=[], ) with gr.Column(scale=4): with gr.Row(): a = gr.Button(value="A", min_width=1) b = gr.Button(value="B", min_width=1) c = gr.Button(value="C", min_width=1) d = gr.Button(value="D", min_width=1) with gr.Row(): e = gr.Button(value="E", min_width=1) f = gr.Button(value="F", min_width=1) g = gr.Button(value="G", min_width=1) h = gr.Button(value="H", min_width=1) with gr.Row(): user_prediction.render() model_prediction.render() solution = gr.TextArea( label="Solution", visible=False, lines=1, max_lines=1, interactive=False, ) with gr.Row(): result.render() def result_string(model_pred, user_pred, solution): if solution == "": return "" solution_letter = chr(ord('A') + int(solution)) solution_string = f"The correct answer is {solution_letter}." win_or_loose = "🥇" if user_pred == solution_letter else "🙈" if user_pred == solution_letter and model_pred == solution_letter: comparison_string = "Both you and the AI got it correctly!" elif user_pred == solution_letter and model_pred != solution_letter: comparison_string = "You beat the AI!" elif user_pred != solution_letter and model_pred != solution_letter: comparison_string = "Both you and the AI got it wrong!" elif user_pred != solution_letter and model_pred == solution_letter: comparison_string = "The AI beat you!" return f"{win_or_loose} {comparison_string} {solution_string}" load_new_sample.click( fn=load_sample, inputs=[], outputs=[imagebox, solution, model_prediction, user_prediction, result] ) gr.on( triggers=[ a.click, b.click, c.click, d.click, e.click, f.click, g.click, h.click, ], fn=model_inference, inputs=[imagebox], outputs=[model_prediction], ).then( fn=result_string, inputs=[model_prediction, user_prediction, solution], outputs=[result], ) a.click(fn=lambda: "A", inputs=[], outputs=[user_prediction]) b.click(fn=lambda: "B", inputs=[], outputs=[user_prediction]) c.click(fn=lambda: "C", inputs=[], outputs=[user_prediction]) d.click(fn=lambda: "D", inputs=[], outputs=[user_prediction]) e.click(fn=lambda: "E", inputs=[], outputs=[user_prediction]) f.click(fn=lambda: "F", inputs=[], outputs=[user_prediction]) g.click(fn=lambda: "G", inputs=[], outputs=[user_prediction]) h.click(fn=lambda: "H", inputs=[], outputs=[user_prediction]) demo.load() demo.queue(max_size=40, api_open=False) demo.launch(max_threads=400)