Spaces:
Runtime error
Runtime error
import os | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
# Function to load model and tokenizer based on selection | |
def load_model(model_name): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
return tokenizer, model | |
# Define the function to generate a response with adjustable parameters and model-specific adjustments | |
def generate_response(prompt, model_name, persona="I am a helpful assistant.", temperature=0.7, top_p=0.9, repetition_penalty=1.2, max_length=70): | |
# Load the chosen model and tokenizer | |
tokenizer, model = load_model(model_name) | |
# Adjust the prompt format for DialoGPT | |
if model_name == "microsoft/DialoGPT-small": | |
full_prompt = f"User: {prompt}\nBot:" # Structure as a conversation | |
else: | |
full_prompt = f"{persona}: {prompt}" # Standard format for other models | |
# Tokenize and generate response | |
inputs = tokenizer(full_prompt, return_tensors="pt") | |
outputs = model.generate( | |
inputs.input_ids, | |
max_length=max_length, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
# Trim the prompt if it appears in the response | |
if model_name == "microsoft/DialoGPT-small": | |
response_without_prompt = response.split("Bot:", 1)[-1].strip() | |
else: | |
response_without_prompt = response.split(":", 1)[-1].strip() | |
return response_without_prompt if response_without_prompt else "I'm not sure how to respond to that." | |
# Define Gradio interface function with model selection | |
def chat_interface(user_input, model_choice, persona="I am a helpful assistant", temperature=0.7, top_p=0.9, repetition_penalty=1.2, max_length=50): | |
return generate_response(user_input, model_choice, persona, temperature, top_p, repetition_penalty, max_length) | |
# Set up Gradio interface with model selection and parameter sliders | |
interface = gr.Interface( | |
fn=chat_interface, | |
inputs=[ | |
gr.Textbox(label="User Input"), | |
gr.Dropdown(choices=["distilgpt2", "gpt2", "microsoft/DialoGPT-small"], label="Model Choice", value="distilgpt2"), | |
gr.Textbox(label="Persona", value="You are a helpful assistant."), | |
gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1), | |
gr.Slider(label="Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, value=0.9, step=0.1), | |
gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.1), | |
gr.Slider(label="Max Length", minimum=10, maximum=100, value=50, step=5) | |
], | |
outputs="text", | |
title="Interactive Chatbot with Model Comparison", | |
description="Chat with the bot! Select a model and adjust parameters to see how they affect the response." | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
interface.launch() | |