Spaces:
Sleeping
Sleeping
File size: 2,610 Bytes
d538a8c 0680865 d538a8c 0680865 d538a8c 0680865 d538a8c 0680865 d538a8c 0680865 d538a8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from datasets import load_dataset
import random
import os
# Check if fine-tuned model exists, otherwise use base model
model_path = "./customer_support_chatbot" if os.path.exists("./customer_support_chatbot") else "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
# Load the customer support dataset
dataset = load_dataset("Victorano/customer-support-1k")
def generate_response(message, history):
# Format the input with conversation history
conversation = ""
for user_msg, bot_msg in history:
conversation += f"Customer: {user_msg}\nSupport: {bot_msg}\n"
conversation += f"Customer: {message}\nSupport:"
# Encode the conversation
input_ids = tokenizer.encode(conversation, return_tensors='pt')
# Generate response
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_length=1000,
num_return_sequences=1,
no_repeat_ngram_size=2,
temperature=0.7,
top_k=50,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
# Decode and return the response
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Extract only the last response (after "Support:")
response = response.split("Support:")[-1].strip()
return response
# Create the Gradio interface
with gr.Blocks(css="footer {display: none !important}") as demo:
gr.Markdown("""
# 🤖 Customer Support Chatbot
This chatbot is fine-tuned on customer support conversations using DialoGPT-medium.
""")
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False,
avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=1"),
height=500,
show_copy_button=True,
)
with gr.Row():
txt = gr.Textbox(
show_label=False,
placeholder="Type your message here...",
container=False
)
submit_btn = gr.Button("Send", variant="primary")
# Handle user input and generate response
def user_input(message, history):
return "", history + [[message, generate_response(message, history)]]
# Connect the interface components
txt.submit(user_input, [txt, chatbot], [txt, chatbot])
submit_btn.click(user_input, [txt, chatbot], [txt, chatbot])
if __name__ == "__main__":
demo.launch() |