Spaces:
Runtime error
Runtime error
import os | |
import pandas as pd | |
import streamlit as st | |
import logging | |
from transformers import pipeline | |
# Configure logging | |
logging.basicConfig(level=logging.DEBUG) | |
# Retrieve Hugging Face API token from environment variables | |
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
if not HUGGINGFACEHUB_API_TOKEN: | |
logging.error("Hugging Face API token is missing or invalid.") | |
st.error("Hugging Face API token is missing or invalid.") | |
st.stop() | |
# Now you can use HUGGINGFACEHUB_API_TOKEN in your code, for example: | |
from transformers import pipeline | |
# Load the model using the Hugging Face API token | |
MODEL_NAME = "mistralai/Mistral-Large-Instruct-2407" | |
llm_pipeline = pipeline("text-generation", model=MODEL_NAME, use_auth_token=HUGGINGFACEHUB_API_TOKEN) | |
# Load datasets | |
def load_datasets(): | |
"""Load datasets from CSV files.""" | |
try: | |
electronics_df = pd.read_csv('electronics.csv') | |
fashion_df = pd.read_csv('fashion.csv') | |
return electronics_df, fashion_df | |
except Exception as e: | |
logging.error(f"Error loading datasets: {e}") | |
st.error("Error loading datasets.") | |
st.stop() | |
electronics_df, fashion_df = load_datasets() | |
# Keywords for routing queries | |
electronics_keywords = [ | |
'electronics', 'device', 'gadget', 'battery', 'performance', | |
'phone', 'mobile', 'laptop', 'tv', 'bluetooth', 'speakers', | |
'washing machine', 'headphones', 'camera', 'tablet', 'charger', | |
'smartwatch', 'refrigerator' | |
] | |
fashion_keywords = [ | |
'fashion', 'clothing', 'size', 'fit', 'material', 'shirt', | |
'pants', 'coats', 'shoes', 'girls dress', 'sarees', 'skirts', | |
'jackets', 'sweaters', 'suits', 'accessories', 't-shirts' | |
] | |
def determine_category(query): | |
"""Determine the category based on the query.""" | |
query_lower = query.lower() | |
if any(keyword in query_lower for keyword in electronics_keywords): | |
logging.debug(f"Query '{query}' categorized as 'electronics'.") | |
return 'electronics' | |
elif any(keyword in query_lower for keyword in fashion_keywords): | |
logging.debug(f"Query '{query}' categorized as 'fashion'.") | |
return 'fashion' | |
else: | |
logging.debug(f"Query '{query}' categorized as 'general'.") | |
return 'general' | |
def format_electronics_response(row): | |
"""Format response using data from an Electronics DataFrame row.""" | |
response = ( | |
f"**Product Name:** {row['ProductName']}\n\n" | |
f"**Description:**\n{row['Description']}\n\n" | |
f"**Price:** ${row['Price']}\n" | |
f"**Brand:** {row['Brand']}\n" | |
f"**Model:** {row['Model']}\n" | |
f"**Department:** {row['Department']}\n\n" | |
f"**Reviews:**\n{row['Reviews']}\n\n" | |
f"**Ratings:** {row['Ratings']} / 5\n" | |
) | |
return response | |
def format_fashion_response(row): | |
"""Format response using data from a Fashion DataFrame row.""" | |
response = ( | |
f"**Product Name:** {row['Name']}\n\n" | |
f"**Description:**\n{row['Description']}\n\n" | |
f"**Price:** ${row['Price']}\n" | |
f"**Brand:** {row['Brand']}\n" | |
f"**Model:** {row['Model']}\n\n" | |
f"**Rating:** {row['Rating']} / 5\n" | |
) | |
return response | |
def extract_filters(query): | |
"""Extract filters from the user's query.""" | |
filters = {} | |
query_lower = query.lower() | |
if 'best' in query_lower: | |
if 'rating' in query_lower: | |
filters['Rating'] = category_df['Rating'].max() | |
if 'phones' in query_lower: | |
filters['Category'] = 'phone' | |
elif 'laptops' in query_lower: | |
filters['Category'] = 'laptop' | |
return filters | |
def apply_filters(df, filters): | |
"""Apply filters to a DataFrame based on the provided filter dictionary.""" | |
for key, value in filters.items(): | |
if key in df.columns: | |
if isinstance(value, str): | |
df = df[df[key].str.contains(value, case=False, na=False)] | |
elif isinstance(value, (int, float)): | |
df = df[df[key] == value] | |
return df | |
def fetch_response_from_df(query, category_df, format_response_func): | |
"""Fetch response from the dataset based on the category and filters.""" | |
filters = extract_filters(query) | |
filtered_df = apply_filters(category_df, filters) | |
if filtered_df.empty: | |
return "Sorry, I couldn't find an answer in our records." | |
responses = [] | |
for _, row in filtered_df.iterrows(): | |
responses.append(format_response_func(row)) | |
if responses: | |
return "\n\n".join(responses) | |
return "Sorry, I couldn't find an answer in our records." | |
def electronics_response(query): | |
"""Get response from the electronics dataset.""" | |
return fetch_response_from_df(query, electronics_df, format_electronics_response) | |
def fashion_response(query): | |
"""Get response from the fashion dataset.""" | |
return fetch_response_from_df(query, fashion_df, format_fashion_response) | |
def get_response(user_input): | |
"""Determine the category and fetch the appropriate response.""" | |
if 'hi' in user_input.lower() or 'hello' in user_input.lower(): | |
return "Hi, welcome to the customer support chatbot. How can I help you?" | |
category = determine_category(user_input) | |
if category == 'electronics': | |
response = electronics_response(user_input) | |
elif category == 'fashion': | |
response = fashion_response(user_input) | |
else: | |
response = "Sorry, I couldn't find an answer in our records." | |
return response | |
# Streamlit Interface | |
def main(): | |
st.title("Customer Support Chatbot") | |
# Custom CSS for chat bubbles | |
st.markdown(""" | |
<style> | |
.chat-container { | |
max-width: 800px; | |
margin: 0 auto; | |
padding: 20px; | |
} | |
.chat-bubble { | |
border-radius: 15px; | |
padding: 10px; | |
margin: 5px 0; | |
max-width: 70%; | |
display: inline-block; | |
word-wrap: break-word; | |
} | |
.user-bubble { | |
background-color: #DCF8C6; | |
float: right; | |
text-align: left; | |
} | |
.assistant-bubble { | |
background-color: #FFFFFF; | |
float: left; | |
text-align: left; | |
} | |
.chat-history { | |
max-height: 500px; | |
overflow-y: auto; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
if 'chat_history' not in st.session_state: | |
st.session_state.chat_history = [] | |
st.markdown('<div class="chat-container">', unsafe_allow_html=True) | |
st.markdown('<div class="chat-history">', unsafe_allow_html=True) | |
for message in st.session_state.chat_history: | |
if message['role'] == 'user': | |
st.markdown(f'<div class="chat-bubble user-bubble">{message["content"]}</div>', unsafe_allow_html=True) | |
else: | |
st.markdown(f'<div class="chat-bubble assistant-bubble">{message["content"]}</div>', unsafe_allow_html=True) | |
st.markdown('</div>', unsafe_allow_html=True) | |
user_input = st.text_input("Type your message here:") | |
if st.button("Send"): | |
if user_input: | |
response_message = get_response(user_input) | |
st.session_state.chat_history.append({"role": "user", "content": user_input}) | |
st.session_state.chat_history.append({"role": "assistant", "content": response_message}) | |
# Clear the input after sending | |
st.session_state.user_input = '' | |
st.markdown('</div>', unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() | |