Prathamesh1420 commited on
Commit
296f87c
·
verified ·
1 Parent(s): 3b95758

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +320 -111
chatbot.py CHANGED
@@ -1,120 +1,329 @@
1
- import os
2
- import pickle
3
  import torch
4
- import matplotlib.pyplot as plt
5
- from langchain_community.document_loaders import TextLoader
6
- from datasets import load_dataset
7
- from sentence_transformers import SentenceTransformer, util
8
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
9
- from transformers import BertModel, BertTokenizer
10
- from langchain_core.prompts import PromptTemplate
11
- from transformers import BlipProcessor, BlipForConditionalGeneration
12
- from PIL import Image
13
 
14
- os.environ['HUGGINGFACEHUB_API_TOKEN'] = "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"
 
 
15
 
16
  class Chatbot:
17
  def __init__(self):
18
- self.load_data()
 
 
 
 
 
19
  self.load_models()
20
- self.load_embeddings()
21
- self.load_template()
22
-
23
- def load_data(self):
24
- self.data = load_dataset("ashraq/fashion-product-images-small", split="train")
25
- self.images = self.data["image"]
26
- self.product_frame = self.data.remove_columns("image").to_pandas()
27
- self.product_data = self.product_frame.reset_index(drop=True).to_dict(orient='index')
28
-
29
- def load_template(self):
30
- self.template = """
31
- You are a fashion shopping assistant that wants to convert customers based on the information given.
32
- Describe season and usage given in the context in your interaction with the customer.
33
- Use a bullet list when describing each product.
34
- If user ask general question then answer them accordingly, the question may be like when the store will open, where is your store located.
35
- Context: {context}
36
- User question: {question}
37
- Your response: {response}
38
- """
39
- self.prompt = PromptTemplate.from_template(self.template)
40
-
41
  def load_models(self):
42
- self.model = SentenceTransformer('clip-ViT-B-32')
43
- self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
44
- self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
45
-
46
- def load_embeddings(self):
47
- if os.path.exists("embeddings_cache.pkl"):
48
- with open("embeddings_cache.pkl", "rb") as f:
49
- embeddings_cache = pickle.load(f)
50
- self.image_embeddings = embeddings_cache["image_embeddings"]
51
- self.text_embeddings = embeddings_cache["text_embeddings"]
52
- else:
53
- self.image_embeddings = self.model.encode([image for image in self.images])
54
- self.text_embeddings = self.model.encode(self.product_frame['productDisplayName'])
55
- embeddings_cache = {"image_embeddings": self.image_embeddings, "text_embeddings": self.text_embeddings}
56
- with open("embeddings_cache.pkl", "wb") as f:
57
- pickle.dump(embeddings_cache, f)
58
-
59
- def create_docs(self, results):
60
- docs = []
61
- for result in results:
62
- pid = result['corpus_id']
63
- score = result['score']
64
- result_string = ''
65
- result_string += "Product Name:" + self.product_data[pid]['productDisplayName'] + \
66
- ';' + "Category:" + self.product_data[pid]['masterCategory'] + \
67
- ';' + "Article Type:" + self.product_data[pid]['articleType'] + \
68
- ';' + "Usage:" + self.product_data[pid]['usage'] + \
69
- ';' + "Season:" + self.product_data[pid]['season'] + \
70
- ';' + "Gender:" + self.product_data[pid]['gender']
71
- # Assuming text is imported from somewhere else
72
- doc = text(page_content=result_string)
73
- doc.metadata['pid'] = str(pid)
74
- doc.metadata['score'] = score
75
- docs.append(doc)
76
- return docs
77
-
78
- def get_results(self, query, embeddings, top_k=5):
79
- query_embedding = self.model.encode([query])
80
- cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
81
- top_results = torch.topk(cos_scores, k=top_k)
82
- indices = top_results.indices.tolist()
83
- scores = top_results.values.tolist()
84
- results = [{'corpus_id': idx, 'score': score} for idx, score in zip(indices, scores)]
85
- return results
86
-
87
- def display_text_and_images(self, results_text):
88
- for result in results_text:
89
- pid = result['corpus_id']
90
- product_info = self.product_data[pid]
91
- print("Product Name:", product_info['productDisplayName'])
92
- print("Category:", product_info['masterCategory'])
93
- print("Article Type:", product_info['articleType'])
94
- print("Usage:", product_info['usage'])
95
- print("Season:", product_info['season'])
96
- print("Gender:", product_info['gender'])
97
- print("Score:", result['score'])
98
- plt.imshow(self.images[pid])
99
- plt.axis('off')
100
- plt.show()
101
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def generate_image_caption(self, image_path):
103
- raw_image = Image.open(image_path).convert('RGB')
104
- inputs = self.blip_processor(raw_image, return_tensors="pt")
105
- out = self.blip_model.generate(**inputs)
106
- caption = self.blip_processor.decode(out[0], skip_special_tokens=True)
107
- return caption
108
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def generate_response(self, query):
110
- # Process the user query and generate a response
111
- results_text = self.get_results(query, self.text_embeddings)
112
-
113
- # Generate chatbot response
114
- chatbot_response = "This is a placeholder response from the chatbot." # Placeholder, replace with actual response
115
-
116
- # Display recommended products
117
- self.display_text_and_images(results_text)
118
-
119
- # Return both chatbot response and recommended products
120
- return chatbot_response, results_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import numpy as np
3
+ from sentence_transformers import SentenceTransformer
4
+ import pandas as pd
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ import random
7
+ import logging
8
+ import json
9
+ import os
 
