import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM import time import os from huggingface_hub import login # Hugging Face login login(token=os.environ["HF_TOKEN"]) print(f"CUDA is available: {torch.cuda.is_available()}") print(f"CUDA device count: {torch.cuda.device_count()}") if torch.cuda.is_available(): print(f"Current CUDA device: {torch.cuda.current_device()}") print(f"CUDA device name: {torch.cuda.get_device_name(0)}") class ConversationManager: def __init__(self): self.models = {} self.conversation = [] self.delay = 3 self.is_paused = False self.current_model = None self.initial_prompt = "" self.task_complete = False def load_model(self, model_name): if not model_name: print("Error: Empty model name provided") return None if model_name in self.models: return self.models[model_name] try: print(f"Attempting to load model: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) # Try to load the model with 8-bit quantization model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True) except RuntimeError as e: print(f"8-bit quantization not available, falling back to full precision: {e}") if torch.cuda.is_available(): model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") else: model = AutoModelForCausalLM.from_pretrained(model_name) except Exception as e: print(f"Failed to load model {model_name}: {e}") print(f"Error type: {type(e).__name__}") print(f"Error details: {str(e)}") return None self.models[model_name] = (model, tokenizer) print(f"Successfully loaded model: {model_name}") return self.models[model_name] def generate_response(self, model_name, prompt): model, tokenizer = self.load_model(model_name) formatted_prompt = f"Human: {prompt.strip()}\n\nAssistant:" inputs = tokenizer(formatted_prompt, return_tensors="pt", max_length=1024, truncation=True) with torch.no_grad(): outputs = model.generate(**inputs, max_length=200, num_return_sequences=1, do_sample=True) return tokenizer.decode(outputs[0], skip_special_tokens=True) def add_to_conversation(self, model_name, response): self.conversation.append((model_name, response)) if "task complete?" in response.lower(): self.task_complete = True def get_conversation_history(self): return "\n".join([f"{model}: {msg}" for model, msg in self.conversation]) def clear_conversation(self): self.conversation = [] self.initial_prompt = "" self.models = {} self.current_model = None self.task_complete = False def rewind_conversation(self, steps): self.conversation = self.conversation[:-steps] self.task_complete = False def rewind_and_insert(self, steps, inserted_response): if steps > 0: self.conversation = self.conversation[:-steps] if inserted_response.strip(): last_model = self.conversation[-1][0] if self.conversation else "User" next_model = "Model 1" if last_model == "Model 2" or last_model == "User" else "Model 2" self.conversation.append((next_model, inserted_response)) self.current_model = last_model self.task_complete = False manager = ConversationManager() def get_model(dropdown, custom): return custom if custom and custom.strip() else dropdown def chat(model1, model2, user_input, history, inserted_response=""): try: print(f"Starting chat with models: {model1}, {model2}") print(f"User input: {user_input}") model1 = get_model(model1, model1_custom.value) model2 = get_model(model2, model2_custom.value) print(f"Selected models: {model1}, {model2}") if not manager.load_model(model1) or not manager.load_model(model2): return "Error: Failed to load one or both models. Please check the model names and try again.", "" if not manager.conversation: manager.initial_prompt = user_input manager.clear_conversation() manager.add_to_conversation("User", user_input) models = [model1, model2] current_model_index = 0 if manager.current_model in ["User", "Model 2"] else 1 while not manager.task_complete: if manager.is_paused: yield history, "Conversation paused." return model = models[current_model_index] manager.current_model = model if inserted_response and current_model_index == 0: response = inserted_response inserted_response = "" else: conversation_history = manager.get_conversation_history() prompt = f"{conversation_history}\n\nPlease continue the conversation. If you believe the task is complete, end your response with 'Task complete?'" response = manager.generate_response(model, prompt) manager.add_to_conversation(model, response) history = manager.get_conversation_history() for i in range(manager.delay, 0, -1): yield history, f"{model} is writing... {i}" time.sleep(1) yield history, "" if manager.task_complete: yield history, "Models believe the task is complete. Are you satisfied with the result? (Yes/No)" return current_model_index = (current_model_index + 1) % 2 return history, "Conversation completed." except Exception as e: print(f"Error in chat function: {str(e)}") print(f"Error type: {type(e).__name__}") print(f"Error details: {str(e)}") return f"An error occurred: {str(e)}", "" def user_satisfaction(satisfied, history): if satisfied.lower() == 'yes': return history, "Task completed successfully." else: manager.task_complete = False return history, "Continuing the conversation..." def pause_conversation(): manager.is_paused = True return "Conversation paused. Press Resume to continue." def resume_conversation(): manager.is_paused = False return "Conversation resumed." def edit_response(edited_text): if manager.conversation: manager.conversation[-1] = (manager.current_model, edited_text) manager.task_complete = False return manager.get_conversation_history() def restart_conversation(model1, model2, user_input): manager.clear_conversation() return chat(model1, model2, user_input, "") def rewind_and_insert(steps, inserted_response, history): manager.rewind_and_insert(int(steps), inserted_response) return manager.get_conversation_history(), "" open_source_models = [ "mistralai/Mixtral-8x7B-Instruct-v0.1", "bigcode/starcoder2-15b", "bigcode/starcoder2-3b", "tiiuae/falcon-7b", "EleutherAI/gpt-neox-20b", "google/flan-ul2", "stabilityai/stablelm-zephyr-3b", "HuggingFaceH4/zephyr-7b-beta", "microsoft/phi-2", "google/gemma-7b-it", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", "mosaicml/mpt-7b-chat", "databricks/dolly-v2-12b", "thebloke/Wizard-Vicuna-13B-Uncensored-HF", "bigscience/bloom-560m" ] with gr.Blocks() as demo: gr.Markdown("# ConversAI Playground") with gr.Row(): with gr.Column(scale=1): model1_dropdown = gr.Dropdown(choices=open_source_models, label="Model 1") model1_custom = gr.Textbox(label="Custom Model 1") with gr.Column(scale=1): model2_dropdown = gr.Dropdown(choices=open_source_models, label="Model 2") model2_custom = gr.Textbox(label="Custom Model 2") user_input = gr.Textbox(label="Initial prompt", lines=2) chat_history = gr.Textbox(label="Conversation", lines=20) current_response = gr.Textbox(label="Current model response", lines=3) with gr.Row(): pause_btn = gr.Button("Pause") edit_btn = gr.Button("Edit") rewind_btn = gr.Button("Rewind") resume_btn = gr.Button("Resume") restart_btn = gr.Button("Restart") clear_btn = gr.Button("Clear") with gr.Row(): rewind_steps = gr.Slider(0, 10, 1, label="Steps to rewind") inserted_response = gr.Textbox(label="Insert response after rewind", lines=2) delay_slider = gr.Slider(0, 10, 3, label="Response Delay (seconds)") user_satisfaction_input = gr.Textbox(label="Are you satisfied with the result? (Yes/No)", visible=False) gr.Markdown(""" ## Button Descriptions - **Pause**: Temporarily stops the conversation. The current model will finish its response. - **Edit**: Allows you to modify the last response in the conversation. - **Rewind**: Removes the specified number of last responses from the conversation. - **Resume**: Continues the conversation from where it was paused. - **Restart**: Begins a new conversation with the same or different models, keeping the initial prompt. - **Clear**: Resets everything, including loaded models, conversation history, and initial prompt. """) def on_chat_update(history, response): if response and "Models believe the task is complete" in response: return gr.update(visible=True), gr.update(visible=False) return gr.update(visible=False), gr.update(visible=True) start_btn = gr.Button("Start Conversation") chat_output = start_btn.click( chat, inputs=[ model1_dropdown, model2_dropdown, user_input, chat_history ], outputs=[chat_history, current_response] ) chat_output.then( on_chat_update, inputs=[chat_history, current_response], outputs=[user_satisfaction_input, start_btn] ) user_satisfaction_input.submit( user_satisfaction, inputs=[user_satisfaction_input, chat_history], outputs=[chat_history, current_response] ).then( chat, inputs=[ model1_dropdown, model2_dropdown, user_input, chat_history ], outputs=[chat_history, current_response] ) pause_btn.click(pause_conversation, outputs=[current_response]) resume_btn.click( chat, inputs=[ model1_dropdown, model2_dropdown, user_input, chat_history, inserted_response ], outputs=[chat_history, current_response] ) edit_btn.click(edit_response, inputs=[current_response], outputs=[chat_history]) rewind_btn.click(rewind_and_insert, inputs=[rewind_steps, inserted_response, chat_history], outputs=[chat_history, current_response]) restart_btn.click( restart_conversation, inputs=[ model1_dropdown, model2_dropdown, user_input ], outputs=[chat_history, current_response] ) clear_btn.click(manager.clear_conversation, outputs=[chat_history, current_response, user_input]) delay_slider.change(lambda x: setattr(manager, 'delay', x), inputs=[delay_slider]) if __name__ == "__main__": demo.launch()