import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch import gradio as gr from threading import Thread import subprocess subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) models_available = [ "MohamedRashad/Arabic-Orpo-Llama-3-8B-Instruct", "silma-ai/SILMA-9B-Instruct-v0.1.1", "inceptionai/jais-adapted-7b-chat", # "inceptionai/jais-adapted-13b-chat", "inceptionai/jais-family-6p7b-chat", # "inceptionai/jais-family-13b-chat", "NousResearch/Meta-Llama-3.1-8B-Instruct", "unsloth/gemma-2-9b-it", "NousResearch/Meta-Llama-3-8B-Instruct", ] tokenizer_a, model_a = None, None tokenizer_b, model_b = None, None def load_model_a(model_id): global tokenizer_a, model_a tokenizer_a = AutoTokenizer.from_pretrained(model_id) print(f"model A: {tokenizer_a.eos_token}") try: model_a = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2", trust_remote_code=True, ).eval() except: print(f"Using default attention implementation in {model_id}") model_a = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ).eval() return gr.update(label=model_id) def load_model_b(model_id): global tokenizer_b, model_b tokenizer_b = AutoTokenizer.from_pretrained(model_id) print(f"model B: {tokenizer_b.eos_token}") try: model_b = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2", trust_remote_code=True, ).eval() except: print(f"Using default attention implementation in {model_id}") model_b = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ).eval() return gr.update(label=model_id) @spaces.GPU() def generate_both(system_prompt, input_text, chatbot_a, chatbot_b, max_new_tokens=2048, temperature=0.2, top_p=0.9, repetition_penalty=1.1): text_streamer_a = TextIteratorStreamer(tokenizer_a, skip_prompt=True) text_streamer_b = TextIteratorStreamer(tokenizer_b, skip_prompt=True) system_prompt_list = [{"role": "system", "content": system_prompt}] if system_prompt else [] input_text_list = [{"role": "user", "content": input_text}] chat_history_a = [] for user, assistant in chatbot_a: chat_history_a.append({"role": "user", "content": user}) chat_history_a.append({"role": "assistant", "content": assistant}) chat_history_b = [] for user, assistant in chatbot_b: chat_history_b.append({"role": "user", "content": user}) chat_history_b.append({"role": "assistant", "content": assistant}) base_messages = system_prompt_list + chat_history_a + input_text_list new_messages = system_prompt_list + chat_history_b + input_text_list input_ids_a = tokenizer_a.apply_chat_template( base_messages, add_generation_prompt=True, return_tensors="pt" ).to(model_a.device) input_ids_b = tokenizer_b.apply_chat_template( new_messages, add_generation_prompt=True, return_tensors="pt" ).to(model_b.device) generation_kwargs_a = dict( input_ids=input_ids_a, streamer=text_streamer_a, max_new_tokens=max_new_tokens, pad_token_id=tokenizer_a.eos_token_id, do_sample=True if temperature > 0 else False, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, ) generation_kwargs_b = dict( input_ids=input_ids_b, streamer=text_streamer_b, max_new_tokens=max_new_tokens, pad_token_id=tokenizer_b.eos_token_id, do_sample=True if temperature > 0 else False, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, ) thread_a = Thread(target=model_a.generate, kwargs=generation_kwargs_a) thread_b = Thread(target=model_b.generate, kwargs=generation_kwargs_b) thread_a.start() thread_b.start() chatbot_a.append([input_text, ""]) chatbot_b.append([input_text, ""]) finished_a = False finished_b = False while not (finished_a and finished_b): if not finished_a: try: text_a = next(text_streamer_a) if tokenizer_a.eos_token in text_a: eot_location = text_a.find(tokenizer_a.eos_token) text_a = text_a[:eot_location] finished_a = True chatbot_a[-1][-1] += text_a yield chatbot_a, chatbot_b except StopIteration: finished_a = True if not finished_b: try: text_b = next(text_streamer_b) if tokenizer_b.eos_token in text_b: eot_location = text_b.find(tokenizer_b.eos_token) text_b = text_b[:eot_location] finished_b = True chatbot_b[-1][-1] += text_b yield chatbot_a, chatbot_b except StopIteration: finished_b = True return chatbot_a, chatbot_b def clear(): return [], [] arena_notes = """Important Notes: - `gemma-2` model doesn't have system prompt, so it's make the system prompt field empty for the model to work. - Sometimes an error may occur when generating the response, in this case, please try again. """ with gr.Blocks(title="Arabic-ORPO-Llama3") as demo: with gr.Column(): gr.HTML("