10
 
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
 
15
  class Chatbot:
16
  def __init__(self):
17
+ self.device = 'cpu' # Force CPU usage for Hugging Face Spaces
18
+ logger.info("🚀 Initializing Fashion Chatbot with CPU...")
19
+ self.model = None
20
+ self.product_data = {}
21
+ self.images = {}
22
+ self.product_embeddings = None
23
  self.load_models()
24
+ self.setup_sample_data()
25
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def load_models(self):
27
+ """Load all required models with CPU-only configuration"""
28
+ try:
29
+ logger.info("📥 Loading SentenceTransformer model on CPU...")
30
+
31
+ # Force CPU for all operations
32
+ torch.device('cpu')
33
+
34
+ # Load a lightweight model suitable for CPU
35
+ self.model = SentenceTransformer(
36
+ 'all-MiniLM-L6-v2', # Lightweight model for CPU
37
+ device='cpu'
38
+ )
39
+
40
+ logger.info("✅ Model loaded successfully on CPU")
41
+
42
+ except Exception as e:
43
+ logger.error(f"❌ Error loading model: {e}")
44
+ # Create a dummy model for fallback
45
+ self.model = None
46
+
47
+ def setup_sample_data(self):
48
+ """Setup sample fashion product data for demonstration"""
49
+ logger.info("🛍️ Setting up sample fashion data...")
50
+
51
+ # Sample fashion products data
52
+ self.product_data = {
53
+ 0: {
54
+ 'productDisplayName': 'Classic White T-Shirt',
55
+ 'masterCategory': 'Apparel',
56
+ 'articleType': 'T-Shirt',
57
+ 'usage': 'Casual',
58
+ 'season': 'All Season',
59
+ 'gender': 'Unisex',
60
+ 'baseColour': 'White',
61
+ 'price': 29.99
62
+ },
63
+ 1: {
64
+ 'productDisplayName': 'Denim Jacket',
65
+ 'masterCategory': 'Apparel',
66
+ 'articleType': 'Jacket',
67
+ 'usage': 'Casual',
68
+ 'season': 'Spring, Fall',
69
+ 'gender': 'Unisex',
70
+ 'baseColour': 'Blue',
71
+ 'price': 89.99
72
+ },
73
+ 2: {
74
+ 'productDisplayName': 'Black Leather Boots',
75
+ 'masterCategory': 'Footwear',
76
+ 'articleType': 'Boots',
77
+ 'usage': 'Casual',
78
+ 'season': 'Winter, Fall',
79
+ 'gender': 'Unisex',
80
+ 'baseColour': 'Black',
81
+ 'price': 129.99
82
+ },
83
+ 3: {
84
+ 'productDisplayName': 'Summer Floral Dress',
85
+ 'masterCategory': 'Apparel',
86
+ 'articleType': 'Dress',
87
+ 'usage': 'Casual',
88
+ 'season': 'Summer',
89
+ 'gender': 'Women',
90
+ 'baseColour': 'Multicolor',
91
+ 'price': 59.99
92
+ },
93
+ 4: {
94
+ 'productDisplayName': 'Sports Running Shoes',
95
+ 'masterCategory': 'Footwear',
96
+ 'articleType': 'Sports Shoes',
97
+ 'usage': 'Sports',
98
+ 'season': 'All Season',
99
+ 'gender': 'Unisex',
100
+ 'baseColour': 'White',
101
+ 'price': 79.99
102
+ },
103
+ 5: {
104
+ 'productDisplayName': 'Wool Winter Scarf',
105
+ 'masterCategory': 'Accessories',
106
+ 'articleType': 'Scarf',
107
+ 'usage': 'Casual',
108
+ 'season': 'Winter',
109
+ 'gender': 'Unisex',
110
+ 'baseColour': 'Grey',
111
+ 'price': 34.99
112
+ }
113
+ }
114
+
115
+ # Generate sample product images
116
+ self.images = {}
117
+ for pid in self.product_data.keys():
118
+ self.images[pid] = self.generate_sample_image(pid)
119
+
120
+ # Create sample embeddings for products
121
+ self.create_sample_embeddings()
122
+
123
+ logger.info(f"✅ Loaded {len(self.product_data)} sample products")
124
+
125
+ def generate_sample_image(self, product_id):
126
+ """Generate a sample product image for demonstration"""
127
+ # Create a simple colored image with text
128
+ img = Image.new('RGB', (200, 200), color=self.get_color_for_product(product_id))
129
+ draw = ImageDraw.Draw(img)
130
+
131
+ # Add product type text
132
+ product_type = self.product_data[product_id]['articleType']
133
+ draw.text((50, 90), product_type, fill='white')
134
+
135
+ return img
136
+
137
+ def get_color_for_product(self, product_id):
138
+ """Get color based on product"""
139
+ color_map = {
140
+ 'White': (255, 255, 255),
141
+ 'Blue': (0, 0, 255),
142
+ 'Black': (0, 0, 0),
143
+ 'Multicolor': (255, 0, 0),
144
+ 'Grey': (128, 128, 128)
145
+ }
146
+ base_color = self.product_data[product_id]['baseColour']
147
+ return color_map.get(base_color, (200, 200, 200))
148
+
149
+ def create_sample_embeddings(self):
150
+ """Create sample embeddings for products"""
151
+ try:
152
+ if self.model is not None:
153
+ product_descriptions = []
154
+ for pid, data in self.product_data.items():
155
+ desc = f"{data['productDisplayName']} {data['articleType']} {data['usage']} {data['season']} {data['gender']}"
156
+ product_descriptions.append(desc)
157
+
158
+ self.product_embeddings = self.model.encode(product_descriptions)
159
+ else:
160
+ # Create dummy embeddings
161
+ self.product_embeddings = np.random.randn(len(self.product_data), 384)
162
+ except Exception as e:
163
+ logger.error(f"Error creating embeddings: {e}")
164
+ self.product_embeddings = np.random.randn(len(self.product_data), 384)
165
+
166
+ def load_data(self):
167
+ """Load product data - using sample data for demo"""
168
+ logger.info("📊 Loading product data...")
169
+ # Data is already loaded in setup_sample_data
170
+ pass
171
+
172
  def generate_image_caption(self, image_path):
