Thota02 commited on
Commit
41ad847
·
verified ·
1 Parent(s): d15d153

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -17
app.py CHANGED
@@ -7,8 +7,10 @@ from transformers import pipeline
7
  # Configure logging
8
  logging.basicConfig(level=logging.DEBUG)
9
 
 
 
10
  # Load Hugging Face model
11
- MODEL_NAME = "mistralai/Mistral-Large-Instruct-2407" # Hugging Face Model name
12
  llm_pipeline = pipeline("text-generation", model=MODEL_NAME)
13
 
14
  # Load datasets
@@ -38,16 +40,6 @@ fashion_keywords = [
38
  'jackets', 'sweaters', 'suits', 'accessories', 't-shirts'
39
  ]
40
 
41
- # LLM-based function
42
- def query_llm(prompt):
43
- """Query the LLM for responses."""
44
- try:
45
- responses = llm_pipeline(prompt, max_length=150, num_return_sequences=1)
46
- return responses[0]['generated_text'].strip()
47
- except Exception as e:
48
- logging.error(f"Error querying the LLM: {e}")
49
- return "Sorry, I'm having trouble processing your request right now."
50
-
51
  def determine_category(query):
52
  """Determine the category based on the query."""
53
  query_lower = query.lower()
@@ -61,7 +53,84 @@ def determine_category(query):
61
  logging.debug(f"Query '{query}' categorized as 'general'.")
62
  return 'general'
63
 
64
- # Fetch response based on query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def get_response(user_input):
66
  """Determine the category and fetch the appropriate response."""
67
  if 'hi' in user_input.lower() or 'hello' in user_input.lower():
@@ -74,18 +143,60 @@ def get_response(user_input):
74
  elif category == 'fashion':
75
  response = fashion_response(user_input)
76
  else:
77
- # Use LLM for more complex queries
78
- response = query_llm(user_input)
79
 
80
  return response
81
 
82
- # Streamlit Interface remains the same
83
  def main():
84
  st.title("Customer Support Chatbot")
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  if 'chat_history' not in st.session_state:
87
  st.session_state.chat_history = []
88
 
 
 
 
 
 
 
 
 
 
 
89
  user_input = st.text_input("Type your message here:")
90
 
91
  if st.button("Send"):
@@ -93,9 +204,11 @@ def main():
93
  response_message = get_response(user_input)
94
  st.session_state.chat_history.append({"role": "user", "content": user_input})
95
  st.session_state.chat_history.append({"role": "assistant", "content": response_message})
 
 
 
96
 
97
- for message in st.session_state.chat_history:
98
- st.markdown(f"{message['role'].capitalize()}: {message['content']}")
99
 
100
  if __name__ == "__main__":
101
  main()
 
7
  # Configure logging
8
  logging.basicConfig(level=logging.DEBUG)
9
 
10
+ # Set Hugging Face token
11
+ os.environ['HUGGINGFACEHUB_API_TOKEN'] = hf_token
12
  # Load Hugging Face model
13
+ MODEL_NAME = "mistralai/Mistral-Large-Instruct-2407"
14
  llm_pipeline = pipeline("text-generation", model=MODEL_NAME)
15
 
16
  # Load datasets
 
40
  'jackets', 'sweaters', 'suits', 'accessories', 't-shirts'
41
  ]
42
 
 
 
 
 
 
 
 
 
 
 
43
  def determine_category(query):
44
  """Determine the category based on the query."""
45
  query_lower = query.lower()
 
53
  logging.debug(f"Query '{query}' categorized as 'general'.")
54
  return 'general'
55
 
