import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch from datasets import load_dataset import random import os # Check if fine-tuned model exists, otherwise use base model model_path = "./customer_support_chatbot" if os.path.exists("./customer_support_chatbot") else "microsoft/DialoGPT-medium" tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained(model_path) # Load the customer support dataset dataset = load_dataset("Victorano/customer-support-1k") def generate_response(message, history): # Format the input with conversation history conversation = "" for user_msg, bot_msg in history: conversation += f"Customer: {user_msg}\nSupport: {bot_msg}\n" conversation += f"Customer: {message}\nSupport:" # Encode the conversation input_ids = tokenizer.encode(conversation, return_tensors='pt') # Generate response with torch.no_grad(): output_ids = model.generate( input_ids, max_length=1000, num_return_sequences=1, no_repeat_ngram_size=2, temperature=0.7, top_k=50, top_p=0.9, pad_token_id=tokenizer.eos_token_id ) # Decode and return the response response = tokenizer.decode(output_ids[0], skip_special_tokens=True) # Extract only the last response (after "Support:") response = response.split("Support:")[-1].strip() return response # Create the Gradio interface with gr.Blocks(css="footer {display: none !important}") as demo: gr.Markdown(""" # 🤖 Customer Support Chatbot This chatbot is fine-tuned on customer support conversations using DialoGPT-medium. """) chatbot = gr.Chatbot( [], elem_id="chatbot", bubble_full_width=False, avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=1"), height=500, show_copy_button=True, ) with gr.Row(): txt = gr.Textbox( show_label=False, placeholder="Type your message here...", container=False ) submit_btn = gr.Button("Send", variant="primary") # Handle user input and generate response def user_input(message, history): return "", history + [[message, generate_response(message, history)]] # Connect the interface components txt.submit(user_input, [txt, chatbot], [txt, chatbot]) submit_btn.click(user_input, [txt, chatbot], [txt, chatbot]) if __name__ == "__main__": demo.launch()