|
import copy |
|
import hashlib |
|
import os |
|
import re |
|
import spaces |
|
import subprocess |
|
import torch |
|
import PIL |
|
|
|
from pathlib import Path |
|
from threading import Thread |
|
from typing import List, Optional, Tuple |
|
from urllib.parse import urlparse |
|
from PIL import Image |
|
|
|
import gradio as gr |
|
from gradio import processing_utils |
|
from gradio_client.client import DEFAULT_TEMP_DIR |
|
from transformers import AutoProcessor, AutoModelForCausalLM, TextIteratorStreamer, logging |
|
|
|
from utils import create_model_inputs |
|
|
|
|
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
DEVICE = torch.device("cuda") |
|
MODELS = { |
|
"284 - neftune - opt 18'500": AutoModelForCausalLM.from_pretrained( |
|
"HuggingFaceM4/idefics2", |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
token=os.environ["HF_AUTH_TOKEN"], |
|
revision="1e05755c1c5cb2077a0f60b83ea1368c22a17282", |
|
).to(DEVICE), |
|
"279bis - baseline - opt 18'500": AutoModelForCausalLM.from_pretrained( |
|
"HuggingFaceM4/idefics2", |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
token=os.environ["HF_AUTH_TOKEN"], |
|
revision="5cd3c3a3eb5e0ea664f5ac09e73c9ef42da93a86", |
|
).to(DEVICE), |
|
"286 - mix6 tables - opt 20'000": AutoModelForCausalLM.from_pretrained( |
|
"HuggingFaceM4/idefics2", |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
token=os.environ["HF_AUTH_TOKEN"], |
|
revision="b473d49caa964991b40b79fe7cb27d51d4d023f6", |
|
).to(DEVICE), |
|
"285 - continued pretraining on text sft - opt 2'000": AutoModelForCausalLM.from_pretrained( |
|
"HuggingFaceM4/idefics2", |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
token=os.environ["HF_AUTH_TOKEN"], |
|
revision="b0a2a564e5dc311591886bb375e8d5a1aeaade83", |
|
).to(DEVICE), |
|
} |
|
PROCESSOR = AutoProcessor.from_pretrained( |
|
"HuggingFaceM4/idefics2", |
|
token=os.environ["HF_AUTH_TOKEN"], |
|
) |
|
FAKE_TOK_AROUND_IMAGE = "<fake_token_around_image>" |
|
BOS_TOKEN = PROCESSOR.tokenizer.bos_token |
|
BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids |
|
EOS_WORDS_IDS = PROCESSOR.tokenizer(["<end_of_utterance>", "\nUser:"], add_special_tokens=False).input_ids |
|
IMAGE_SEQ_LEN = list(MODELS.values())[0].config.perceiver_config.resampler_n_latents |
|
|
|
SYSTEM_PROMPT = [ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
API_TOKEN = os.getenv("HF_AUTH_TOKEN") |
|
|
|
BOT_AVATAR = "IDEFICS_logo.png" |
|
|
|
|
|
|
|
def hash_bytes(bytes: bytes): |
|
sha1 = hashlib.sha1() |
|
sha1.update(bytes) |
|
return sha1.hexdigest() |
|
|
|
|
|
def pil_to_temp_file(img: PIL.Image.Image, dir: str = DEFAULT_TEMP_DIR, format: str = "png") -> str: |
|
"""Save a PIL image into a temp file""" |
|
bytes_data = processing_utils.encode_pil_to_bytes(img, format) |
|
temp_dir = Path(dir) / hash_bytes(bytes_data) |
|
temp_dir.mkdir(exist_ok=True, parents=True) |
|
filename = str(temp_dir / f"image.{format}") |
|
if not os.path.exists(filename): |
|
img.save(filename, pnginfo=processing_utils.get_pil_metadata(img)) |
|
return filename |
|
|
|
|
|
def add_file(file): |
|
return file.name, gr.update(label='πΌοΈ Uploaded!') |
|
|
|
|
|
|
|
def split_str_on_im_markdown(string: str) -> List[str]: |
|
""" |
|
Extract from a string (typically the user prompt string) the potential images from markdown |
|
Examples: |
|
- `User:![](/file=/my_temp/chicken_on_money.png)Describe this image.` would become `["User:", "/my_temp/chicken_on_money.png", "Describe this image."]` |
|
""" |
|
IMAGES_PATTERN = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") |
|
parts = [] |
|
cursor = 0 |
|
for pattern in IMAGES_PATTERN.finditer(string): |
|
start = pattern.start() |
|
if start != cursor: |
|
parts.append(string[cursor:start]) |
|
image_url = pattern.group(1) |
|
if image_url.startswith("/file="): |
|
image_url = image_url[6:] |
|
parts.append(image_url) |
|
cursor = pattern.end() |
|
if cursor != len(string): |
|
parts.append(string[cursor:]) |
|
return parts |
|
|
|
|
|
def is_image(string: str) -> bool: |
|
""" |
|
There are two ways for images: local image path or url. |
|
""" |
|
return is_url(string) or string.startswith(DEFAULT_TEMP_DIR) |
|
|
|
|
|
def is_url(string: str) -> bool: |
|
""" |
|
Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately |
|
invalidated the url |
|
""" |
|
if " " in string: |
|
return False |
|
result = urlparse(string) |
|
return all([result.scheme, result.netloc]) |
|
|
|
|
|
def isolate_images_urls(prompt_list: List) -> List: |
|
""" |
|
Convert a full string prompt to the list format expected by the processor. |
|
In particular, image urls (as delimited by <fake_token_around_image>) should be their own elements. |
|
From: |
|
``` |
|
[ |
|
"bonjour<fake_token_around_image><image:IMG_URL><fake_token_around_image>hello", |
|
PIL.Image.Image, |
|
"Aurevoir", |
|
] |
|
``` |
|
to: |
|
``` |
|
[ |
|
"bonjour", |
|
IMG_URL, |
|
"hello", |
|
PIL.Image.Image, |
|
"Aurevoir", |
|
] |
|
``` |
|
""" |
|
linearized_list = [] |
|
for prompt in prompt_list: |
|
|
|
if isinstance(prompt, PIL.Image.Image): |
|
linearized_list.append(prompt) |
|
elif isinstance(prompt, str): |
|
if "<fake_token_around_image>" not in prompt: |
|
linearized_list.append(prompt) |
|
else: |
|
prompt_splitted = prompt.split("<fake_token_around_image>") |
|
for ps in prompt_splitted: |
|
if ps == "": |
|
continue |
|
if ps.startswith("<image:"): |
|
linearized_list.append(ps[7:-1]) |
|
else: |
|
linearized_list.append(ps) |
|
else: |
|
raise TypeError( |
|
f"Unrecognized type for `prompt`. Got {type(type(prompt))}. Was expecting something in [`str`," |
|
" `PIL.Image.Image`]" |
|
) |
|
return linearized_list |
|
|
|
|
|
def fetch_images(url_list: str) -> PIL.Image.Image: |
|
"""Fetching images""" |
|
return PROCESSOR.image_processor.fetch_images(url_list) |
|
|
|
|
|
def handle_manual_images_in_user_prompt(user_prompt: str) -> List[str]: |
|
""" |
|
Handle the case of textually manually inputted images (i.e. the `<fake_token_around_image><image:IMG_URL><fake_token_around_image>`) in the user prompt |
|
by fetching them, saving them locally and replacing the whole sub-sequence the image local path. |
|
""" |
|
if "<fake_token_around_image>" in user_prompt: |
|
splitted_user_prompt = isolate_images_urls([user_prompt]) |
|
resulting_user_prompt = [] |
|
for u_p in splitted_user_prompt: |
|
if is_url(u_p): |
|
img = fetch_images([u_p])[0] |
|
tmp_file = pil_to_temp_file(img) |
|
resulting_user_prompt.append(tmp_file) |
|
else: |
|
resulting_user_prompt.append(u_p) |
|
return resulting_user_prompt |
|
else: |
|
return [user_prompt] |
|
|
|
|
|
def prompt_list_to_markdown(prompt_list: List[str]) -> str: |
|
""" |
|
Convert a user prompt in the list format (i.e. elements are either a PIL image or a string) into |
|
the markdown format that is used for the chatbot history and rendering. |
|
""" |
|
resulting_string = "" |
|
for elem in prompt_list: |
|
if is_image(elem): |
|
if is_url(elem): |
|
resulting_string += f"![]({elem})" |
|
else: |
|
resulting_string += f"![](/file={elem})" |
|
else: |
|
resulting_string += elem |
|
return resulting_string |
|
|
|
|
|
def prompt_list_to_model_input(prompt_list: List[str]) -> Tuple[str, List[Image.Image]]: |
|
""" |
|
Create the final input string and image list to feed to the model's processor. |
|
""" |
|
images = [] |
|
for idx, part in enumerate(prompt_list): |
|
if is_image(part): |
|
if is_url(part): |
|
images.append(fetch_images([part])[0]) |
|
else: |
|
images.append(Image.open(part)) |
|
prompt_list[idx] = f"{FAKE_TOK_AROUND_IMAGE}{'<image>' * IMAGE_SEQ_LEN}{FAKE_TOK_AROUND_IMAGE}" |
|
input_text = "".join(prompt_list) |
|
input_text = input_text.replace(FAKE_TOK_AROUND_IMAGE * 2, FAKE_TOK_AROUND_IMAGE) |
|
input_text = BOS_TOKEN + input_text.strip() |
|
return input_text, images |
|
|
|
|
|
def remove_spaces_around_token(text: str) -> str: |
|
pattern = r"\s*(<fake_token_around_image>)\s*" |
|
replacement = r"\1" |
|
result = re.sub(pattern, replacement, text) |
|
return result |
|
|
|
|
|
|
|
def format_user_prompt_with_im_history_and_system_conditioning( |
|
current_user_prompt_str: str, current_image: Optional[str], history: List[Tuple[str, str]] |
|
) -> Tuple[List[str], List[str]]: |
|
""" |
|
Produces the resulting list that needs to go inside the processor. |
|
It handles the potential image box input, the history and the system conditionning. |
|
""" |
|
resulting_list = copy.deepcopy(SYSTEM_PROMPT) |
|
|
|
|
|
for turn in history: |
|
user_utterance, assistant_utterance = turn |
|
splitted_user_utterance = split_str_on_im_markdown(user_utterance) |
|
|
|
optional_space = "" |
|
if not is_image(splitted_user_utterance[0]): |
|
optional_space = " " |
|
resulting_list.append(f"\nUser:{optional_space}") |
|
resulting_list.extend(splitted_user_utterance) |
|
resulting_list.append(f"<end_of_utterance>\nAssistant: {assistant_utterance}") |
|
|
|
|
|
current_user_prompt_str = remove_spaces_around_token(current_user_prompt_str) |
|
if current_image is None: |
|
if "![](" in current_user_prompt_str: |
|
current_user_prompt_list = split_str_on_im_markdown(current_user_prompt_str) |
|
else: |
|
current_user_prompt_list = handle_manual_images_in_user_prompt(current_user_prompt_str) |
|
|
|
optional_space = "" |
|
if not is_image(current_user_prompt_list[0]): |
|
|
|
optional_space = " " |
|
resulting_list.append(f"\nUser:{optional_space}") |
|
resulting_list.extend(current_user_prompt_list) |
|
resulting_list.append("<end_of_utterance>\nAssistant:") |
|
else: |
|
|
|
resulting_list.extend(["\nUser:", current_image, f"{current_user_prompt_str}<end_of_utterance>\nAssistant:"]) |
|
current_user_prompt_list = [current_user_prompt_str] |
|
|
|
return resulting_list, current_user_prompt_list |
|
|
|
|
|
textbox = gr.Textbox( |
|
placeholder="Upload an image and send a message", |
|
show_label=False, |
|
|
|
visible=True, |
|
container=False, |
|
label="Text input", |
|
scale=6, |
|
) |
|
with gr.Blocks(title="IDEFICS Playground", theme=gr.themes.Base()) as demo: |
|
gr.HTML("""<h1 align="center">πΆ IDEFICS Playground</h1>""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(elem_id="model_selector_row"): |
|
model_selector = gr.Dropdown( |
|
choices=MODELS.keys(), |
|
value="284 - neftune - opt 18'500", |
|
interactive=True, |
|
show_label=False, |
|
container=False, |
|
label="Model", |
|
visible=True, |
|
) |
|
|
|
imagebox = gr.Image(type="filepath", label="Image input", visible=False) |
|
|
|
with gr.Row(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
chatbot = gr.Chatbot( |
|
elem_id="chatbot", |
|
label="IDEFICS", |
|
visible=True, |
|
height=750, |
|
avatar_images=[None, BOT_AVATAR] |
|
) |
|
|
|
with gr.Group(): |
|
with gr.Row(): |
|
textbox.render() |
|
submit_btn = gr.Button(value="βΆοΈ Submit", visible=True) |
|
clear_btn = gr.ClearButton([textbox, imagebox, chatbot], value="π§Ή Clear") |
|
regenerate_btn = gr.Button(value="π Regenerate", visible=True) |
|
upload_btn = gr.UploadButton("π Upload image", file_types=["image"]) |
|
|
|
with gr.Row(): |
|
with gr.Accordion("Advanced settings", open=False, visible=True) as parameter_row: |
|
max_new_tokens = gr.Slider( |
|
minimum=8, |
|
maximum=1024, |
|
value=512, |
|
step=1, |
|
interactive=True, |
|
label="Maximum number of new tokens to generate", |
|
) |
|
repetition_penalty = gr.Slider( |
|
minimum=0.01, |
|
maximum=5.0, |
|
value=1.0, |
|
step=0.01, |
|
interactive=True, |
|
label="Repetition penalty", |
|
info="1.0 is equivalent to no penalty", |
|
) |
|
decoding_strategy = gr.Radio( |
|
[ |
|
"Greedy", |
|
"Top P Sampling", |
|
], |
|
value="Greedy", |
|
label="Decoding strategy", |
|
interactive=True, |
|
info="Higher values is equivalent to sampling more low-probability tokens.", |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.0, |
|
maximum=5.0, |
|
value=0.4, |
|
step=0.1, |
|
interactive=True, |
|
visible=False, |
|
label="Sampling temperature", |
|
info="Higher values will produce more diverse outputs.", |
|
) |
|
decoding_strategy.change( |
|
fn=lambda selection: gr.Slider( |
|
visible=( |
|
selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"] |
|
) |
|
), |
|
inputs=decoding_strategy, |
|
outputs=temperature, |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.01, |
|
maximum=0.99, |
|
value=0.8, |
|
step=0.01, |
|
interactive=True, |
|
visible=False, |
|
label="Top P", |
|
info="Higher values is equivalent to sampling more low-probability tokens.", |
|
) |
|
decoding_strategy.change( |
|
fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])), |
|
inputs=decoding_strategy, |
|
outputs=top_p, |
|
) |
|
|
|
@spaces.GPU(duration=180) |
|
def model_inference( |
|
model_selector, |
|
user_prompt_str, |
|
chat_history, |
|
image, |
|
decoding_strategy, |
|
temperature, |
|
max_new_tokens, |
|
repetition_penalty, |
|
top_p, |
|
): |
|
if user_prompt_str.strip() == "" and image is None: |
|
return "", None, chat_history |
|
|
|
formated_prompt_list, user_prompt_list = format_user_prompt_with_im_history_and_system_conditioning( |
|
current_user_prompt_str=user_prompt_str.strip(), |
|
current_image=image, |
|
history=chat_history, |
|
) |
|
|
|
streamer = TextIteratorStreamer( |
|
PROCESSOR.tokenizer, |
|
skip_prompt=True, |
|
) |
|
|
|
|
|
|
|
generation_args = { |
|
"max_new_tokens": max_new_tokens, |
|
"repetition_penalty": repetition_penalty, |
|
"bad_words_ids": BAD_WORDS_IDS, |
|
"streamer": streamer, |
|
} |
|
|
|
assert decoding_strategy in [ |
|
"Greedy", |
|
"Top P Sampling", |
|
] |
|
if decoding_strategy == "Greedy": |
|
generation_args["do_sample"] = False |
|
elif decoding_strategy == "Top P Sampling": |
|
generation_args["temperature"] = temperature |
|
generation_args["do_sample"] = True |
|
generation_args["top_p"] = top_p |
|
|
|
if image is None: |
|
|
|
chat_history.append([prompt_list_to_markdown(user_prompt_list), '']) |
|
else: |
|
|
|
|
|
|
|
chat_history.append( |
|
[ |
|
f"{prompt_list_to_markdown([image] + user_prompt_list)}", |
|
'', |
|
] |
|
) |
|
|
|
|
|
input_text, images = prompt_list_to_model_input(formated_prompt_list) |
|
inputs = create_model_inputs([input_text], [images]) |
|
inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
|
generation_args.update(inputs) |
|
|
|
thread = Thread( |
|
target=MODELS[model_selector].generate, |
|
kwargs=generation_args, |
|
) |
|
|
|
thread.start() |
|
acc_text = "" |
|
for idx, text_token in enumerate(streamer): |
|
|
|
acc_text += text_token |
|
last_turn = chat_history.pop(-1) |
|
last_turn[-1] += acc_text |
|
if last_turn[-1].endswith("\nUser"): |
|
|
|
|
|
|
|
last_turn[-1] = last_turn[-1][:-5] |
|
chat_history.append(last_turn) |
|
yield "", None, chat_history |
|
acc_text = "" |
|
|
|
def process_example(message, image): |
|
""" |
|
Same as `model_inference` but in greedy mode and with the 80b-instruct. |
|
Specifically for pre-computing the default examples. |
|
""" |
|
model_selector = "284 - neftune - opt 18'500" |
|
user_prompt_str = message |
|
chat_history = [] |
|
max_new_tokens = 512 |
|
|
|
formated_prompt_list, user_prompt_list = format_user_prompt_with_im_history_and_system_conditioning( |
|
current_user_prompt_str=user_prompt_str.strip(), |
|
current_image=image, |
|
history=chat_history, |
|
) |
|
|
|
|
|
|
|
generation_args = { |
|
"max_new_tokens": max_new_tokens, |
|
"repetition_penalty": None, |
|
"bad_words_ids": BAD_WORDS_IDS, |
|
"do_sample": False, |
|
} |
|
|
|
if image is None: |
|
|
|
chat_history.append([prompt_list_to_markdown(user_prompt_list), '']) |
|
else: |
|
|
|
|
|
|
|
chat_history.append( |
|
[ |
|
f"{prompt_list_to_markdown([image] + user_prompt_list)}", |
|
'', |
|
] |
|
) |
|
|
|
|
|
input_text, images = prompt_list_to_model_input(formated_prompt_list) |
|
inputs = create_model_inputs([input_text], [images]) |
|
inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
|
generation_args.update(inputs) |
|
|
|
generated_ids = MODELS[model_selector].generate(**generation_args) |
|
generated_text = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
if generated_text.endswith("\nUser"): |
|
generated_text = generated_text[:-5] |
|
|
|
last_turn = chat_history.pop(-1) |
|
last_turn[-1] += generated_text |
|
chat_history.append(last_turn) |
|
return "", None, chat_history |
|
|
|
textbox.submit( |
|
fn=model_inference, |
|
inputs=[ |
|
model_selector, |
|
textbox, |
|
chatbot, |
|
imagebox, |
|
decoding_strategy, |
|
temperature, |
|
max_new_tokens, |
|
repetition_penalty, |
|
top_p, |
|
], |
|
outputs=[textbox, imagebox, chatbot], |
|
) |
|
submit_btn.click( |
|
fn=model_inference, |
|
inputs=[ |
|
model_selector, |
|
textbox, |
|
chatbot, |
|
imagebox, |
|
decoding_strategy, |
|
temperature, |
|
max_new_tokens, |
|
repetition_penalty, |
|
top_p, |
|
], |
|
outputs=[ |
|
textbox, |
|
imagebox, |
|
chatbot, |
|
], |
|
) |
|
|
|
def remove_last_turn(chat_history): |
|
if len(chat_history) == 0: |
|
return gr.Update(), gr.Update() |
|
last_interaction = chat_history[-1] |
|
chat_history = chat_history[:-1] |
|
chat_update = gr.update(value=chat_history) |
|
text_update = gr.update(value=last_interaction[0]) |
|
return chat_update, text_update |
|
|
|
regenerate_btn.click(fn=remove_last_turn, inputs=chatbot, outputs=[chatbot, textbox]).then( |
|
fn=model_inference, |
|
inputs=[ |
|
model_selector, |
|
textbox, |
|
chatbot, |
|
imagebox, |
|
decoding_strategy, |
|
temperature, |
|
max_new_tokens, |
|
repetition_penalty, |
|
top_p, |
|
], |
|
outputs=[ |
|
textbox, |
|
imagebox, |
|
chatbot, |
|
], |
|
) |
|
|
|
upload_btn.upload(add_file, [upload_btn], [imagebox, upload_btn], queue=False) |
|
submit_btn.click(lambda : gr.update(label='π Upload image', interactive=True), [], upload_btn) |
|
textbox.submit(lambda : gr.update(label='π Upload image', interactive=True), [], upload_btn) |
|
clear_btn.click(lambda : gr.update(label='π Upload image', interactive=True), [], upload_btn) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.queue(max_size=40) |
|
demo.launch() |
|
|