56
+ def format_electronics_response(row):
57
+ """Format response using data from an Electronics DataFrame row."""
58
+ response = (
59
+ f"**Product Name:** {row['ProductName']}\n\n"
60
+ f"**Description:**\n{row['Description']}\n\n"
61
+ f"**Price:** ${row['Price']}\n"
62
+ f"**Brand:** {row['Brand']}\n"
63
+ f"**Model:** {row['Model']}\n"
64
+ f"**Department:** {row['Department']}\n\n"
65
+ f"**Reviews:**\n{row['Reviews']}\n\n"
66
+ f"**Ratings:** {row['Ratings']} / 5\n"
67
+ )
68
+ return response
69
+
70
+ def format_fashion_response(row):
71
+ """Format response using data from a Fashion DataFrame row."""
72
+ response = (
73
+ f"**Product Name:** {row['Name']}\n\n"
74
+ f"**Description:**\n{row['Description']}\n\n"
75
+ f"**Price:** ${row['Price']}\n"
76
+ f"**Brand:** {row['Brand']}\n"
77
+ f"**Model:** {row['Model']}\n\n"
78
+ f"**Rating:** {row['Rating']} / 5\n"
79
+ )
80
+ return response
81
+
82
+ def extract_filters(query):
83
+ """Extract filters from the user's query."""
84
+ filters = {}
85
+ query_lower = query.lower()
86
+
87
+ if 'best' in query_lower:
88
+ if 'rating' in query_lower:
89
+ filters['Rating'] = category_df['Rating'].max()
90
+
91
+ if 'phones' in query_lower:
92
+ filters['Category'] = 'phone'
93
+ elif 'laptops' in query_lower:
94
+ filters['Category'] = 'laptop'
95
+
96
+ return filters
97
+
98
+ def apply_filters(df, filters):
99
+ """Apply filters to a DataFrame based on the provided filter dictionary."""
100
+ for key, value in filters.items():
101
+ if key in df.columns:
102
+ if isinstance(value, str):
103
+ df = df[df[key].str.contains(value, case=False, na=False)]
104
+ elif isinstance(value, (int, float)):
105
+ df = df[df[key] == value]
106
+ return df
107
+
108
+ def fetch_response_from_df(query, category_df, format_response_func):
109
+ """Fetch response from the dataset based on the category and filters."""
110
+ filters = extract_filters(query)
111
+
112
+ filtered_df = apply_filters(category_df, filters)
113
+
114
+ if filtered_df.empty:
115
+ return "Sorry, I couldn't find an answer in our records."
116
+
117
+ responses = []
118
+ for _, row in filtered_df.iterrows():
119
+ responses.append(format_response_func(row))
120
+
121
+ if responses:
122
+ return "\n\n".join(responses)
123
+
124
+ return "Sorry, I couldn't find an answer in our records."
125
+
126
+ def electronics_response(query):
127
+ """Get response from the electronics dataset."""
128
+ return fetch_response_from_df(query, electronics_df, format_electronics_response)
129
+
130
+ def fashion_response(query):
131
+ """Get response from the fashion dataset."""
132
+ return fetch_response_from_df(query, fashion_df, format_fashion_response)
133
+
134
  def get_response(user_input):
135
  """Determine the category and fetch the appropriate response."""
136
  if 'hi' in user_input.lower() or 'hello' in user_input.lower():
 
143
  elif category == 'fashion':
144
  response = fashion_response(user_input)
145
  else:
146
+ response = "Sorry, I couldn't find an answer in our records."
 
147
 
148
  return response
149
 
150
+ # Streamlit Interface
151
  def main():
152
  st.title("Customer Support Chatbot")
153
 
154
+ # Custom CSS for chat bubbles
155
+ st.markdown("""
156
+ <style>
157
+ .chat-container {
158
+ max-width: 800px;
159
+ margin: 0 auto;
160
+ padding: 20px;
161
+ }
162
+ .chat-bubble {
163
+ border-radius: 15px;
164
+ padding: 10px;
165
+ margin: 5px 0;
166
+ max-width: 70%;
167
+ display: inline-block;
168
+ word-wrap: break-word;
169
+ }
170
+ .user-bubble {
171
+ background-color: #DCF8C6;
172
+ float: right;
173
+ text-align: left;
174
+ }
175
+ .assistant-bubble {
176
+ background-color: #FFFFFF;
177
+ float: left;
178
+ text-align: left;
179
+ }
180
+ .chat-history {
181
+ max-height: 500px;
182
+ overflow-y: auto;
183
+ }
184
+ </style>
185
+ """, unsafe_allow_html=True)
186
+
187
  if 'chat_history' not in st.session_state:
188
  st.session_state.chat_history = []
189
 
190
+ st.markdown('<div class="chat-container">', unsafe_allow_html=True)
191
+
192
+ st.markdown('<div class="chat-history">', unsafe_allow_html=True)
193
+ for message in st.session_state.chat_history:
194
+ if message['role'] == 'user':
195
+ st.markdown(f'<div class="chat-bubble user-bubble">{message["content"]}</div>', unsafe_allow_html=True)
196
+ else:
197
+ st.markdown(f'<div class="chat-bubble assistant-bubble">{message["content"]}</div>', unsafe_allow_html=True)
198
+ st.markdown('</div>', unsafe_allow_html=True)
199
+
200
  user_input = st.text_input("Type your message here:")
201
 
202
  if st.button("Send"):
 
204
  response_message = get_response(user_input)
205
  st.session_state.chat_history.append({"role": "user", "content": user_input})
206
  st.session_state.chat_history.append({"role": "assistant", "content": response_message})
207
+
208
+ # Clear the input after sending
209
+ st.session_state.user_input = ''
210
 
211
+ st.markdown('</div>', unsafe_allow_html=True)
 
212
 
213
  if __name__ == "__main__":
214
  main()