import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch import gradio as gr from threading import Thread base_model_id = "NousResearch/Meta-Llama-3-8B-Instruct" new_model_id = "MohamedRashad/Arabic-Orpo-Llama-3-8B-Instruct" # Reload tokenizer and model tokenizer = AutoTokenizer.from_pretrained(base_model_id) base_model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch.bfloat16, device_map="auto", ).eval() new_model = AutoModelForCausalLM.from_pretrained( new_model_id, torch_dtype=torch.bfloat16, device_map="auto", ).eval() terminators = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>"), ] @spaces.GPU(duration=120) def generate_both(system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens=2048, temperature=0.2, top_p=0.9, repetition_penalty=1.1): base_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) new_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) system_prompt_list = [{"role": "system", "content": system_prompt}] input_text_list = [{"role": "user", "content": input_text}] base_chat_history = [] for user, assistant in base_chatbot: base_chat_history.append({"role": "user", "content": user}) base_chat_history.append({"role": "assistant", "content": assistant}) new_chat_history = [] for user, assistant in new_chatbot: new_chat_history.append({"role": "user", "content": user}) new_chat_history.append({"role": "assistant", "content": assistant}) base_messages = system_prompt_list + base_chat_history + input_text_list new_messages = system_prompt_list + new_chat_history + input_text_list base_input_ids = tokenizer.apply_chat_template( base_messages, add_generation_prompt=True, return_tensors="pt" ).to(base_model.device).long() new_input_ids = tokenizer.apply_chat_template( new_messages, add_generation_prompt=True, return_tensors="pt" ).to(new_model.device).long() base_generation_kwargs = dict( input_ids=base_input_ids, streamer=base_text_streamer, max_new_tokens=max_new_tokens, eos_token_id=terminators, pad_token_id=tokenizer.eos_token_id, do_sample=True if temperature > 0 else False, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, ) new_generation_kwargs = dict( input_ids=new_input_ids, streamer=new_text_streamer, max_new_tokens=max_new_tokens, eos_token_id=terminators, pad_token_id=tokenizer.eos_token_id, do_sample=True if temperature > 0 else False, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, ) base_thread = Thread(target=base_model.generate, kwargs=base_generation_kwargs) base_thread.start() base_chatbot.append([input_text, ""]) new_chatbot.append([input_text, ""]) for base_text in base_text_streamer: if "<|eot_id|>" in base_text: eot_location = base_text.find("<|eot_id|>") base_text = base_text[:eot_location] base_chatbot[-1][-1] += base_text yield base_chatbot, new_chatbot new_thread = Thread(target=new_model.generate, kwargs=new_generation_kwargs) new_thread.start() for new_text in new_text_streamer: if "<|eot_id|>" in new_text: eot_location = new_text.find("<|eot_id|>") new_text = new_text[:eot_location] new_chatbot[-1][-1] += new_text yield base_chatbot, new_chatbot return base_chatbot, new_chatbot def clear(): return [], [] with gr.Blocks(title="Arabic-ORPO-Llama3") as demo: with gr.Column(): gr.HTML("