base_chat / app.py
sbicy's picture
Update app.py
0922636 verified
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
# Function to load model and tokenizer based on selection
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return tokenizer, model
# Define the function to generate a response with adjustable parameters and model-specific adjustments
def generate_response(prompt, model_name, persona="I am a helpful assistant.", temperature=0.7, top_p=0.9, repetition_penalty=1.2, max_length=70):
# Load the chosen model and tokenizer
tokenizer, model = load_model(model_name)
# Adjust the prompt format for DialoGPT
if model_name == "microsoft/DialoGPT-small":
full_prompt = f"User: {prompt}\nBot:" # Structure as a conversation
else:
full_prompt = f"{persona}: {prompt}" # Standard format for other models
# Tokenize and generate response
inputs = tokenizer(full_prompt, return_tensors="pt")
outputs = model.generate(
inputs.input_ids,
max_length=max_length,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
# Trim the prompt if it appears in the response
if model_name == "microsoft/DialoGPT-small":
response_without_prompt = response.split("Bot:", 1)[-1].strip()
else:
response_without_prompt = response.split(":", 1)[-1].strip()
return response_without_prompt if response_without_prompt else "I'm not sure how to respond to that."
# Define Gradio interface function with model selection
def chat_interface(user_input, model_choice, persona="I am a helpful assistant", temperature=0.7, top_p=0.9, repetition_penalty=1.2, max_length=50):
return generate_response(user_input, model_choice, persona, temperature, top_p, repetition_penalty, max_length)
# Set up Gradio interface with model selection and parameter sliders
interface = gr.Interface(
fn=chat_interface,
inputs=[
gr.Textbox(label="User Input"),
gr.Dropdown(choices=["distilgpt2", "gpt2", "microsoft/DialoGPT-small"], label="Model Choice", value="distilgpt2"),
gr.Textbox(label="Persona", value="You are a helpful assistant."),
gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1),
gr.Slider(label="Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, value=0.9, step=0.1),
gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.2, step=0.1),
gr.Slider(label="Max Length", minimum=10, maximum=100, value=50, step=5)
],
outputs="text",
title="Interactive Chatbot with Model Comparison",
description="Chat with the bot! Select a model and adjust parameters to see how they affect the response."
)
# Launch the Gradio app
if __name__ == "__main__":
interface.launch()