173
+ """Generate caption for uploaded image"""
174
+ try:
175
+ # For CPU deployment, use a simpler approach
176
+ image = Image.open(image_path)
177
+
178
+ # Simple analysis based on image characteristics
179
+ width, height = image.size
180
+ dominant_color = self.get_dominant_color(image)
181
+
182
+ # Generate descriptive caption based on image properties
183
+ size_desc = "large" if width > 1000 else "medium" if width > 500 else "small"
184
+ color_desc = self.get_color_name(dominant_color)
185
+
186
+ captions = [
187
+ f"A {size_desc} {color_desc} fashion item perfect for your style",
188
+ f"Stylish {color_desc} clothing item that matches current trends",
189
+ f"Fashionable {size_desc} apparel in {color_desc} color",
190
+ f"Trendy {color_desc} fashion piece suitable for various occasions"
191
+ ]
192
+
193
+ return random.choice(captions)
194
+
195
+ except Exception as e:
196
+ logger.error(f"Error generating caption: {e}")
197
+ return "A fashionable clothing item that suits your style"
198
+
199
+ def get_dominant_color(self, image):
200
+ """Get dominant color from image (simplified)"""
201
+ try:
202
+ # Resize image for faster processing
203
+ image = image.resize((50, 50))
204
+ # Convert to numpy array and get average color
205
+ np_image = np.array(image)
206
+ return tuple(np.mean(np_image, axis=(0, 1)).astype(int))
207
+ except:
208
+ return (128, 128, 128) # Default gray
209
+
210
+ def get_color_name(self, rgb):
211
+ """Convert RGB to color name"""
212
+ colors = {
213
+ (255, 255, 255): "white",
214
+ (0, 0, 0): "black",
215
+ (255, 0, 0): "red",
216
+ (0, 255, 0): "green",
217
+ (0, 0, 255): "blue",
218
+ (255, 255, 0): "yellow",
219
+ (128, 128, 128): "gray",
220
+ (255, 165, 0): "orange",
221
+ (128, 0, 128): "purple"
222
+ }
223
+
224
+ # Find closest color
225
+ min_dist = float('inf')
226
+ closest_color = "colored"
227
+ for color, name in colors.items():
228
+ dist = sum((a - b) ** 2 for a, b in zip(rgb, color))
229
+ if dist < min_dist:
230
+ min_dist = dist
231
+ closest_color = name
232
+
233
+ return closest_color
234
+
235
  def generate_response(self, query):
