import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM import time class ConversationManager: def __init__(self): self.models = {} self.conversation = [] self.delay = 3 self.is_paused = False self.current_model = None self.initial_prompt = "" self.task_complete = False def load_model(self, model_name): if not model_name: print("Error: Empty model name provided") return None if model_name in self.models: return self.models[model_name] try: print(f"Attempting to load model: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True) self.models[model_name] = (model, tokenizer) print(f"Successfully loaded model: {model_name}") return self.models[model_name] except Exception as e: print(f"Failed to load model {model_name}: {e}") print(f"Error type: {type(e).__name__}") print(f"Error details: {str(e)}") return None def generate_response(self, model_name, prompt): model, tokenizer = self.load_model(model_name) if "llama" in model_name.lower(): formatted_prompt = self.format_llama2_prompt(prompt) else: formatted_prompt = self.format_general_prompt(prompt) inputs = tokenizer(formatted_prompt, return_tensors="pt", max_length=1024, truncation=True) with torch.no_grad(): outputs = model.generate(**inputs, max_length=200, num_return_sequences=1, do_sample=True) return tokenizer.decode(outputs[0], skip_special_tokens=True) def format_llama2_prompt(self, prompt): B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" system_prompt = "You are a helpful AI assistant. Please provide a concise and relevant response." formatted_prompt = f"{B_INST} {B_SYS}{system_prompt}{E_SYS}{prompt.strip()} {E_INST}" return formatted_prompt def format_general_prompt(self, prompt): return f"Human: {prompt.strip()}\n\nAssistant:" def add_to_conversation(self, model_name, response): self.conversation.append((model_name, response)) if "task complete?" in response.lower(): self.task_complete = True def get_conversation_history(self): return "\n".join([f"{model}: {msg}" for model, msg in self.conversation]) def clear_conversation(self): self.conversation = [] self.initial_prompt = "" self.models = {} self.current_model = None self.task_complete = False def rewind_conversation(self, steps): self.conversation = self.conversation[:-steps] self.task_complete = False def rewind_and_insert(self, steps, inserted_response): if steps > 0: self.conversation = self.conversation[:-steps] if inserted_response.strip(): last_model = self.conversation[-1][0] if self.conversation else "User" next_model = "Model 1" if last_model == "Model 2" or last_model == "User" else "Model 2" self.conversation.append((next_model, inserted_response)) self.current_model = last_model self.task_complete = False manager = ConversationManager() def get_model(dropdown, custom): return custom if custom and custom.strip() else dropdown def chat(model1, model2, user_input, history, inserted_response=""): try: print(f"Starting chat with models: {model1}, {model2}") print(f"User input: {user_input}") model1 = get_model(model1, model1_custom.value) model2 = get_model(model2, model2_custom.value) print(f"Selected models: {model1}, {model2}") if not manager.load_model(model1) or not manager.load_model(model2): return "Error: Failed to load one or both models. Please check the model names and try again.", "" if not manager.conversation: manager.initial_prompt = user_input manager.clear_conversation() manager.add_to_conversation("User", user_input) models = [model1, model2] current_model_index = 0 if manager.current_model in ["User", "Model 2"] else 1 while not manager.task_complete: if manager.is_paused: yield history, "Conversation paused." return model = models[current_model_index] manager.current_model = model if inserted_response and current_model_index == 0: response = inserted_response inserted_response = "" else: conversation_history = manager.get_conversation_history() prompt = f"{conversation_history}\n\nPlease continue the conversation. If you believe the task is complete, end your response with 'Task complete?'" response = manager.generate_response(model, prompt) manager.add_to_conversation(model, response) history = manager.get_conversation_history() for i in range(manager.delay, 0, -1): yield history, f"{model} is writing... {i}" time.sleep(1) yield history, "" if manager.task_complete: yield history, "Models believe the task is complete. Are you satisfied with the result? (Yes/No)" return current_model_index = (current_model_index + 1) % 2 return history, "Conversation completed." except Exception as e: print(f"Error in chat function: {str(e)}") print(f"Error type: {type(e).__name__}") print(f"Error details: {str(e)}") return f"An error occurred: {str(e)}", "" def user_satisfaction(satisfied, history): if satisfied.lower() == 'yes': return history, "Task completed successfully." else: manager.task_complete = False return history, "Continuing the conversation..." def pause_conversation(): manager.is_paused = True return "Conversation paused. Press Resume to continue." def resume_conversation(): manager.is_paused = False return "Conversation resumed." def edit_response(edited_text): if manager.conversation: manager.conversation[-1] = (manager.current_model, edited_text) manager.task_complete = False return manager.get_conversation_history() def restart_conversation(model1, model2, user_input): manager.clear_conversation() return chat(model1, model2, user_input, "") def rewind_and_insert(steps, inserted_response, history): manager.rewind_and_insert(int(steps), inserted_response) return manager.get_conversation_history(), "" open_source_models = [ "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-70b-chat-hf", "mistralai/Mixtral-8x7B-Instruct-v0.1", "bigcode/starcoder2-15b", "bigcode/starcoder2-3b", "tiiuae/falcon-7b", "tiiuae/falcon-40b", "EleutherAI/gpt-neox-20b", "google/flan-ul2", "stabilityai/stablelm-zephyr-3b", "HuggingFaceH4/zephyr-7b-beta", "microsoft/phi-2", "google/gemma-7b-it" ] with gr.Blocks() as demo: gr.Markdown("# ConversAI Playground") with gr.Row(): with gr.Column(scale=1): model1_dropdown = gr.Dropdown(choices=open_source_models, label="Model 1") model1_custom = gr.Textbox(label="Custom Model 1") with gr.Column(scale=1): model2_dropdown = gr.Dropdown(choices=open_source_models, label="Model 2") model2_custom = gr.Textbox(label="Custom Model 2") user_input = gr.Textbox(label="Initial prompt", lines=2) chat_history = gr.Textbox(label="Conversation", lines=20) current_response = gr.Textbox(label="Current model response", lines=3) with gr.Row(): pause_btn = gr.Button("Pause") edit_btn = gr.Button("Edit") rewind_btn = gr.Button("Rewind") resume_btn = gr.Button("Resume") restart_btn = gr.Button("Restart") clear_btn = gr.Button("Clear") with gr.Row(): rewind_steps = gr.Slider(0, 10, 1, label="Steps to rewind") inserted_response = gr.Textbox(label="Insert response after rewind", lines=2) delay_slider = gr.Slider(0, 10, 3, label="Response Delay (seconds)") user_satisfaction_input = gr.Textbox(label="Are you satisfied with the result? (Yes/No)", visible=False) gr.Markdown(""" ## Button Descriptions - **Pause**: Temporarily stops the conversation. The current model will finish its response. - **Edit**: Allows you to modify the last response in the conversation. - **Rewind**: Removes the specified number of last responses from the conversation. - **Resume**: Continues the conversation from where it was paused. - **Restart**: Begins a new conversation with the same or different models, keeping the initial prompt. - **Clear**: Resets everything, including loaded models, conversation history, and initial prompt. """) def on_chat_update(history, response): if response and "Models believe the task is complete" in response: return gr.update(visible=True), gr.update(visible=False) return gr.update(visible=False), gr.update(visible=True) start_btn = gr.Button("Start Conversation") chat_output = start_btn.click( chat, inputs=[ model1_dropdown, model2_dropdown, user_input, chat_history ], outputs=[chat_history, current_response] ) chat_output.then( on_chat_update, inputs=[chat_history, current_response], outputs=[user_satisfaction_input, start_btn] ) user_satisfaction_input.submit( user_satisfaction, inputs=[user_satisfaction_input, chat_history], outputs=[chat_history, current_response] ).then( chat, inputs=[ model1_dropdown, model2_dropdown, user_input, chat_history ], outputs=[chat_history, current_response] ) pause_btn.click(pause_conversation, outputs=[current_response]) resume_btn.click( chat, inputs=[ model1_dropdown, model2_dropdown, user_input, chat_history, inserted_response ], outputs=[chat_history, current_response] ) edit_btn.click(edit_response, inputs=[current_response], outputs=[chat_history]) rewind_btn.click(rewind_and_insert, inputs=[rewind_steps, inserted_response, chat_history], outputs=[chat_history, current_response]) restart_btn.click( restart_conversation, inputs=[ model1_dropdown, model2_dropdown, user_input ], outputs=[chat_history, current_response] ) clear_btn.click(manager.clear_conversation, outputs=[chat_history, current_response, user_input]) delay_slider.change(lambda x: setattr(manager, 'delay', x), inputs=[delay_slider]) if __name__ == "__main__": demo.launch()