Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -7,8 +7,10 @@ from transformers import pipeline
|
|
7 |
# Configure logging
|
8 |
logging.basicConfig(level=logging.DEBUG)
|
9 |
|
|
|
|
|
10 |
# Load Hugging Face model
|
11 |
-
MODEL_NAME = "mistralai/Mistral-Large-Instruct-2407"
|
12 |
llm_pipeline = pipeline("text-generation", model=MODEL_NAME)
|
13 |
|
14 |
# Load datasets
|
@@ -38,16 +40,6 @@ fashion_keywords = [
|
|
38 |
'jackets', 'sweaters', 'suits', 'accessories', 't-shirts'
|
39 |
]
|
40 |
|
41 |
-
# LLM-based function
|
42 |
-
def query_llm(prompt):
|
43 |
-
"""Query the LLM for responses."""
|
44 |
-
try:
|
45 |
-
responses = llm_pipeline(prompt, max_length=150, num_return_sequences=1)
|
46 |
-
return responses[0]['generated_text'].strip()
|
47 |
-
except Exception as e:
|
48 |
-
logging.error(f"Error querying the LLM: {e}")
|
49 |
-
return "Sorry, I'm having trouble processing your request right now."
|
50 |
-
|
51 |
def determine_category(query):
|
52 |
"""Determine the category based on the query."""
|
53 |
query_lower = query.lower()
|
@@ -61,7 +53,84 @@ def determine_category(query):
|
|
61 |
logging.debug(f"Query '{query}' categorized as 'general'.")
|
62 |
return 'general'
|
63 |
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
def get_response(user_input):
|
66 |
"""Determine the category and fetch the appropriate response."""
|
67 |
if 'hi' in user_input.lower() or 'hello' in user_input.lower():
|
@@ -74,18 +143,60 @@ def get_response(user_input):
|
|
74 |
elif category == 'fashion':
|
75 |
response = fashion_response(user_input)
|
76 |
else:
|
77 |
-
|
78 |
-
response = query_llm(user_input)
|
79 |
|
80 |
return response
|
81 |
|
82 |
-
# Streamlit Interface
|
83 |
def main():
|
84 |
st.title("Customer Support Chatbot")
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
if 'chat_history' not in st.session_state:
|
87 |
st.session_state.chat_history = []
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
user_input = st.text_input("Type your message here:")
|
90 |
|
91 |
if st.button("Send"):
|
@@ -93,9 +204,11 @@ def main():
|
|
93 |
response_message = get_response(user_input)
|
94 |
st.session_state.chat_history.append({"role": "user", "content": user_input})
|
95 |
st.session_state.chat_history.append({"role": "assistant", "content": response_message})
|
|
|
|
|
|
|
96 |
|
97 |
-
|
98 |
-
st.markdown(f"{message['role'].capitalize()}: {message['content']}")
|
99 |
|
100 |
if __name__ == "__main__":
|
101 |
main()
|
|
|
7 |
# Configure logging
|
8 |
logging.basicConfig(level=logging.DEBUG)
|
9 |
|
10 |
+
# Set Hugging Face token
|
11 |
+
os.environ['HUGGINGFACEHUB_API_TOKEN'] = hf_token
|
12 |
# Load Hugging Face model
|
13 |
+
MODEL_NAME = "mistralai/Mistral-Large-Instruct-2407"
|
14 |
llm_pipeline = pipeline("text-generation", model=MODEL_NAME)
|
15 |
|
16 |
# Load datasets
|
|
|
40 |
'jackets', 'sweaters', 'suits', 'accessories', 't-shirts'
|
41 |
]
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
def determine_category(query):
|
44 |
"""Determine the category based on the query."""
|
45 |
query_lower = query.lower()
|
|
|
53 |
logging.debug(f"Query '{query}' categorized as 'general'.")
|
54 |
return 'general'
|
55 |
|
56 |
+
def format_electronics_response(row):
|
57 |
+
"""Format response using data from an Electronics DataFrame row."""
|
58 |
+
response = (
|
59 |
+
f"**Product Name:** {row['ProductName']}\n\n"
|
60 |
+
f"**Description:**\n{row['Description']}\n\n"
|
61 |
+
f"**Price:** ${row['Price']}\n"
|
62 |
+
f"**Brand:** {row['Brand']}\n"
|
63 |
+
f"**Model:** {row['Model']}\n"
|
64 |
+
f"**Department:** {row['Department']}\n\n"
|
65 |
+
f"**Reviews:**\n{row['Reviews']}\n\n"
|
66 |
+
f"**Ratings:** {row['Ratings']} / 5\n"
|
67 |
+
)
|
68 |
+
return response
|
69 |
+
|
70 |
+
def format_fashion_response(row):
|
71 |
+
"""Format response using data from a Fashion DataFrame row."""
|
72 |
+
response = (
|
73 |
+
f"**Product Name:** {row['Name']}\n\n"
|
74 |
+
f"**Description:**\n{row['Description']}\n\n"
|
75 |
+
f"**Price:** ${row['Price']}\n"
|
76 |
+
f"**Brand:** {row['Brand']}\n"
|
77 |
+
f"**Model:** {row['Model']}\n\n"
|
78 |
+
f"**Rating:** {row['Rating']} / 5\n"
|
79 |
+
)
|
80 |
+
return response
|
81 |
+
|
82 |
+
def extract_filters(query):
|
83 |
+
"""Extract filters from the user's query."""
|
84 |
+
filters = {}
|
85 |
+
query_lower = query.lower()
|
86 |
+
|
87 |
+
if 'best' in query_lower:
|
88 |
+
if 'rating' in query_lower:
|
89 |
+
filters['Rating'] = category_df['Rating'].max()
|
90 |
+
|
91 |
+
if 'phones' in query_lower:
|
92 |
+
filters['Category'] = 'phone'
|
93 |
+
elif 'laptops' in query_lower:
|
94 |
+
filters['Category'] = 'laptop'
|
95 |
+
|
96 |
+
return filters
|
97 |
+
|
98 |
+
def apply_filters(df, filters):
|
99 |
+
"""Apply filters to a DataFrame based on the provided filter dictionary."""
|
100 |
+
for key, value in filters.items():
|
101 |
+
if key in df.columns:
|
102 |
+
if isinstance(value, str):
|
103 |
+
df = df[df[key].str.contains(value, case=False, na=False)]
|
104 |
+
elif isinstance(value, (int, float)):
|
105 |
+
df = df[df[key] == value]
|
106 |
+
return df
|
107 |
+
|
108 |
+
def fetch_response_from_df(query, category_df, format_response_func):
|
109 |
+
"""Fetch response from the dataset based on the category and filters."""
|
110 |
+
filters = extract_filters(query)
|
111 |
+
|
112 |
+
filtered_df = apply_filters(category_df, filters)
|
113 |
+
|
114 |
+
if filtered_df.empty:
|
115 |
+
return "Sorry, I couldn't find an answer in our records."
|
116 |
+
|
117 |
+
responses = []
|
118 |
+
for _, row in filtered_df.iterrows():
|
119 |
+
responses.append(format_response_func(row))
|
120 |
+
|
121 |
+
if responses:
|
122 |
+
return "\n\n".join(responses)
|
123 |
+
|
124 |
+
return "Sorry, I couldn't find an answer in our records."
|
125 |
+
|
126 |
+
def electronics_response(query):
|
127 |
+
"""Get response from the electronics dataset."""
|
128 |
+
return fetch_response_from_df(query, electronics_df, format_electronics_response)
|
129 |
+
|
130 |
+
def fashion_response(query):
|
131 |
+
"""Get response from the fashion dataset."""
|
132 |
+
return fetch_response_from_df(query, fashion_df, format_fashion_response)
|
133 |
+
|
134 |
def get_response(user_input):
|
135 |
"""Determine the category and fetch the appropriate response."""
|
136 |
if 'hi' in user_input.lower() or 'hello' in user_input.lower():
|
|
|
143 |
elif category == 'fashion':
|
144 |
response = fashion_response(user_input)
|
145 |
else:
|
146 |
+
response = "Sorry, I couldn't find an answer in our records."
|
|
|
147 |
|
148 |
return response
|
149 |
|
150 |
+
# Streamlit Interface
|
151 |
def main():
|
152 |
st.title("Customer Support Chatbot")
|
153 |
|
154 |
+
# Custom CSS for chat bubbles
|
155 |
+
st.markdown("""
|
156 |
+
<style>
|
157 |
+
.chat-container {
|
158 |
+
max-width: 800px;
|
159 |
+
margin: 0 auto;
|
160 |
+
padding: 20px;
|
161 |
+
}
|
162 |
+
.chat-bubble {
|
163 |
+
border-radius: 15px;
|
164 |
+
padding: 10px;
|
165 |
+
margin: 5px 0;
|
166 |
+
max-width: 70%;
|
167 |
+
display: inline-block;
|
168 |
+
word-wrap: break-word;
|
169 |
+
}
|
170 |
+
.user-bubble {
|
171 |
+
background-color: #DCF8C6;
|
172 |
+
float: right;
|
173 |
+
text-align: left;
|
174 |
+
}
|
175 |
+
.assistant-bubble {
|
176 |
+
background-color: #FFFFFF;
|
177 |
+
float: left;
|
178 |
+
text-align: left;
|
179 |
+
}
|
180 |
+
.chat-history {
|
181 |
+
max-height: 500px;
|
182 |
+
overflow-y: auto;
|
183 |
+
}
|
184 |
+
</style>
|
185 |
+
""", unsafe_allow_html=True)
|
186 |
+
|
187 |
if 'chat_history' not in st.session_state:
|
188 |
st.session_state.chat_history = []
|
189 |
|
190 |
+
st.markdown('<div class="chat-container">', unsafe_allow_html=True)
|
191 |
+
|
192 |
+
st.markdown('<div class="chat-history">', unsafe_allow_html=True)
|
193 |
+
for message in st.session_state.chat_history:
|
194 |
+
if message['role'] == 'user':
|
195 |
+
st.markdown(f'<div class="chat-bubble user-bubble">{message["content"]}</div>', unsafe_allow_html=True)
|
196 |
+
else:
|
197 |
+
st.markdown(f'<div class="chat-bubble assistant-bubble">{message["content"]}</div>', unsafe_allow_html=True)
|
198 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
199 |
+
|
200 |
user_input = st.text_input("Type your message here:")
|
201 |
|
202 |
if st.button("Send"):
|
|
|
204 |
response_message = get_response(user_input)
|
205 |
st.session_state.chat_history.append({"role": "user", "content": user_input})
|
206 |
st.session_state.chat_history.append({"role": "assistant", "content": response_message})
|
207 |
+
|
208 |
+
# Clear the input after sending
|
209 |
+
st.session_state.user_input = ''
|
210 |
|
211 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
|
|
212 |
|
213 |
if __name__ == "__main__":
|
214 |
main()
|