import torch import gradio as gr import random import numpy as np from PIL import Image import imagehash import cv2 import os import spaces import subprocess from transformers import AutoProcessor, AutoModelForCausalLM from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension from transformers.image_transforms import resize, to_channel_dimension_format from typing import List from PIL import Image from collections import Counter from datasets import load_dataset, concatenate_datasets subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) DEVICE = torch.device("cuda") PROCESSOR = AutoProcessor.from_pretrained( "HuggingFaceM4/idefics2_raven_finetuned", token=os.environ["HF_AUTH_TOKEN"], ) MODEL = AutoModelForCausalLM.from_pretrained( "HuggingFaceM4/idefics2_raven_finetuned", trust_remote_code=True, torch_dtype=torch.bfloat16, token=os.environ["HF_AUTH_TOKEN"], ).to(DEVICE) if MODEL.config.use_resampler: image_seq_len = MODEL.config.perceiver_config.resampler_n_latents else: image_seq_len = ( MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size ) ** 2 BOS_TOKEN = PROCESSOR.tokenizer.bos_token BAD_WORDS_IDS = PROCESSOR.tokenizer(["", ""], add_special_tokens=False).input_ids DATASET = load_dataset("HuggingFaceM4/RAVEN_rendered", split="validation") ## Utils def convert_to_rgb(image): # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background # for transparent images. The call to `alpha_composite` handles this case if image.mode == "RGB": return image image_rgba = image.convert("RGBA") background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) alpha_composite = Image.alpha_composite(background, image_rgba) alpha_composite = alpha_composite.convert("RGB") return alpha_composite # The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip, # so this is a hack in order to redefine ONLY the transform method def custom_transform(x): x = convert_to_rgb(x) x = to_numpy_array(x) x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR) x = PROCESSOR.image_processor.rescale(x, scale=1 / 255) x = PROCESSOR.image_processor.normalize( x, mean=PROCESSOR.image_processor.image_mean, std=PROCESSOR.image_processor.image_std ) x = to_channel_dimension_format(x, ChannelDimension.FIRST) x = torch.tensor(x) return x def pixel_difference(image1, image2): def color(im): arr = np.array(im).flatten() arr_list = arr.tolist() counts = Counter(arr_list) most_common = counts.most_common(2) if most_common[0][0] == 255: return most_common[1][0] else: return most_common[0][0] def canny_edges(im): im = cv2.Canny(np.array(im), 50, 100) im[im!=0] = 255 return Image.fromarray(im) def phash(im): return imagehash.phash(canny_edges(im), hash_size=32) def surface(im): return (np.array(im) != 255).sum() color_diff = np.abs(color(image1) - color(image2)) hash_diff = phash(image1) - phash(image2) surface_diff = np.abs(surface(image1) - surface(image2)) if int(hash_diff/7) < 10: return color_diff < 10 or int(surface_diff / (160 * 160) * 100) < 10 elif color_diff < 10: return int(surface_diff / (160 * 160) * 100) < 10 or int(hash_diff/7) < 10 elif int(surface_diff / (160 * 160) * 100) < 10: return int(hash_diff/7) < 10 or color_diff < 10 else: return False # End of Utils def load_sample(): n = len(DATASET) found_sample = False while not found_sample: idx = random.randint(0, n) sample = DATASET[idx] found_sample = True return sample["image"], sample["label"], "", "", "" @spaces.GPU(duration=180) def model_inference( image, ): if image is None: raise ValueError("`image` is None. It should be a PIL image.") # return "A" inputs = PROCESSOR.tokenizer( f"{BOS_TOKEN}User:{'' * image_seq_len}Which figure should complete the logical sequence?\nAssistant:", return_tensors="pt", add_special_tokens=False, ) inputs["pixel_values"] = PROCESSOR.image_processor( [image], transform=custom_transform ) inputs = { k: v.to(DEVICE) for k, v in inputs.items() } generation_kwargs = dict( inputs, bad_words_ids=BAD_WORDS_IDS, max_length=4, ) # Regular generation version generated_ids = MODEL.generate(**generation_kwargs) generated_text = PROCESSOR.batch_decode( generated_ids, skip_special_tokens=True )[0] return generated_text[-1] model_prediction = gr.TextArea( label="AI's guess", visible=True, lines=1, max_lines=1, interactive=False, ) user_prediction = gr.TextArea( label="Your guess", visible=True, lines=1, max_lines=1, interactive=False, ) result = gr.TextArea( label="Win or lose?", visible=True, lines=1, max_lines=1, interactive=False, ) css = """ .gradio-container{max-width: 1000px!important} h1{display: flex;align-items: center;justify-content: center;gap: .25em} *{transition: width 0.5s ease, flex-grow 0.5s ease} """ with gr.Blocks(title="Beat the AI", theme=gr.themes.Base(), css=css) as demo: gr.Markdown( "Are you smarter than the AI?" ) load_new_sample = gr.Button(value="Load new sample") with gr.Row(equal_height=True): with gr.Column(scale=4, min_width=250) as upload_area: imagebox = gr.Image( image_mode="L", type="pil", visible=True, sources=None, ) with gr.Column(scale=4): with gr.Row(): a = gr.Button(value="A", min_width=1) b = gr.Button(value="B", min_width=1) c = gr.Button(value="C", min_width=1) d = gr.Button(value="D", min_width=1) with gr.Row(): e = gr.Button(value="E", min_width=1) f = gr.Button(value="F", min_width=1) g = gr.Button(value="G", min_width=1) h = gr.Button(value="H", min_width=1) with gr.Row(): model_prediction.render() user_prediction.render() solution = gr.TextArea( label="Solution", visible=False, lines=1, max_lines=1, interactive=False, ) with gr.Row(): result.render() load_new_sample.click( fn=load_sample, inputs=[], outputs=[imagebox, solution, model_prediction, user_prediction, result] ) gr.on( triggers=[ a.click, b.click, c.click, d.click, e.click, f.click, g.click, h.click, ], fn=model_inference, inputs=[imagebox], outputs=[model_prediction], ).then( fn=lambda x, y, z: "🥇" if x==y else f"💩 The solution is {chr(ord('A') + int(z))}", inputs=[model_prediction, user_prediction, solution], outputs=[result], ) a.click(fn=lambda: "A", inputs=[], outputs=[user_prediction]) b.click(fn=lambda: "B", inputs=[], outputs=[user_prediction]) c.click(fn=lambda: "C", inputs=[], outputs=[user_prediction]) d.click(fn=lambda: "D", inputs=[], outputs=[user_prediction]) e.click(fn=lambda: "E", inputs=[], outputs=[user_prediction]) f.click(fn=lambda: "F", inputs=[], outputs=[user_prediction]) g.click(fn=lambda: "G", inputs=[], outputs=[user_prediction]) h.click(fn=lambda: "H", inputs=[], outputs=[user_prediction]) demo.load() demo.queue(max_size=40, api_open=False) demo.launch(max_threads=400)