File size: 8,552 Bytes
deb090d
d994686
deb090d
 
 
 
 
5b65de2
deb090d
 
f2611d0
deb090d
 
 
 
 
 
2fb49a3
deb090d
 
 
 
 
 
 
 
 
65726e0
 
 
 
 
deb090d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b65de2
deb090d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b65de2
deb090d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b65de2
deb090d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b65de2
deb090d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3092e52
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
from google import genai
from .vector_store import VectorStore
import PyPDF2
from docx import Document
from typing import List
import os
from langchain_text_splitters import RecursiveCharacterTextSplitter
import asyncio

class RAG:
    def __init__(self, google_api_key: str, collection_name: str = "ca-documents"):
        # Setup Gemini
        # The client gets the API key from the environment variable `GOOGLE_API_KEY`
        # or from the `api_key` argument.
        self.client = genai.Client(api_key=google_api_key)
        
        # Setup Vector Store (Qdrant configuration is handled via environment variables)
        self.vector_store = VectorStore()
        
        # Setup Text Splitter
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200,
            length_function=len,
            separators=["\n\n", "\n", ". ", " ", ""]
        )

    async def initialize(self):
        """Asynchronous initialization to be called after object creation."""
        await self.vector_store.initialize()
        await self.vector_store.verify_collection_health()

    def process_pdf(self, file_path: str) -> List[str]:
        """Extract text from PDF and split into chunks using RecursiveTextSplitter"""
        full_text = ""
        with open(file_path, 'rb') as file:
            pdf_reader = PyPDF2.PdfReader(file)
            for page in pdf_reader.pages:
                text = page.extract_text()
                if text.strip():
                    full_text += text + "\n"
        
        # Use RecursiveCharacterTextSplitter for better chunking
        chunks = self.text_splitter.split_text(full_text)
        return [chunk.strip() for chunk in chunks if chunk.strip()]
    
    def process_docx(self, file_path: str) -> List[str]:
        """Extract text from DOCX and split into chunks using RecursiveTextSplitter"""
        doc = Document(file_path)
        full_text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
        
        # Use RecursiveCharacterTextSplitter for better chunking
        chunks = self.text_splitter.split_text(full_text)
        return [chunk.strip() for chunk in chunks if chunk.strip()]
    
    async def upload_document(self, file_path: str) -> bool:
        """Upload and process document"""
        try:
            filename = os.path.basename(file_path)
            
            if file_path.endswith('.pdf'):
                chunks = self.process_pdf(file_path)
            elif file_path.endswith('.docx'):
                chunks = self.process_docx(file_path)
            elif file_path.endswith('.txt'):
                with open(file_path, 'r', encoding='utf-8') as f:
                    full_text = f.read()
                chunks = self.text_splitter.split_text(full_text)
                chunks = [chunk.strip() for chunk in chunks if chunk.strip()]
            else:
                print("Unsupported file format")
                return False
            
            # Store chunks in Qdrant
            for i, chunk in enumerate(chunks):
                await self.vector_store.add_document(
                    text=chunk,
                    metadata={"source": filename, "chunk_id": i}
                )
            
            print(f"Uploaded {len(chunks)} chunks from {filename}")
            return True
            
        except Exception as e:
            print(f"Error uploading document: {e}")
            return False
    

    def is_casual_conversation(self, question: str) -> bool:
        """Determine if the question is casual conversation vs CA-specific query"""
        
        question_lower = question.lower().strip()
        
        # Pure casual greetings (exact matches or very short)
        pure_casual = [
            'hello', 'hi', 'hey', 'good morning', 'good afternoon', 'good evening',
            'how are you', 'what\'s up', 'greetings', 'thanks', 'thank you',
            'bye', 'goodbye', 'see you', 'nice to meet you', 'who are you',
            'what can you do', 'help me', 'what is your name', 'introduce yourself',
            'how do you work', 'what are you', 'can you help me'
        ]
        
        # Check for exact matches first
        if question_lower in pure_casual:
            return True
        
        # Check if it's a very short greeting (≤ 4 words) without technical terms
        words = question_lower.split()
        if len(words) <= 4:
            # Technical/question indicators
            technical_indicators = [
                'what', 'how', 'why', 'when', 'where', 'explain', 'define', 'calculate',
                'accounting', 'audit', 'tax', 'finance', 'depreciation', 'balance', 'sheet',
                'profit', 'loss', 'asset', 'liability', 'equity', 'revenue', 'expense',
                'journal', 'ledger', 'trial', 'cash', 'flow', 'ratio', 'analysis'
            ]
            
            # If no technical indicators and contains casual words, it's casual
            has_casual = any(casual in question_lower for casual in ['hello', 'hi', 'hey', 'thanks', 'bye'])
            has_technical = any(tech in question_lower for tech in technical_indicators)
            
            if has_casual and not has_technical:
                return True
        
        # Check for greetings followed by actual questions
        # Pattern: "hello, what is..." or "hi there, how do..."
        greeting_patterns = [
            r'^(hello|hi|hey|good morning|good afternoon|good evening),?\s+(what|how|why|when|where|explain|define|tell|can you)',
            r'^(hello|hi|hey)\s+(there|everyone)?,?\s+(what|how|why|when|where|explain|define|tell|can you)'
        ]
        
        import re
        for pattern in greeting_patterns:
            if re.search(pattern, question_lower):
                return False  # It's a question with greeting prefix, not pure casual
        
        return False

    async def ask_question_stream(self, question: str):
        """Ask a question and get a streaming answer"""
        try:
            # 1. Check if this is casual conversation
            if self.is_casual_conversation(question):
                # Respond as a friendly CA assistant for casual conversation
                casual_prompt = f"""You are a friendly CA (Chartered Accountant) study assistant. The user said: "{question}"

Respond naturally and warmly as a CA study assistant. Be helpful and mention that you can help with CA studies, accounting concepts, financial topics, etc. Keep it brief but friendly."""

                for chunk in self.client.models.generate_content_stream(
                    model='gemini-2.5-flash',
                    contents=casual_prompt
                ):
                    yield chunk.text
                return

            # 2. For CA-specific questions, search for similar documents
            similar_docs = await self.vector_store.search_similar(question, limit=3)

            if similar_docs and len(similar_docs) > 0:
                # 3. Create context from similar documents
                context = "\n\n".join([doc["text"] for doc in similar_docs])

                # 4. Create prompt for Gemini with context
                prompt = f"""You are a CA study assistant. Based on the following context from uploaded documents, answer the question.

Context:
{context}

Question: {question}

Please provide a detailed answer based on the context above. If you need more specific information, suggest what documents might be helpful."""

            else:
                # 5. No documents found, but still be helpful
                prompt = f"""You are a CA (Chartered Accountant) study assistant. The user asked: "{question}"

Even though no specific study materials have been uploaded yet, provide a helpful answer based on your knowledge of CA studies, accounting, finance, taxation, and auditing. Be informative and suggest that uploading relevant study materials would help provide more specific and detailed answers.

Question: {question}"""

            # 6. Get answer from Gemini
            for chunk in self.client.models.generate_content_stream(
                model='gemini-2.5-flash',
                contents=prompt
            ):
                yield chunk.text

        except Exception as e:
            yield f"Error generating answer: {e}"

# Simple usage example
# if __name__ == "__main__":
#     # Initialize
#     rag = RAG(
#         google_api_key="your_google_api_key",
#         collection_name="ca-documents"
#     )