DialoGPT / app.py
kdevoe's picture
Reducing default token generation to 15 tokens to prevent the model from taking too long to respond.
7a04fc2 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load the shared tokenizer (using a tokenizer from DialoGPT models)
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
# Define the model names, including the locally saved fine-tuned model
model_names = {
"DialoGPT-med-FT": "DialoGPT-med-FT.bin",
"DialoGPT-medium": "microsoft/DialoGPT-medium"
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the default model initially
current_model_name = "DialoGPT-med-FT"
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
model.load_state_dict(torch.load(model_names[current_model_name], map_location=device))
model.to(device)
def load_model(model_name):
global model, current_model_name
if model_name != current_model_name:
# Load the new model and update the current model reference
if model_name == "DialoGPT-medium":
model = AutoModelForCausalLM.from_pretrained(model_names[model_name]).to(device)
elif model_name == "DialoGPT-med-FT":
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
model.load_state_dict(torch.load(model_names[model_name], map_location=device))
model.to(device)
current_model_name = model_name
def respond(
message,
history: list[dict],
model_choice,
max_tokens,
temperature,
top_p,
):
# Load the selected model if it's different from the current one
load_model(model_choice)
# Prepare the input by concatenating the history into a dialogue format
input_text = ""
for message_pair in history:
input_text += f"{message_pair['role']}: {message_pair['content']}\n"
input_text += f"User: {message}\nAssistant:"
# Tokenize the input text using the shared tokenizer
inputs = tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
# Generate the response using the selected DialoGPT model
output_tokens = model.generate(
inputs["input_ids"].to(device),
max_length=len(inputs["input_ids"][0]) + max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
# Decode and return the assistant's response
response = tokenizer.decode(output_tokens[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
yield response
# Define the Gradio interface
demo = gr.ChatInterface(
respond,
type='messages',
additional_inputs=[
gr.Dropdown(choices=["DialoGPT-med-FT", "DialoGPT-medium"], value="DialoGPT-med-FT", label="Model"),
gr.Slider(minimum=1, maximum=100, value=15, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
demo.launch()