236
+ """Generate chatbot response and recommendations"""
237
+ try:
238
+ # Fashion-related responses
239
+ fashion_responses = {
240
+ 'casual': "Great choice! Casual wear is perfect for everyday comfort and style.",
241
+ 'formal': "Elegant choice! Formal wear always makes a strong impression.",
242
+ 'sports': "Active lifestyle! Sports wear combines comfort and performance.",
243
+ 'summer': "Perfect for warm weather! Light and breathable fabrics work best.",
244
+ 'winter': "Stay warm and stylish! Layering is key for winter fashion.",
245
+ 'dress': "Dresses are versatile and always in style!",
246
+ 'shirt': "Classic shirts never go out of fashion!",
247
+ 'shoes': "The right shoes can complete any outfit!",
248
+ 'jacket': "Jackets add style and functionality to any outfit!"
249
+ }
250
+
251
+ # Generate contextual response
252
+ query_lower = query.lower()
253
+ response_key = None
254
+
255
+ for key in fashion_responses.keys():
256
+ if key in query_lower:
257
+ response_key = key
258
+ break
259
+
260
+ if response_key:
261
+ bot_response = fashion_responses[response_key]
262
+ else:
263
+ generic_responses = [
264
+ f"I found some great fashion items related to '{query}'!",
265
+ f"Based on your interest in '{query}', here are my recommendations:",
266
+ f"Here are some stylish options for '{query}':",
267
+ f"Perfect! I have some fashion suggestions for '{query}':"
268
+ ]
269
+ bot_response = random.choice(generic_responses)
270
+
271
+ # Get recommendations
272
+ recommended_products = self.get_recommendations(query)
273
+
274
+ return bot_response, recommended_products
275
+
276
+ except Exception as e:
277
+ logger.error(f"Error generating response: {e}")
278
+ return "I apologize, but I'm having trouble processing your request right now.", []
279
+
280
+ def get_recommendations(self, query, top_k=3):
281
+ """Get product recommendations based on query"""
282
+ try:
283
+ if self.model is not None and self.product_embeddings is not None:
284
+ # Encode query
285
+ query_embedding = self.model.encode([query])
286
+
287
+ # Calculate similarities (using dot product for simplicity)
288
+ similarities = np.dot(self.product_embeddings, query_embedding.T).flatten()
289
+
290
+ # Get top products
291
+ top_indices = np.argsort(similarities)[::-1][:top_k]
292
+ else:
293
+ # Fallback: random recommendations
294
+ top_indices = random.sample(list(self.product_data.keys()), min(top_k, len(self.product_data)))
295
+
296
+ recommended_products = []
297
+ for idx in top_indices:
298
+ recommended_products.append({
299
+ 'corpus_id': idx,
300
+ 'score': 0.9 - (len(recommended_products) * 0.1)
301
+ })
302
+
303
+ return recommended_products
304
+
305
+ except Exception as e:
306
+ logger.error(f"Error getting recommendations: {e}")
307
+ # Return random products as fallback
308
+ return [{'corpus_id': i, 'score': 0.8} for i in range(min(3, len(self.product_data)))]
309
+
310
+ def get_product_info(self, product_id):
311
+ """Get complete product information"""
312
+ try:
313
+ if product_id in self.product_data:
314
+ data = self.product_data[product_id]
315
+ return {
316
+ 'name': data['productDisplayName'],
317
+ 'category': data['masterCategory'],
318
+ 'article_type': data['articleType'],
319
+ 'usage': data['usage'],
320
+ 'season': data['season'],
321
+ 'gender': data['gender'],
322
+ 'color': data['baseColour'],
323
+ 'price': data['price'],
324
+ 'image': self.images.get(product_id)
325
+ }
326
+ return None
327
+ except Exception as e:
328
+ logger.error(f"Error getting product info: {e}")
329
+ return None