nileshhanotia commited on
Commit
d7a26ff
·
verified ·
1 Parent(s): ac02b78

Update models/rag_system.py

Browse files
Files changed (1) hide show
  1. models/rag_system.py +54 -32
models/rag_system.py CHANGED
@@ -1,9 +1,9 @@
1
  import os
2
- import pandas as pd
3
- from transformers import pipeline, AutoTokenizer, AutoModel
4
- import torch
5
  import numpy as np
 
 
6
  from sentence_transformers import SentenceTransformer
 
7
  from utils.logger import setup_logger
8
  from utils.model_loader import ModelLoader
9
 
@@ -12,13 +12,18 @@ logger = setup_logger(__name__)
12
  class RAGSystem:
13
  def __init__(self, csv_path="apparel.csv"):
14
  try:
15
- self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
16
- self.setup_system(csv_path)
17
- self.qa_pipeline = ModelLoader.load_model_with_retry(
18
- "distilbert-base-cased-distilled-squad",
19
- pipeline,
20
- task="question-answering"
 
 
21
  )
 
 
 
22
  except Exception as e:
23
  logger.error(f"Failed to initialize RAGSystem: {str(e)}")
24
  raise
@@ -28,41 +33,58 @@ class RAGSystem:
28
  raise FileNotFoundError(f"CSV file not found at {csv_path}")
29
 
30
  try:
 
31
  self.documents = pd.read_csv(csv_path)
 
 
32
  # Create embeddings for all documents
33
- self.doc_embeddings = self.model.encode(
34
- self.documents['Title'].astype(str).tolist(),
35
- convert_to_tensor=True
36
- )
37
  except Exception as e:
38
  logger.error(f"Failed to setup RAG system: {str(e)}")
39
  raise
40
 
41
  def get_relevant_documents(self, query, top_k=5):
42
- # Get query embedding
43
- query_embedding = self.model.encode(query, convert_to_tensor=True)
44
-
45
- # Calculate cosine similarities
46
- cos_scores = torch.nn.functional.cosine_similarity(
47
- query_embedding.unsqueeze(0),
48
- self.doc_embeddings
49
- )
50
-
51
- # Get top_k most similar documents
52
- top_indices = torch.topk(cos_scores, min(top_k, len(self.documents))).indices
53
- return [str(self.documents.iloc[idx]['Title']) for idx in top_indices]
 
 
 
54
 
55
  def process_query(self, query):
56
  try:
57
- retrieved_docs = self.get_relevant_documents(query)
58
- retrieved_text = "\n".join(retrieved_docs)[:1000]
 
 
 
 
 
 
59
 
 
60
  qa_input = {
61
  "question": query,
62
- "context": retrieved_text
63
  }
64
- response = self.qa_pipeline(qa_input)
65
- return response['answer']
 
 
 
 
66
  except Exception as e:
67
- logger.error(f"Query processing error: {str(e)}")
68
- return "Failed to process query due to an error."
 
1
  import os
 
 
 
2
  import numpy as np
3
+ import pandas as pd
4
+ from transformers import pipeline
5
  from sentence_transformers import SentenceTransformer
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
  from utils.logger import setup_logger
8
  from utils.model_loader import ModelLoader
9
 
 
12
  class RAGSystem:
13
  def __init__(self, csv_path="apparel.csv"):
14
  try:
15
+ # Initialize the sentence transformer model
16
+ self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
17
+
18
+ # Initialize the QA pipeline
19
+ self.qa_pipeline = pipeline(
20
+ "question-answering",
21
+ model="distilbert-base-cased-distilled-squad",
22
+ tokenizer="distilbert-base-cased-distilled-squad"
23
  )
24
+
25
+ self.setup_system(csv_path)
26
+
27
  except Exception as e:
28
  logger.error(f"Failed to initialize RAGSystem: {str(e)}")
29
  raise
 
33
  raise FileNotFoundError(f"CSV file not found at {csv_path}")
34
 
35
  try:
36
+ # Load and preprocess documents
37
  self.documents = pd.read_csv(csv_path)
38
+ self.texts = self.documents['Title'].astype(str).tolist()
39
+
40
  # Create embeddings for all documents
41
+ self.embeddings = self.embedder.encode(self.texts)
42
+
43
+ logger.info(f"Successfully loaded {len(self.texts)} documents")
44
+
45
  except Exception as e:
46
  logger.error(f"Failed to setup RAG system: {str(e)}")
47
  raise
48
 
49
  def get_relevant_documents(self, query, top_k=5):
50
+ try:
51
+ # Get query embedding
52
+ query_embedding = self.embedder.encode([query])
53
+
54
+ # Calculate similarities
55
+ similarities = cosine_similarity(query_embedding, self.embeddings)[0]
56
+
57
+ # Get top k most similar documents
58
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
59
+
60
+ return [self.texts[i] for i in top_indices]
61
+
62
+ except Exception as e:
63
+ logger.error(f"Error retrieving relevant documents: {str(e)}")
64
+ return []
65
 
66
  def process_query(self, query):
67
  try:
68
+ # Get relevant documents
69
+ relevant_docs = self.get_relevant_documents(query)
70
+
71
+ if not relevant_docs:
72
+ return "No relevant documents found."
73
+
74
+ # Combine retrieved documents into context
75
+ context = " ".join(relevant_docs)
76
 
77
+ # Prepare QA input
78
  qa_input = {
79
  "question": query,
80
+ "context": context[:512] # Limit context length for the model
81
  }
82
+
83
+ # Get answer using QA pipeline
84
+ answer = self.qa_pipeline(qa_input)
85
+
86
+ return answer['answer']
87
+
88
  except Exception as e:
89
+ logger.error(f"Error processing query: {str(e)}")
90
+ return f"Failed to process query: {str(e)}"