Chillyblast's picture
Update app.py
801b717 verified
raw
history blame
2.21 kB
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset, Dataset
import pandas as pd
# Load the dataset
ds = load_dataset("bitext/Bitext-customer-support-llm-chatbot-training-dataset")
# Convert the dataset to a pandas DataFrame
df = ds['train'].to_pandas()
# Define labels based on your intent categories
label2id = {label: idx for idx, label in enumerate(df['intent'].unique())}
id2label = {idx: label for label, idx in label2id.items()}
# Encode labels
df['label'] = df['intent'].map(label2id)
# Ensure 'instruction', 'label', 'intent', and 'response' columns are included
df = df[['instruction', 'label', 'intent', 'response']]
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Chillyblast/Roberta_Question_Answer")
model = AutoModelForSequenceClassification.from_pretrained("Chillyblast/Roberta_Question_Answer")
# Ensure the model is in evaluation mode
model.eval()
# Function to get the predicted intent and response
def get_intent_and_response(instruction):
# Tokenize the input instruction
inputs = tokenizer(instruction, return_tensors="pt", truncation=True, padding='max_length', max_length=128)
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_label_id = torch.argmax(logits, dim=1).item()
# Decode the predicted label to get the intent
predicted_intent = id2label[predicted_label_id]
# Fetch the appropriate response based on the predicted intent
response = df[df['intent'] == predicted_intent].iloc[0]['response']
return predicted_intent, response
# Streamlit app setup
st.title("Customer Support Chatbot")
st.write("Ask a question, and I'll do my best to help you.")
instruction = st.text_input("You:")
if st.button("Submit"):
if instruction:
predicted_intent, response = get_intent_and_response(instruction)
st.write(f"**Predicted Intent:** {predicted_intent}")
st.write(f"**Assistant:** {response}")
else:
st.write("Please enter an instruction.")
if st.button("Exit"):
st.write("Exiting the chat.")