alx-d commited on
Commit
10a7b38
·
verified ·
1 Parent(s): 23e5fe5

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. advanced_rag.py +1113 -1102
advanced_rag.py CHANGED
@@ -1,1102 +1,1113 @@
1
- import os
2
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
- import datetime
4
- import functools
5
- import traceback
6
- from typing import List, Optional, Any, Dict
7
-
8
- import torch
9
- import transformers
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
- from langchain_community.llms import HuggingFacePipeline
12
-
13
- # Other LangChain and community imports
14
- from langchain_community.document_loaders import OnlinePDFLoader
15
- from langchain.text_splitter import RecursiveCharacterTextSplitter
16
- from langchain_community.vectorstores import FAISS
17
- from langchain.embeddings import HuggingFaceEmbeddings
18
- from langchain_community.retrievers import BM25Retriever
19
- from langchain.retrievers import EnsembleRetriever
20
- from langchain.prompts import ChatPromptTemplate
21
- from langchain.schema import StrOutputParser, Document
22
- from langchain_core.runnables import RunnableParallel, RunnableLambda
23
- from transformers.quantizers.auto import AutoQuantizationConfig
24
- import gradio as gr
25
- import requests
26
- from pydantic import PrivateAttr
27
- import pydantic
28
-
29
- from langchain.llms.base import LLM
30
- from typing import Any, Optional, List
31
- import typing
32
- import time
33
-
34
- print("Pydantic Version: ")
35
- print(pydantic.__version__)
36
- # Add Mistral imports with fallback handling
37
-
38
- try:
39
- from mistralai import Mistral
40
- MISTRAL_AVAILABLE = True
41
- debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
42
- debug_print("Loaded latest Mistral client library")
43
- except ImportError:
44
- MISTRAL_AVAILABLE = False
45
- debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
46
- debug_print("Mistral client library not found. Install with: pip install mistralai")
47
-
48
- def debug_print(message: str):
49
- print(f"[{datetime.datetime.now().isoformat()}] {message}", flush=True)
50
-
51
- def word_count(text: str) -> int:
52
- return len(text.split())
53
-
54
- # Initialize a tokenizer for token counting (using gpt2 as a generic fallback)
55
- def initialize_tokenizer():
56
- try:
57
- return AutoTokenizer.from_pretrained("gpt2")
58
- except Exception as e:
59
- debug_print("Failed to initialize tokenizer: " + str(e))
60
- return None
61
-
62
- global_tokenizer = initialize_tokenizer()
63
-
64
- def count_tokens(text: str) -> int:
65
- if global_tokenizer:
66
- try:
67
- return len(global_tokenizer.encode(text))
68
- except Exception as e:
69
- return len(text.split())
70
- return len(text.split())
71
-
72
-
73
- # Add these imports at the top of your file
74
- import uuid
75
- import threading
76
- import queue
77
- from typing import Dict, Any, Tuple, Optional
78
- import time
79
-
80
- # Global storage for jobs and results
81
- jobs = {} # Stores job status and results
82
- results_queue = queue.Queue() # Thread-safe queue for completed jobs
83
- processing_lock = threading.Lock() # Prevent simultaneous processing of the same job
84
-
85
- # Add a global variable to store the last job ID
86
- last_job_id = None
87
-
88
- # Add these missing async processing functions
89
-
90
- def process_in_background(job_id, function, args):
91
- """Process a function in the background and store results"""
92
- try:
93
- debug_print(f"Processing job {job_id} in background")
94
- result = function(*args)
95
- results_queue.put((job_id, result))
96
- debug_print(f"Job {job_id} completed and added to results queue")
97
- except Exception as e:
98
- debug_print(f"Error in background job {job_id}: {str(e)}")
99
- error_result = (f"Error processing job: {str(e)}", "", "", "")
100
- results_queue.put((job_id, error_result))
101
-
102
- def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
103
- """Asynchronous version of load_pdfs_updated to prevent timeouts"""
104
- global last_job_id
105
- if not file_links:
106
- return "Please enter non-empty URLs", "", "Model used: N/A", "", "", get_job_list()
107
-
108
- job_id = str(uuid.uuid4())
109
- debug_print(f"Starting async job {job_id} for file loading")
110
-
111
- # Start background thread
112
- threading.Thread(
113
- target=process_in_background,
114
- args=(job_id, load_pdfs_updated, [file_links, model_choice, prompt_template, bm25_weight, temperature, top_p])
115
- ).start()
116
-
117
- job_query = f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
118
- jobs[job_id] = {
119
- "status": "processing",
120
- "type": "load_files",
121
- "start_time": time.time(),
122
- "query": job_query
123
- }
124
-
125
- last_job_id = job_id
126
-
127
- return (
128
- f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
129
- f"Use 'Check Job Status' tab with this ID to get results.",
130
- f"Job ID: {job_id}",
131
- f"Model requested: {model_choice}",
132
- job_id, # Return job_id to update the job_id_input component
133
- job_query, # Return job_query to update the job_query_display component
134
- get_job_list() # Return updated job list
135
- )
136
-
137
- def submit_query_async(query, model_choice=None):
138
- """Asynchronous version of submit_query_updated to prevent timeouts"""
139
- global last_job_id
140
- if not query:
141
- return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
142
-
143
- job_id = str(uuid.uuid4())
144
- debug_print(f"Starting async job {job_id} for query: {query}")
145
-
146
- # Update model if specified
147
- if model_choice and rag_chain and rag_chain.llm_choice != model_choice:
148
- debug_print(f"Updating model to {model_choice} for this query")
149
- rag_chain.update_llm_pipeline(model_choice, rag_chain.temperature, rag_chain.top_p,
150
- rag_chain.prompt_template, rag_chain.bm25_weight)
151
-
152
- # Start background thread
153
- threading.Thread(
154
- target=process_in_background,
155
- args=(job_id, submit_query_updated, [query])
156
- ).start()
157
-
158
- jobs[job_id] = {
159
- "status": "processing",
160
- "type": "query",
161
- "start_time": time.time(),
162
- "query": query,
163
- "model": rag_chain.llm_choice if hasattr(rag_chain, 'llm_choice') else "Unknown"
164
- }
165
-
166
- last_job_id = job_id
167
-
168
- return (
169
- f"Query submitted and processing in the background (Job ID: {job_id}).\n\n"
170
- f"Use 'Check Job Status' tab with this ID to get results.",
171
- f"Job ID: {job_id}",
172
- f"Input tokens: {count_tokens(query)}",
173
- "Output tokens: pending",
174
- job_id, # Return job_id to update the job_id_input component
175
- query, # Return query to update the job_query_display component
176
- get_job_list() # Return updated job list
177
- )
178
-
179
- def update_ui_with_last_job_id():
180
- # This function doesn't need to do anything anymore
181
- # We'll update the UI directly in the functions that call this
182
- pass
183
-
184
- # Function to display all jobs as a clickable list
185
- def get_job_list():
186
- job_list_md = "### Submitted Jobs\n\n"
187
-
188
- if not jobs:
189
- return "No jobs found. Submit a query or load files to create jobs."
190
-
191
- # Sort jobs by start time (newest first)
192
- sorted_jobs = sorted(
193
- [(job_id, job_info) for job_id, job_info in jobs.items()],
194
- key=lambda x: x[1].get("start_time", 0),
195
- reverse=True
196
- )
197
-
198
- for job_id, job_info in sorted_jobs:
199
- status = job_info.get("status", "unknown")
200
- job_type = job_info.get("type", "unknown")
201
- query = job_info.get("query", "")
202
- start_time = job_info.get("start_time", 0)
203
- time_str = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
204
-
205
- # Create a shortened query preview
206
- query_preview = query[:30] + "..." if query and len(query) > 30 else query or "N/A"
207
-
208
- # Create clickable links using Markdown
209
- if job_type == "query":
210
- job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status} - Query: {query_preview}\n"
211
- else:
212
- job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status} - File Load Job\n"
213
-
214
- return job_list_md
215
-
216
- # Function to handle job list clicks
217
- def job_selected(job_id):
218
- if job_id in jobs:
219
- return job_id, jobs[job_id].get("query", "No query for this job")
220
- return job_id, "Job not found"
221
-
222
- # Function to refresh the job list
223
- def refresh_job_list():
224
- return get_job_list()
225
-
226
- # Function to sync model dropdown boxes
227
- def sync_model_dropdown(value):
228
- return value
229
-
230
- # Function to check job status
231
- def check_job_status(job_id):
232
- if not job_id:
233
- return "Please enter a job ID", "", "", "", ""
234
-
235
- # Process any completed jobs in the queue
236
- try:
237
- while not results_queue.empty():
238
- completed_id, result = results_queue.get_nowait()
239
- if completed_id in jobs:
240
- jobs[completed_id]["status"] = "completed"
241
- jobs[completed_id]["result"] = result
242
- jobs[completed_id]["end_time"] = time.time()
243
- debug_print(f"Job {completed_id} completed and stored in jobs dictionary")
244
- except queue.Empty:
245
- pass
246
-
247
- # Check if the requested job exists
248
- if job_id not in jobs:
249
- return "Job not found. Please check the ID and try again.", "", "", "", ""
250
-
251
- job = jobs[job_id]
252
- job_query = job.get("query", "No query available for this job")
253
-
254
- # If job is still processing
255
- if job["status"] == "processing":
256
- elapsed_time = time.time() - job["start_time"]
257
- job_type = job.get("type", "unknown")
258
-
259
- if job_type == "load_files":
260
- return (
261
- f"Files are still being processed (elapsed: {elapsed_time:.1f}s).\n\n"
262
- f"Try checking again in a few seconds.",
263
- f"Job ID: {job_id}",
264
- f"Status: Processing",
265
- "",
266
- job_query
267
- )
268
- else: # query job
269
- return (
270
- f"Query is still being processed (elapsed: {elapsed_time:.1f}s).\n\n"
271
- f"Try checking again in a few seconds.",
272
- f"Job ID: {job_id}",
273
- f"Input tokens: {count_tokens(job.get('query', ''))}",
274
- "Output tokens: pending",
275
- job_query
276
- )
277
-
278
- # If job is completed
279
- if job["status"] == "completed":
280
- result = job["result"]
281
- processing_time = job["end_time"] - job["start_time"]
282
-
283
- if job.get("type") == "load_files":
284
- return (
285
- f"{result[0]}\n\nProcessing time: {processing_time:.1f}s",
286
- result[1],
287
- result[2],
288
- "",
289
- job_query
290
- )
291
- else: # query job
292
- return (
293
- f"{result[0]}\n\nProcessing time: {processing_time:.1f}s",
294
- result[1],
295
- result[2],
296
- result[3],
297
- job_query
298
- )
299
-
300
- # Fallback for unknown status
301
- return f"Job status: {job['status']}", "", "", "", job_query
302
-
303
- # Function to clean up old jobs
304
- def cleanup_old_jobs():
305
- current_time = time.time()
306
- to_delete = []
307
-
308
- for job_id, job in jobs.items():
309
- # Keep completed jobs for 1 hour, processing jobs for 2 hours
310
- if job["status"] == "completed" and (current_time - job.get("end_time", 0)) > 3600:
311
- to_delete.append(job_id)
312
- elif job["status"] == "processing" and (current_time - job.get("start_time", 0)) > 7200:
313
- to_delete.append(job_id)
314
-
315
- for job_id in to_delete:
316
- del jobs[job_id]
317
-
318
- debug_print(f"Cleaned up {len(to_delete)} old jobs. {len(jobs)} jobs remaining.")
319
- return f"Cleaned up {len(to_delete)} old jobs", "", ""
320
-
321
- # Improve the truncate_prompt function to be more aggressive with limiting context
322
- def truncate_prompt(prompt: str, max_tokens: int = 4096) -> str:
323
- """Truncate prompt to fit within token limit, preserving the most recent/relevant parts."""
324
- if not prompt:
325
- return ""
326
-
327
- if global_tokenizer:
328
- try:
329
- tokens = global_tokenizer.encode(prompt)
330
- if len(tokens) > max_tokens:
331
- # For prompts, we often want to keep the beginning instructions and the end context
332
- # So we'll keep the first 20% and the last 80% of the max tokens
333
- beginning_tokens = int(max_tokens * 0.2)
334
- ending_tokens = max_tokens - beginning_tokens
335
-
336
- new_tokens = tokens[:beginning_tokens] + tokens[-(ending_tokens):]
337
- return global_tokenizer.decode(new_tokens)
338
- except Exception as e:
339
- debug_print(f"Truncation error: {str(e)}")
340
-
341
- # Fallback to word-based truncation
342
- words = prompt.split()
343
- if len(words) > max_tokens:
344
- beginning_words = int(max_tokens * 0.2)
345
- ending_words = max_tokens - beginning_words
346
-
347
- return " ".join(words[:beginning_words] + words[-(ending_words):])
348
-
349
- return prompt
350
-
351
-
352
-
353
-
354
- default_prompt = """\
355
- {conversation_history}
356
- Use the following context to provide a detailed technical answer to the user's question.
357
- Do not include an introduction like "Based on the provided documents, ...". Just answer the question.
358
- If you don't know the answer, please respond with "I don't know".
359
-
360
- Context:
361
- {context}
362
-
363
- User's question:
364
- {question}
365
- """
366
-
367
- def load_txt_from_url(url: str) -> Document:
368
- response = requests.get(url)
369
- if response.status_code == 200:
370
- text = response.text.strip()
371
- if not text:
372
- raise ValueError(f"TXT file at {url} is empty.")
373
- return Document(page_content=text, metadata={"source": url})
374
- else:
375
- raise Exception(f"Failed to load {url} with status {response.status_code}")
376
-
377
- class ElevatedRagChain:
378
- def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
379
- bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
380
- debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
381
- self.embed_func = HuggingFaceEmbeddings(
382
- model_name="sentence-transformers/all-MiniLM-L6-v2",
383
- model_kwargs={"device": "cpu"}
384
- )
385
- self.bm25_weight = bm25_weight
386
- self.faiss_weight = 1.0 - bm25_weight
387
- self.top_k = 5
388
- self.llm_choice = llm_choice
389
- self.temperature = temperature
390
- self.top_p = top_p
391
- self.prompt_template = prompt_template
392
- self.context = ""
393
- self.conversation_history: List[Dict[str, str]] = []
394
- self.raw_data = None
395
- self.split_data = None
396
- self.elevated_rag_chain = None
397
-
398
- # Instance method to capture context and conversation history
399
- def capture_context(self, result):
400
- self.context = "\n".join([str(doc) for doc in result["context"]])
401
- result["context"] = self.context
402
- history_text = (
403
- "\n".join([f"Q: {conv['query']}\nA: {conv['response']}" for conv in self.conversation_history])
404
- if self.conversation_history else ""
405
- )
406
- result["conversation_history"] = history_text
407
- return result
408
-
409
- # Instance method to extract question from input data
410
- def extract_question(self, input_data):
411
- return input_data["question"]
412
-
413
- # Improve error handling in the ElevatedRagChain class
414
- def create_llm_pipeline(self):
415
- from langchain.llms.base import LLM # Import LLM here so it's always defined
416
- normalized = self.llm_choice.lower()
417
- try:
418
- if "remote" in normalized:
419
- debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
420
- from huggingface_hub import InferenceClient
421
- repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
422
- hf_api_token = os.environ.get("HF_API_TOKEN")
423
- if not hf_api_token:
424
- raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
425
-
426
- client = InferenceClient(token=hf_api_token, timeout=120)
427
-
428
- # We no longer use wait_for_model because it's unsupported
429
- def remote_generate(prompt: str) -> str:
430
- max_retries = 3
431
- backoff = 2 # start with 2 seconds
432
- for attempt in range(max_retries):
433
- try:
434
- debug_print(f"Remote generation attempt {attempt+1}")
435
- response = client.text_generation(
436
- prompt,
437
- model=repo_id,
438
- temperature=self.temperature,
439
- top_p=self.top_p,
440
- max_new_tokens=512 # Reduced token count for speed
441
- )
442
- return response
443
- except Exception as e:
444
- debug_print(f"Attempt {attempt+1} failed with error: {e}")
445
- if attempt == max_retries - 1:
446
- raise
447
- time.sleep(backoff)
448
- backoff *= 2 # exponential backoff
449
- return "Failed to generate response after multiple attempts."
450
-
451
- class RemoteLLM(LLM):
452
- @property
453
- def _llm_type(self) -> str:
454
- return "remote_llm"
455
-
456
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
457
- return remote_generate(prompt)
458
-
459
- @property
460
- def _identifying_params(self) -> dict:
461
- return {"model": repo_id}
462
-
463
- debug_print("Remote Meta-Llama-3 pipeline created successfully.")
464
- return RemoteLLM()
465
-
466
- elif "mistral-api" in normalized:
467
- debug_print("Creating Mistral API pipeline...")
468
- mistral_api_key = os.environ.get("MISTRAL_API_KEY")
469
- if not mistral_api_key:
470
- raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
471
- try:
472
- from mistralai import Mistral
473
- debug_print("Mistral library imported successfully")
474
- except ImportError:
475
- debug_print("Mistral client library not installed. Falling back to Llama pipeline.")
476
- normalized = "llama"
477
- if normalized != "llama":
478
- # from pydantic import PrivateAttr
479
- # from langchain.llms.base import LLM
480
- # from typing import Any, Optional, List
481
- # import typing
482
-
483
- class MistralLLM(LLM):
484
- temperature: float = 0.7
485
- top_p: float = 0.95
486
- _client: Any = PrivateAttr(default=None)
487
-
488
- def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
489
- try:
490
- super().__init__(**kwargs)
491
- # Bypass Pydantic's __setattr__ to assign to _client
492
- object.__setattr__(self, '_client', Mistral(api_key=api_key))
493
- self.temperature = temperature
494
- self.top_p = top_p
495
- except Exception as e:
496
- debug_print(f"Init Mistral failed with error: {e}")
497
-
498
- @property
499
- def _llm_type(self) -> str:
500
- return "mistral_llm"
501
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
502
- try:
503
- debug_print("Calling Mistral API...")
504
- response = self._client.chat.complete(
505
- model="mistral-small-latest",
506
- messages=[{"role": "user", "content": prompt}],
507
- temperature=self.temperature,
508
- top_p=self.top_p
509
- )
510
- return response.choices[0].message.content
511
- except Exception as e:
512
- debug_print(f"Mistral API error: {str(e)}")
513
- return f"Error generating response: {str(e)}"
514
- @property
515
- def _identifying_params(self) -> dict:
516
- return {"model": "mistral-small-latest"}
517
- debug_print("Creating Mistral LLM instance")
518
- mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
519
- debug_print("Mistral API pipeline created successfully.")
520
- return mistral_llm
521
-
522
- else:
523
- # Default case - using a fallback model (or Llama)
524
- debug_print("Using local/fallback model pipeline")
525
- model_id = "facebook/opt-350m" # Use a smaller model as fallback
526
- pipe = pipeline(
527
- "text-generation",
528
- model=model_id,
529
- device=-1, # CPU
530
- max_length=1024
531
- )
532
-
533
- class LocalLLM(LLM):
534
- @property
535
- def _llm_type(self) -> str:
536
- return "local_llm"
537
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
538
- # For this fallback, truncate prompt if it exceeds limits
539
- reserved_gen = 128
540
- max_total = 1024
541
- max_prompt_tokens = max_total - reserved_gen
542
- truncated_prompt = truncate_prompt(prompt, max_tokens=max_prompt_tokens)
543
- generated = pipe(truncated_prompt, max_new_tokens=reserved_gen)[0]["generated_text"]
544
- return generated
545
- @property
546
- def _identifying_params(self) -> dict:
547
- return {"model": model_id, "max_length": 1024}
548
-
549
- debug_print("Local fallback pipeline created.")
550
- return LocalLLM()
551
-
552
- except Exception as e:
553
- debug_print(f"Error creating LLM pipeline: {str(e)}")
554
- # Return a dummy LLM that explains the error
555
- class ErrorLLM(LLM):
556
- @property
557
- def _llm_type(self) -> str:
558
- return "error_llm"
559
-
560
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
561
- return f"Error initializing LLM: \n\nPlease check your environment variables and try again."
562
-
563
- @property
564
- def _identifying_params(self) -> dict:
565
- return {"model": "error"}
566
-
567
- return ErrorLLM()
568
-
569
-
570
- def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
571
- debug_print(f"Updating chain with new model: {new_model_choice}")
572
- self.llm_choice = new_model_choice
573
- self.temperature = temperature
574
- self.top_p = top_p
575
- self.prompt_template = prompt_template
576
- self.bm25_weight = bm25_weight
577
- self.faiss_weight = 1.0 - bm25_weight
578
- self.llm = self.create_llm_pipeline()
579
- def format_response(response: str) -> str:
580
- input_tokens = count_tokens(self.context + self.prompt_template)
581
- output_tokens = count_tokens(response)
582
- formatted = f"### Response\n\n{response}\n\n---\n"
583
- formatted += f"- **Input tokens:** {input_tokens}\n"
584
- formatted += f"- **Output tokens:** {output_tokens}\n"
585
- formatted += f"- **Generated using:** {self.llm_choice}\n"
586
- formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
587
- return formatted
588
- base_runnable = RunnableParallel({
589
- "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
590
- "question": RunnableLambda(self.extract_question)
591
- }) | self.capture_context
592
- self.elevated_rag_chain = base_runnable | self.rag_prompt | self.llm | format_response
593
- debug_print("Chain updated successfully with new LLM pipeline.")
594
-
595
- def add_pdfs_to_vectore_store(self, file_links: List[str]) -> None:
596
- debug_print(f"Processing files using {self.llm_choice}")
597
- self.raw_data = []
598
- for link in file_links:
599
- if link.lower().endswith(".pdf"):
600
- debug_print(f"Loading PDF: {link}")
601
- loaded_docs = OnlinePDFLoader(link).load()
602
- if loaded_docs:
603
- self.raw_data.append(loaded_docs[0])
604
- else:
605
- debug_print(f"No content found in PDF: {link}")
606
- elif link.lower().endswith(".txt") or link.lower().endswith(".utf-8"):
607
- debug_print(f"Loading TXT: {link}")
608
- try:
609
- self.raw_data.append(load_txt_from_url(link))
610
- except Exception as e:
611
- debug_print(f"Error loading TXT file {link}: {e}")
612
- else:
613
- debug_print(f"File type not supported for URL: {link}")
614
- if not self.raw_data:
615
- raise ValueError("No files were successfully loaded. Please check the URLs and file formats.")
616
- debug_print("Files loaded successfully.")
617
- debug_print("Starting text splitting...")
618
- self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100)
619
- self.split_data = self.text_splitter.split_documents(self.raw_data)
620
- if not self.split_data:
621
- raise ValueError("Text splitting resulted in no chunks. Check the file contents.")
622
- debug_print(f"Text splitting completed. Number of chunks: {len(self.split_data)}")
623
- debug_print("Creating BM25 retriever...")
624
- self.bm25_retriever = BM25Retriever.from_documents(self.split_data)
625
- self.bm25_retriever.k = self.top_k
626
- debug_print("BM25 retriever created.")
627
- debug_print("Embedding chunks and creating FAISS vector store...")
628
- self.vector_store = FAISS.from_documents(self.split_data, self.embed_func)
629
- self.faiss_retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
630
- debug_print("FAISS vector store created successfully.")
631
- self.ensemble_retriever = EnsembleRetriever(
632
- retrievers=[self.bm25_retriever, self.faiss_retriever],
633
- weights=[self.bm25_weight, self.faiss_weight]
634
- )
635
-
636
- base_runnable = RunnableParallel({
637
- "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
638
- "question": RunnableLambda(self.extract_question)
639
- }) | self.capture_context
640
-
641
- # Ensure the prompt template is set
642
- self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
643
- if self.rag_prompt is None:
644
- raise ValueError("Prompt template could not be created from the given template.")
645
- prompt_runnable = RunnableLambda(lambda vars: self.rag_prompt.format(**vars))
646
-
647
- self.str_output_parser = StrOutputParser()
648
- debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
649
- self.llm = self.create_llm_pipeline()
650
- if self.llm is None:
651
- raise ValueError("LLM pipeline creation failed.")
652
-
653
- def format_response(response: str) -> str:
654
- input_tokens = count_tokens(self.context + self.prompt_template)
655
- output_tokens = count_tokens(response)
656
- formatted = f"### Response\n\n{response}\n\n---\n"
657
- formatted += f"- **Input tokens:** {input_tokens}\n"
658
- formatted += f"- **Output tokens:** {output_tokens}\n"
659
- formatted += f"- **Generated using:** {self.llm_choice}\n"
660
- formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
661
- return formatted
662
-
663
- self.elevated_rag_chain = base_runnable | prompt_runnable | self.llm | format_response
664
- debug_print("Elevated RAG chain successfully built and ready to use.")
665
-
666
-
667
-
668
- def get_current_context(self) -> str:
669
- base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
670
- history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
671
- recent = self.conversation_history[-3:]
672
- if recent:
673
- for i, conv in enumerate(recent, 1):
674
- history_summary += f"**Conversation {i}:**\n- Query: {conv['query']}\n- Response: {conv['response']}\n"
675
- else:
676
- history_summary += "No conversation history."
677
- return base_context + history_summary
678
-
679
- # ----------------------------
680
- # Gradio Interface Functions
681
- # ----------------------------
682
- global rag_chain
683
- rag_chain = ElevatedRagChain()
684
-
685
- def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
686
- debug_print("Inside load_pdfs function.")
687
- if not file_links:
688
- debug_print("Please enter non-empty URLs")
689
- return "Please enter non-empty URLs", "Word count: N/A", "Model used: N/A", "Context: N/A"
690
- try:
691
- links = [link.strip() for link in file_links.split("\n") if link.strip()]
692
- global rag_chain
693
- if rag_chain.raw_data:
694
- rag_chain.update_llm_pipeline(model_choice, temperature, top_p, prompt_template, bm25_weight)
695
- context_display = rag_chain.get_current_context()
696
- response_msg = f"Files already loaded. Chain updated with model: {model_choice}"
697
- return (
698
- response_msg,
699
- f"Word count: {word_count(rag_chain.context)}",
700
- f"Model used: {rag_chain.llm_choice}",
701
- f"Context:\n{context_display}"
702
- )
703
- else:
704
- rag_chain = ElevatedRagChain(
705
- llm_choice=model_choice,
706
- prompt_template=prompt_template,
707
- bm25_weight=bm25_weight,
708
- temperature=temperature,
709
- top_p=top_p
710
- )
711
- rag_chain.add_pdfs_to_vectore_store(links)
712
- context_display = rag_chain.get_current_context()
713
- response_msg = f"Files loaded successfully. Using model: {model_choice}"
714
- return (
715
- response_msg,
716
- f"Word count: {word_count(rag_chain.context)}",
717
- f"Model used: {rag_chain.llm_choice}",
718
- f"Context:\n{context_display}"
719
- )
720
- except Exception as e:
721
- error_msg = traceback.format_exc()
722
- debug_print("Could not load files. Error: " + error_msg)
723
- return (
724
- "Error loading files: " + str(e),
725
- f"Word count: {word_count('')}",
726
- f"Model used: {rag_chain.llm_choice}",
727
- "Context: N/A"
728
- )
729
-
730
- def update_model(new_model: str):
731
- global rag_chain
732
- if rag_chain and rag_chain.raw_data:
733
- rag_chain.update_llm_pipeline(new_model, rag_chain.temperature, rag_chain.top_p,
734
- rag_chain.prompt_template, rag_chain.bm25_weight)
735
- debug_print(f"Model updated to {rag_chain.llm_choice}")
736
- return f"Model updated to: {rag_chain.llm_choice}"
737
- else:
738
- return "No files loaded; please load files first."
739
-
740
-
741
- # Update submit_query_updated to better handle context limitation
742
- def submit_query_updated(query):
743
- debug_print(f"Processing query: {query}")
744
- if not query:
745
- debug_print("Empty query received")
746
- return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
747
-
748
- if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
749
- debug_print("RAG chain not initialized")
750
- return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
751
-
752
- try:
753
- # Determine max context size based on model
754
- model_name = rag_chain.llm_choice.lower()
755
- max_context_tokens = 32000 if "mistral" in model_name else 4096
756
-
757
- # Reserve 20% of tokens for the question and response generation
758
- reserved_tokens = int(max_context_tokens * 0.2)
759
- max_context_tokens -= reserved_tokens
760
-
761
- # Collect conversation history (last 2 only to save tokens)
762
- if rag_chain.conversation_history:
763
- recent_history = rag_chain.conversation_history[-2:]
764
- history_text = "\n".join([f"Q: {conv['query']}\nA: {conv['response'][:300]}..."
765
- for conv in recent_history])
766
- else:
767
- history_text = ""
768
-
769
- # Get history token count
770
- history_tokens = count_tokens(history_text)
771
-
772
- # Adjust context tokens based on history size
773
- context_tokens = max_context_tokens - history_tokens
774
-
775
- # Ensure we have some minimum context
776
- context_tokens = max(context_tokens, 1000)
777
-
778
- # Truncate context if needed
779
- context = truncate_prompt(rag_chain.context, max_tokens=context_tokens)
780
-
781
- debug_print(f"Using model: {model_name}, context tokens: {count_tokens(context)}, history tokens: {history_tokens}")
782
-
783
- prompt_variables = {
784
- "conversation_history": history_text,
785
- "context": context,
786
- "question": query
787
- }
788
-
789
- debug_print("Invoking RAG chain")
790
- response = rag_chain.elevated_rag_chain.invoke({"question": query})
791
-
792
- # Store only a reasonable amount of the response in history
793
- trimmed_response = response[:1000] + ("..." if len(response) > 1000 else "")
794
- rag_chain.conversation_history.append({"query": query, "response": trimmed_response})
795
-
796
- input_token_count = count_tokens(query)
797
- output_token_count = count_tokens(response)
798
-
799
- debug_print(f"Query processed successfully. Output tokens: {output_token_count}")
800
-
801
- return (
802
- response,
803
- rag_chain.get_current_context(),
804
- f"Input tokens: {input_token_count}",
805
- f"Output tokens: {output_token_count}"
806
- )
807
- except Exception as e:
808
- error_msg = traceback.format_exc()
809
- debug_print(f"LLM error: {error_msg}")
810
- return (
811
- f"Query error: {str(e)}\n\nTry using a smaller document or simplifying your query.",
812
- "",
813
- "Input tokens: 0",
814
- "Output tokens: 0"
815
- )
816
-
817
- def reset_app_updated():
818
- global rag_chain
819
- rag_chain = ElevatedRagChain()
820
- debug_print("App reset successfully.")
821
- return (
822
- "App reset successfully. You can now load new files",
823
- "",
824
- "Model used: Not selected"
825
- )
826
-
827
- # ----------------------------
828
- # Gradio Interface Setup
829
- # ----------------------------
830
- custom_css = """
831
- textarea {
832
- overflow-y: scroll !important;
833
- max-height: 200px;
834
- }
835
- """
836
-
837
- # Update the Gradio interface to include job status checking
838
- with gr.Blocks(css=custom_css, js="""
839
- document.addEventListener('DOMContentLoaded', function() {
840
- // Add event listener for job list clicks
841
- const jobListInterval = setInterval(() => {
842
- const jobLinks = document.querySelectorAll('.job-list-container a');
843
- if (jobLinks.length > 0) {
844
- jobLinks.forEach(link => {
845
- link.addEventListener('click', function(e) {
846
- e.preventDefault();
847
- const jobId = this.textContent.split(' ')[0];
848
- // Find the job ID input textbox and set its value
849
- const jobIdInput = document.querySelector('.job-id-input input');
850
- if (jobIdInput) {
851
- jobIdInput.value = jobId;
852
- // Trigger the input event to update Gradio's state
853
- jobIdInput.dispatchEvent(new Event('input', { bubbles: true }));
854
- }
855
- });
856
- });
857
- clearInterval(jobListInterval);
858
- }
859
- }, 500);
860
- });
861
- """) as app:
862
- gr.Markdown('''# PhiRAG - Async Version
863
- **PhiRAG** Query Your Data with Advanced RAG Techniques
864
-
865
- **Model Selection & Parameters:** Choose from the following options:
866
- - 🇺🇸 Remote Meta-Llama-3 - has context windows of 8000 tokens
867
- - 🇪🇺 Mistral-API - has context windows of 32000 tokens
868
-
869
- **🔥 Randomness (Temperature):** Adjusts output predictability.
870
- - Example: 0.2 makes the output very deterministic (less creative), while 0.8 introduces more variety and spontaneity.
871
-
872
- **🎯 Word Variety (Top‑p):** Limits word choices to a set probability percentage.
873
- - Example: 0.5 restricts output to the most likely 50% of token choices for a focused answer; 0.95 allows almost all possibilities for more diverse responses.
874
-
875
- **⚖️ BM25 Weight:** Adjust Lexical vs Semantics.
876
- - Example: A value of 0.8 puts more emphasis on exact keyword (lexical) matching, while 0.3 shifts emphasis toward semantic similarity.
877
-
878
- **✏️ Prompt Template:** Edit as desired.
879
-
880
- **🔗 File URLs:** Enter one URL per line (.pdf or .txt).\
881
- - Example: Provide one URL per line, such as
882
- https://www.gutenberg.org/ebooks/8438.txt.utf-8
883
-
884
- **🔍 Query:** Enter your query below.
885
-
886
- **⚠️ IMPORTANT: This app now uses asynchronous processing to avoid timeout issues**
887
- - When you load files or submit a query, you'll receive a Job ID
888
- - Use the "Check Job Status" tab to monitor and retrieve your results
889
- ''')
890
-
891
- with gr.Tabs() as tabs:
892
- with gr.TabItem("Setup & Load Files"):
893
- with gr.Row():
894
- with gr.Column():
895
- model_dropdown = gr.Dropdown(
896
- choices=["🇺🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
897
- value="🇺🇸 Remote Meta-Llama-3",
898
- label="Select Model"
899
- )
900
- temperature_slider = gr.Slider(
901
- minimum=0.1, maximum=1.0, value=0.5, step=0.1,
902
- label="Randomness (Temperature)"
903
- )
904
- top_p_slider = gr.Slider(
905
- minimum=0.1, maximum=0.99, value=0.95, step=0.05,
906
- label="Word Variety (Top-p)"
907
- )
908
- with gr.Column():
909
- pdf_input = gr.Textbox(
910
- label="Enter your file URLs (one per line)",
911
- placeholder="Enter one URL per line (.pdf or .txt)",
912
- lines=4
913
- )
914
- prompt_input = gr.Textbox(
915
- label="Custom Prompt Template",
916
- placeholder="Enter your custom prompt template here",
917
- lines=8,
918
- value=default_prompt
919
- )
920
- with gr.Column():
921
- bm25_weight_slider = gr.Slider(
922
- minimum=0.0, maximum=1.0, value=0.6, step=0.1,
923
- label="Lexical vs Semantics (BM25 Weight)"
924
- )
925
- load_button = gr.Button("Load Files (Async)")
926
- load_status = gr.Markdown("Status: Waiting for files")
927
-
928
- with gr.Row():
929
- load_response = gr.Textbox(
930
- label="Load Response",
931
- placeholder="Response will appear here",
932
- lines=4
933
- )
934
- load_context = gr.Textbox(
935
- label="Context Info",
936
- placeholder="Context info will appear here",
937
- lines=4
938
- )
939
-
940
- with gr.Row():
941
- model_output = gr.Markdown("**Current Model**: Not selected")
942
-
943
- with gr.TabItem("Submit Query"):
944
- with gr.Row():
945
- # Add this line to define the query_model_dropdown
946
- query_model_dropdown = gr.Dropdown(
947
- choices=["🇺🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
948
- value="🇺🇸 Remote Meta-Llama-3",
949
- label="Query Model"
950
- )
951
-
952
- query_input = gr.Textbox(
953
- label="Enter your query here",
954
- placeholder="Type your query",
955
- lines=4
956
- )
957
- submit_button = gr.Button("Submit Query (Async)")
958
-
959
- with gr.Row():
960
- query_response = gr.Textbox(
961
- label="Query Response",
962
- placeholder="Response will appear here (formatted as Markdown)",
963
- lines=6
964
- )
965
- query_context = gr.Textbox(
966
- label="Context Information",
967
- placeholder="Retrieved context and conversation history will appear here",
968
- lines=6
969
- )
970
-
971
- with gr.Row():
972
- input_tokens = gr.Markdown("Input tokens: 0")
973
- output_tokens = gr.Markdown("Output tokens: 0")
974
-
975
- with gr.TabItem("Check Job Status"):
976
- with gr.Row():
977
- with gr.Column(scale=1):
978
- job_list = gr.Markdown(
979
- value="No jobs yet",
980
- label="Job List (Click to select)"
981
- )
982
- refresh_button = gr.Button("Refresh Job List")
983
-
984
- with gr.Column(scale=2):
985
- job_id_input = gr.Textbox(
986
- label="Job ID",
987
- placeholder="Job ID will appear here when selected from the list",
988
- lines=1
989
- )
990
- job_query_display = gr.Textbox(
991
- label="Job Query",
992
- placeholder="The query associated with this job will appear here",
993
- lines=2,
994
- interactive=False
995
- )
996
- check_button = gr.Button("Check Status")
997
- cleanup_button = gr.Button("Cleanup Old Jobs")
998
-
999
- with gr.Row():
1000
- status_response = gr.Textbox(
1001
- label="Job Result",
1002
- placeholder="Job result will appear here",
1003
- lines=6
1004
- )
1005
- status_context = gr.Textbox(
1006
- label="Context Information",
1007
- placeholder="Context information will appear here",
1008
- lines=6
1009
- )
1010
-
1011
- with gr.Row():
1012
- status_tokens1 = gr.Markdown("")
1013
- status_tokens2 = gr.Markdown("")
1014
-
1015
- with gr.TabItem("App Management"):
1016
- with gr.Row():
1017
- reset_button = gr.Button("Reset App")
1018
-
1019
- with gr.Row():
1020
- reset_response = gr.Textbox(
1021
- label="Reset Response",
1022
- placeholder="Reset confirmation will appear here",
1023
- lines=2
1024
- )
1025
- reset_context = gr.Textbox(
1026
- label="",
1027
- placeholder="",
1028
- lines=2,
1029
- visible=False
1030
- )
1031
-
1032
- with gr.Row():
1033
- reset_model = gr.Markdown("")
1034
-
1035
- # Connect the buttons to their respective functions
1036
- load_button.click(
1037
- load_pdfs_async,
1038
- inputs=[pdf_input, model_dropdown, prompt_input, bm25_weight_slider, temperature_slider, top_p_slider],
1039
- outputs=[load_response, load_context, model_output, job_id_input, job_query_display, job_list]
1040
- )
1041
-
1042
- # Also sync in the other direction
1043
- query_model_dropdown.change(
1044
- fn=sync_model_dropdown,
1045
- inputs=query_model_dropdown,
1046
- outputs=model_dropdown
1047
- )
1048
-
1049
- submit_button.click(
1050
- submit_query_async,
1051
- inputs=[query_input, query_model_dropdown],
1052
- outputs=[query_response, query_context, input_tokens, output_tokens, job_id_input, job_query_display, job_list]
1053
- )
1054
-
1055
- check_button.click(
1056
- check_job_status,
1057
- inputs=[job_id_input],
1058
- outputs=[status_response, status_context, status_tokens1, status_tokens2, job_query_display]
1059
- )
1060
-
1061
- refresh_button.click(
1062
- refresh_job_list,
1063
- inputs=[],
1064
- outputs=[job_list]
1065
- )
1066
-
1067
- # Connect the job list selection event (this is handled by JavaScript)
1068
- job_id_input.change(
1069
- job_selected,
1070
- inputs=[job_id_input],
1071
- outputs=[job_id_input, job_query_display]
1072
- )
1073
-
1074
- cleanup_button.click(
1075
- cleanup_old_jobs,
1076
- inputs=[],
1077
- outputs=[status_response, status_context, status_tokens1]
1078
- )
1079
-
1080
- reset_button.click(
1081
- reset_app_updated,
1082
- inputs=[],
1083
- outputs=[reset_response, reset_context, reset_model]
1084
- )
1085
-
1086
-
1087
- model_dropdown.change(
1088
- fn=sync_model_dropdown,
1089
- inputs=model_dropdown,
1090
- outputs=query_model_dropdown
1091
- )
1092
-
1093
- # Add an event to refresh the job list on page load
1094
- app.load(
1095
- fn=refresh_job_list,
1096
- inputs=None,
1097
- outputs=job_list
1098
- )
1099
-
1100
- if __name__ == "__main__":
1101
- debug_print("Launching Gradio interface.")
1102
- app.launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
+ import datetime
4
+ import functools
5
+ import traceback
6
+ from typing import List, Optional, Any, Dict
7
+
8
+ import torch
9
+ import transformers
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
+ from langchain_community.llms import HuggingFacePipeline
12
+
13
+ # Other LangChain and community imports
14
+ from langchain_community.document_loaders import OnlinePDFLoader
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain_community.vectorstores import FAISS
17
+ from langchain.embeddings import HuggingFaceEmbeddings
18
+ from langchain_community.retrievers import BM25Retriever
19
+ from langchain.retrievers import EnsembleRetriever
20
+ from langchain.prompts import ChatPromptTemplate
21
+ from langchain.schema import StrOutputParser, Document
22
+ from langchain_core.runnables import RunnableParallel, RunnableLambda
23
+ from transformers.quantizers.auto import AutoQuantizationConfig
24
+ import gradio as gr
25
+ import requests
26
+ from pydantic import PrivateAttr
27
+ import pydantic
28
+
29
+ from langchain.llms.base import LLM
30
+ from typing import Any, Optional, List
31
+ import typing
32
+ import time
33
+
34
+ print("Pydantic Version: ")
35
+ print(pydantic.__version__)
36
+ # Add Mistral imports with fallback handling
37
+
38
+ try:
39
+ from mistralai import Mistral
40
+ MISTRAL_AVAILABLE = True
41
+ debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
42
+ debug_print("Loaded latest Mistral client library")
43
+ except ImportError:
44
+ MISTRAL_AVAILABLE = False
45
+ debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
46
+ debug_print("Mistral client library not found. Install with: pip install mistralai")
47
+
48
+ def debug_print(message: str):
49
+ print(f"[{datetime.datetime.now().isoformat()}] {message}", flush=True)
50
+
51
+ def word_count(text: str) -> int:
52
+ return len(text.split())
53
+
54
+ # Initialize a tokenizer for token counting (using gpt2 as a generic fallback)
55
+ def initialize_tokenizer():
56
+ try:
57
+ return AutoTokenizer.from_pretrained("gpt2")
58
+ except Exception as e:
59
+ debug_print("Failed to initialize tokenizer: " + str(e))
60
+ return None
61
+
62
+ global_tokenizer = initialize_tokenizer()
63
+
64
+ def count_tokens(text: str) -> int:
65
+ if global_tokenizer:
66
+ try:
67
+ return len(global_tokenizer.encode(text))
68
+ except Exception as e:
69
+ return len(text.split())
70
+ return len(text.split())
71
+
72
+
73
+ # Add these imports at the top of your file
74
+ import uuid
75
+ import threading
76
+ import queue
77
+ from typing import Dict, Any, Tuple, Optional
78
+ import time
79
+
80
+ # Global storage for jobs and results
81
+ jobs = {} # Stores job status and results
82
+ results_queue = queue.Queue() # Thread-safe queue for completed jobs
83
+ processing_lock = threading.Lock() # Prevent simultaneous processing of the same job
84
+
85
+ # Add a global variable to store the last job ID
86
+ last_job_id = None
87
+
88
+ # Add these missing async processing functions
89
+
90
+ def process_in_background(job_id, function, args):
91
+ """Process a function in the background and store results"""
92
+ try:
93
+ debug_print(f"Processing job {job_id} in background")
94
+ result = function(*args)
95
+ results_queue.put((job_id, result))
96
+ debug_print(f"Job {job_id} completed and added to results queue")
97
+ except Exception as e:
98
+ debug_print(f"Error in background job {job_id}: {str(e)}")
99
+ error_result = (f"Error processing job: {str(e)}", "", "", "")
100
+ results_queue.put((job_id, error_result))
101
+
102
+ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
103
+ """Asynchronous version of load_pdfs_updated to prevent timeouts"""
104
+ global last_job_id
105
+ if not file_links:
106
+ return "Please enter non-empty URLs", "", "Model used: N/A", "", "", get_job_list()
107
+
108
+ job_id = str(uuid.uuid4())
109
+ debug_print(f"Starting async job {job_id} for file loading")
110
+
111
+ # Start background thread
112
+ threading.Thread(
113
+ target=process_in_background,
114
+ args=(job_id, load_pdfs_updated, [file_links, model_choice, prompt_template, bm25_weight, temperature, top_p])
115
+ ).start()
116
+
117
+ job_query = f"Loading files: {file_links.split()[0]}..." if file_links else "No files"
118
+ jobs[job_id] = {
119
+ "status": "processing",
120
+ "type": "load_files",
121
+ "start_time": time.time(),
122
+ "query": job_query
123
+ }
124
+
125
+ last_job_id = job_id
126
+
127
+ return (
128
+ f"Files submitted and processing in the background (Job ID: {job_id}).\n\n"
129
+ f"Use 'Check Job Status' tab with this ID to get results.",
130
+ f"Job ID: {job_id}",
131
+ f"Model requested: {model_choice}",
132
+ job_id, # Return job_id to update the job_id_input component
133
+ job_query, # Return job_query to update the job_query_display component
134
+ get_job_list() # Return updated job list
135
+ )
136
+
137
+ def submit_query_async(query, model_choice=None):
138
+ """Asynchronous version of submit_query_updated to prevent timeouts"""
139
+ global last_job_id
140
+ if not query:
141
+ return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0", "", "", get_job_list()
142
+
143
+ job_id = str(uuid.uuid4())
144
+ debug_print(f"Starting async job {job_id} for query: {query}")
145
+
146
+ # Update model if specified
147
+ if model_choice and rag_chain and rag_chain.llm_choice != model_choice:
148
+ debug_print(f"Updating model to {model_choice} for this query")
149
+ rag_chain.update_llm_pipeline(model_choice, rag_chain.temperature, rag_chain.top_p,
150
+ rag_chain.prompt_template, rag_chain.bm25_weight)
151
+
152
+ # Start background thread
153
+ threading.Thread(
154
+ target=process_in_background,
155
+ args=(job_id, submit_query_updated, [query])
156
+ ).start()
157
+
158
+ jobs[job_id] = {
159
+ "status": "processing",
160
+ "type": "query",
161
+ "start_time": time.time(),
162
+ "query": query,
163
+ "model": rag_chain.llm_choice if hasattr(rag_chain, 'llm_choice') else "Unknown"
164
+ }
165
+
166
+ last_job_id = job_id
167
+
168
+ return (
169
+ f"Query submitted and processing in the background (Job ID: {job_id}).\n\n"
170
+ f"Use 'Check Job Status' tab with this ID to get results.",
171
+ f"Job ID: {job_id}",
172
+ f"Input tokens: {count_tokens(query)}",
173
+ "Output tokens: pending",
174
+ job_id, # Return job_id to update the job_id_input component
175
+ query, # Return query to update the job_query_display component
176
+ get_job_list() # Return updated job list
177
+ )
178
+
179
+ def update_ui_with_last_job_id():
180
+ # This function doesn't need to do anything anymore
181
+ # We'll update the UI directly in the functions that call this
182
+ pass
183
+
184
+ # Function to display all jobs as a clickable list
185
+ def get_job_list():
186
+ job_list_md = "### Submitted Jobs\n\n"
187
+
188
+ if not jobs:
189
+ return "No jobs found. Submit a query or load files to create jobs."
190
+
191
+ # Sort jobs by start time (newest first)
192
+ sorted_jobs = sorted(
193
+ [(job_id, job_info) for job_id, job_info in jobs.items()],
194
+ key=lambda x: x[1].get("start_time", 0),
195
+ reverse=True
196
+ )
197
+
198
+ for job_id, job_info in sorted_jobs:
199
+ status = job_info.get("status", "unknown")
200
+ job_type = job_info.get("type", "unknown")
201
+ query = job_info.get("query", "")
202
+ start_time = job_info.get("start_time", 0)
203
+ time_str = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
204
+
205
+ # Create a shortened query preview
206
+ query_preview = query[:30] + "..." if query and len(query) > 30 else query or "N/A"
207
+
208
+ # Add color and icons based on status
209
+ if status == "processing":
210
+ # Red color with processing icon for processing jobs
211
+ status_formatted = f"<span style='color: red'>⏳ {status}</span>"
212
+ elif status == "completed":
213
+ # Green color with checkmark for completed jobs
214
+ status_formatted = f"<span style='color: green'>✅ {status}</span>"
215
+ else:
216
+ # Default formatting for unknown status
217
+ status_formatted = f"<span style='color: orange'>❓ {status}</span>"
218
+
219
+ # Create clickable links using Markdown
220
+ if job_type == "query":
221
+ job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status_formatted} - Query: {query_preview}\n"
222
+ else:
223
+ job_list_md += f"- [{job_id}](javascript:void) - {time_str} - {status_formatted} - File Load Job\n"
224
+
225
+ return job_list_md
226
+
227
+ # Function to handle job list clicks
228
+ def job_selected(job_id):
229
+ if job_id in jobs:
230
+ return job_id, jobs[job_id].get("query", "No query for this job")
231
+ return job_id, "Job not found"
232
+
233
+ # Function to refresh the job list
234
+ def refresh_job_list():
235
+ return get_job_list()
236
+
237
+ # Function to sync model dropdown boxes
238
+ def sync_model_dropdown(value):
239
+ return value
240
+
241
+ # Function to check job status
242
+ def check_job_status(job_id):
243
+ if not job_id:
244
+ return "Please enter a job ID", "", "", "", ""
245
+
246
+ # Process any completed jobs in the queue
247
+ try:
248
+ while not results_queue.empty():
249
+ completed_id, result = results_queue.get_nowait()
250
+ if completed_id in jobs:
251
+ jobs[completed_id]["status"] = "completed"
252
+ jobs[completed_id]["result"] = result
253
+ jobs[completed_id]["end_time"] = time.time()
254
+ debug_print(f"Job {completed_id} completed and stored in jobs dictionary")
255
+ except queue.Empty:
256
+ pass
257
+
258
+ # Check if the requested job exists
259
+ if job_id not in jobs:
260
+ return "Job not found. Please check the ID and try again.", "", "", "", ""
261
+
262
+ job = jobs[job_id]
263
+ job_query = job.get("query", "No query available for this job")
264
+
265
+ # If job is still processing
266
+ if job["status"] == "processing":
267
+ elapsed_time = time.time() - job["start_time"]
268
+ job_type = job.get("type", "unknown")
269
+
270
+ if job_type == "load_files":
271
+ return (
272
+ f"Files are still being processed (elapsed: {elapsed_time:.1f}s).\n\n"
273
+ f"Try checking again in a few seconds.",
274
+ f"Job ID: {job_id}",
275
+ f"Status: Processing",
276
+ "",
277
+ job_query
278
+ )
279
+ else: # query job
280
+ return (
281
+ f"Query is still being processed (elapsed: {elapsed_time:.1f}s).\n\n"
282
+ f"Try checking again in a few seconds.",
283
+ f"Job ID: {job_id}",
284
+ f"Input tokens: {count_tokens(job.get('query', ''))}",
285
+ "Output tokens: pending",
286
+ job_query
287
+ )
288
+
289
+ # If job is completed
290
+ if job["status"] == "completed":
291
+ result = job["result"]
292
+ processing_time = job["end_time"] - job["start_time"]
293
+
294
+ if job.get("type") == "load_files":
295
+ return (
296
+ f"{result[0]}\n\nProcessing time: {processing_time:.1f}s",
297
+ result[1],
298
+ result[2],
299
+ "",
300
+ job_query
301
+ )
302
+ else: # query job
303
+ return (
304
+ f"{result[0]}\n\nProcessing time: {processing_time:.1f}s",
305
+ result[1],
306
+ result[2],
307
+ result[3],
308
+ job_query
309
+ )
310
+
311
+ # Fallback for unknown status
312
+ return f"Job status: {job['status']}", "", "", "", job_query
313
+
314
+ # Function to clean up old jobs
315
+ def cleanup_old_jobs():
316
+ current_time = time.time()
317
+ to_delete = []
318
+
319
+ for job_id, job in jobs.items():
320
+ # Keep completed jobs for 1 hour, processing jobs for 2 hours
321
+ if job["status"] == "completed" and (current_time - job.get("end_time", 0)) > 3600:
322
+ to_delete.append(job_id)
323
+ elif job["status"] == "processing" and (current_time - job.get("start_time", 0)) > 7200:
324
+ to_delete.append(job_id)
325
+
326
+ for job_id in to_delete:
327
+ del jobs[job_id]
328
+
329
+ debug_print(f"Cleaned up {len(to_delete)} old jobs. {len(jobs)} jobs remaining.")
330
+ return f"Cleaned up {len(to_delete)} old jobs", "", ""
331
+
332
+ # Improve the truncate_prompt function to be more aggressive with limiting context
333
+ def truncate_prompt(prompt: str, max_tokens: int = 4096) -> str:
334
+ """Truncate prompt to fit within token limit, preserving the most recent/relevant parts."""
335
+ if not prompt:
336
+ return ""
337
+
338
+ if global_tokenizer:
339
+ try:
340
+ tokens = global_tokenizer.encode(prompt)
341
+ if len(tokens) > max_tokens:
342
+ # For prompts, we often want to keep the beginning instructions and the end context
343
+ # So we'll keep the first 20% and the last 80% of the max tokens
344
+ beginning_tokens = int(max_tokens * 0.2)
345
+ ending_tokens = max_tokens - beginning_tokens
346
+
347
+ new_tokens = tokens[:beginning_tokens] + tokens[-(ending_tokens):]
348
+ return global_tokenizer.decode(new_tokens)
349
+ except Exception as e:
350
+ debug_print(f"Truncation error: {str(e)}")
351
+
352
+ # Fallback to word-based truncation
353
+ words = prompt.split()
354
+ if len(words) > max_tokens:
355
+ beginning_words = int(max_tokens * 0.2)
356
+ ending_words = max_tokens - beginning_words
357
+
358
+ return " ".join(words[:beginning_words] + words[-(ending_words):])
359
+
360
+ return prompt
361
+
362
+
363
+
364
+
365
+ default_prompt = """\
366
+ {conversation_history}
367
+ Use the following context to provide a detailed technical answer to the user's question.
368
+ Do not include an introduction like "Based on the provided documents, ...". Just answer the question.
369
+ If you don't know the answer, please respond with "I don't know".
370
+
371
+ Context:
372
+ {context}
373
+
374
+ User's question:
375
+ {question}
376
+ """
377
+
378
+ def load_txt_from_url(url: str) -> Document:
379
+ response = requests.get(url)
380
+ if response.status_code == 200:
381
+ text = response.text.strip()
382
+ if not text:
383
+ raise ValueError(f"TXT file at {url} is empty.")
384
+ return Document(page_content=text, metadata={"source": url})
385
+ else:
386
+ raise Exception(f"Failed to load {url} with status {response.status_code}")
387
+
388
+ class ElevatedRagChain:
389
+ def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
390
+ bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
391
+ debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
392
+ self.embed_func = HuggingFaceEmbeddings(
393
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
394
+ model_kwargs={"device": "cpu"}
395
+ )
396
+ self.bm25_weight = bm25_weight
397
+ self.faiss_weight = 1.0 - bm25_weight
398
+ self.top_k = 5
399
+ self.llm_choice = llm_choice
400
+ self.temperature = temperature
401
+ self.top_p = top_p
402
+ self.prompt_template = prompt_template
403
+ self.context = ""
404
+ self.conversation_history: List[Dict[str, str]] = []
405
+ self.raw_data = None
406
+ self.split_data = None
407
+ self.elevated_rag_chain = None
408
+
409
+ # Instance method to capture context and conversation history
410
+ def capture_context(self, result):
411
+ self.context = "\n".join([str(doc) for doc in result["context"]])
412
+ result["context"] = self.context
413
+ history_text = (
414
+ "\n".join([f"Q: {conv['query']}\nA: {conv['response']}" for conv in self.conversation_history])
415
+ if self.conversation_history else ""
416
+ )
417
+ result["conversation_history"] = history_text
418
+ return result
419
+
420
+ # Instance method to extract question from input data
421
+ def extract_question(self, input_data):
422
+ return input_data["question"]
423
+
424
+ # Improve error handling in the ElevatedRagChain class
425
+ def create_llm_pipeline(self):
426
+ from langchain.llms.base import LLM # Import LLM here so it's always defined
427
+ normalized = self.llm_choice.lower()
428
+ try:
429
+ if "remote" in normalized:
430
+ debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
431
+ from huggingface_hub import InferenceClient
432
+ repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
433
+ hf_api_token = os.environ.get("HF_API_TOKEN")
434
+ if not hf_api_token:
435
+ raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
436
+
437
+ client = InferenceClient(token=hf_api_token, timeout=120)
438
+
439
+ # We no longer use wait_for_model because it's unsupported
440
+ def remote_generate(prompt: str) -> str:
441
+ max_retries = 3
442
+ backoff = 2 # start with 2 seconds
443
+ for attempt in range(max_retries):
444
+ try:
445
+ debug_print(f"Remote generation attempt {attempt+1}")
446
+ response = client.text_generation(
447
+ prompt,
448
+ model=repo_id,
449
+ temperature=self.temperature,
450
+ top_p=self.top_p,
451
+ max_new_tokens=512 # Reduced token count for speed
452
+ )
453
+ return response
454
+ except Exception as e:
455
+ debug_print(f"Attempt {attempt+1} failed with error: {e}")
456
+ if attempt == max_retries - 1:
457
+ raise
458
+ time.sleep(backoff)
459
+ backoff *= 2 # exponential backoff
460
+ return "Failed to generate response after multiple attempts."
461
+
462
+ class RemoteLLM(LLM):
463
+ @property
464
+ def _llm_type(self) -> str:
465
+ return "remote_llm"
466
+
467
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
468
+ return remote_generate(prompt)
469
+
470
+ @property
471
+ def _identifying_params(self) -> dict:
472
+ return {"model": repo_id}
473
+
474
+ debug_print("Remote Meta-Llama-3 pipeline created successfully.")
475
+ return RemoteLLM()
476
+
477
+ elif "mistral-api" in normalized:
478
+ debug_print("Creating Mistral API pipeline...")
479
+ mistral_api_key = os.environ.get("MISTRAL_API_KEY")
480
+ if not mistral_api_key:
481
+ raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
482
+ try:
483
+ from mistralai import Mistral
484
+ debug_print("Mistral library imported successfully")
485
+ except ImportError:
486
+ debug_print("Mistral client library not installed. Falling back to Llama pipeline.")
487
+ normalized = "llama"
488
+ if normalized != "llama":
489
+ # from pydantic import PrivateAttr
490
+ # from langchain.llms.base import LLM
491
+ # from typing import Any, Optional, List
492
+ # import typing
493
+
494
+ class MistralLLM(LLM):
495
+ temperature: float = 0.7
496
+ top_p: float = 0.95
497
+ _client: Any = PrivateAttr(default=None)
498
+
499
+ def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
500
+ try:
501
+ super().__init__(**kwargs)
502
+ # Bypass Pydantic's __setattr__ to assign to _client
503
+ object.__setattr__(self, '_client', Mistral(api_key=api_key))
504
+ self.temperature = temperature
505
+ self.top_p = top_p
506
+ except Exception as e:
507
+ debug_print(f"Init Mistral failed with error: {e}")
508
+
509
+ @property
510
+ def _llm_type(self) -> str:
511
+ return "mistral_llm"
512
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
513
+ try:
514
+ debug_print("Calling Mistral API...")
515
+ response = self._client.chat.complete(
516
+ model="mistral-small-latest",
517
+ messages=[{"role": "user", "content": prompt}],
518
+ temperature=self.temperature,
519
+ top_p=self.top_p
520
+ )
521
+ return response.choices[0].message.content
522
+ except Exception as e:
523
+ debug_print(f"Mistral API error: {str(e)}")
524
+ return f"Error generating response: {str(e)}"
525
+ @property
526
+ def _identifying_params(self) -> dict:
527
+ return {"model": "mistral-small-latest"}
528
+ debug_print("Creating Mistral LLM instance")
529
+ mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
530
+ debug_print("Mistral API pipeline created successfully.")
531
+ return mistral_llm
532
+
533
+ else:
534
+ # Default case - using a fallback model (or Llama)
535
+ debug_print("Using local/fallback model pipeline")
536
+ model_id = "facebook/opt-350m" # Use a smaller model as fallback
537
+ pipe = pipeline(
538
+ "text-generation",
539
+ model=model_id,
540
+ device=-1, # CPU
541
+ max_length=1024
542
+ )
543
+
544
+ class LocalLLM(LLM):
545
+ @property
546
+ def _llm_type(self) -> str:
547
+ return "local_llm"
548
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
549
+ # For this fallback, truncate prompt if it exceeds limits
550
+ reserved_gen = 128
551
+ max_total = 1024
552
+ max_prompt_tokens = max_total - reserved_gen
553
+ truncated_prompt = truncate_prompt(prompt, max_tokens=max_prompt_tokens)
554
+ generated = pipe(truncated_prompt, max_new_tokens=reserved_gen)[0]["generated_text"]
555
+ return generated
556
+ @property
557
+ def _identifying_params(self) -> dict:
558
+ return {"model": model_id, "max_length": 1024}
559
+
560
+ debug_print("Local fallback pipeline created.")
561
+ return LocalLLM()
562
+
563
+ except Exception as e:
564
+ debug_print(f"Error creating LLM pipeline: {str(e)}")
565
+ # Return a dummy LLM that explains the error
566
+ class ErrorLLM(LLM):
567
+ @property
568
+ def _llm_type(self) -> str:
569
+ return "error_llm"
570
+
571
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
572
+ return f"Error initializing LLM: \n\nPlease check your environment variables and try again."
573
+
574
+ @property
575
+ def _identifying_params(self) -> dict:
576
+ return {"model": "error"}
577
+
578
+ return ErrorLLM()
579
+
580
+
581
+ def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
582
+ debug_print(f"Updating chain with new model: {new_model_choice}")
583
+ self.llm_choice = new_model_choice
584
+ self.temperature = temperature
585
+ self.top_p = top_p
586
+ self.prompt_template = prompt_template
587
+ self.bm25_weight = bm25_weight
588
+ self.faiss_weight = 1.0 - bm25_weight
589
+ self.llm = self.create_llm_pipeline()
590
+ def format_response(response: str) -> str:
591
+ input_tokens = count_tokens(self.context + self.prompt_template)
592
+ output_tokens = count_tokens(response)
593
+ formatted = f"### Response\n\n{response}\n\n---\n"
594
+ formatted += f"- **Input tokens:** {input_tokens}\n"
595
+ formatted += f"- **Output tokens:** {output_tokens}\n"
596
+ formatted += f"- **Generated using:** {self.llm_choice}\n"
597
+ formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
598
+ return formatted
599
+ base_runnable = RunnableParallel({
600
+ "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
601
+ "question": RunnableLambda(self.extract_question)
602
+ }) | self.capture_context
603
+ self.elevated_rag_chain = base_runnable | self.rag_prompt | self.llm | format_response
604
+ debug_print("Chain updated successfully with new LLM pipeline.")
605
+
606
+ def add_pdfs_to_vectore_store(self, file_links: List[str]) -> None:
607
+ debug_print(f"Processing files using {self.llm_choice}")
608
+ self.raw_data = []
609
+ for link in file_links:
610
+ if link.lower().endswith(".pdf"):
611
+ debug_print(f"Loading PDF: {link}")
612
+ loaded_docs = OnlinePDFLoader(link).load()
613
+ if loaded_docs:
614
+ self.raw_data.append(loaded_docs[0])
615
+ else:
616
+ debug_print(f"No content found in PDF: {link}")
617
+ elif link.lower().endswith(".txt") or link.lower().endswith(".utf-8"):
618
+ debug_print(f"Loading TXT: {link}")
619
+ try:
620
+ self.raw_data.append(load_txt_from_url(link))
621
+ except Exception as e:
622
+ debug_print(f"Error loading TXT file {link}: {e}")
623
+ else:
624
+ debug_print(f"File type not supported for URL: {link}")
625
+ if not self.raw_data:
626
+ raise ValueError("No files were successfully loaded. Please check the URLs and file formats.")
627
+ debug_print("Files loaded successfully.")
628
+ debug_print("Starting text splitting...")
629
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100)
630
+ self.split_data = self.text_splitter.split_documents(self.raw_data)
631
+ if not self.split_data:
632
+ raise ValueError("Text splitting resulted in no chunks. Check the file contents.")
633
+ debug_print(f"Text splitting completed. Number of chunks: {len(self.split_data)}")
634
+ debug_print("Creating BM25 retriever...")
635
+ self.bm25_retriever = BM25Retriever.from_documents(self.split_data)
636
+ self.bm25_retriever.k = self.top_k
637
+ debug_print("BM25 retriever created.")
638
+ debug_print("Embedding chunks and creating FAISS vector store...")
639
+ self.vector_store = FAISS.from_documents(self.split_data, self.embed_func)
640
+ self.faiss_retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
641
+ debug_print("FAISS vector store created successfully.")
642
+ self.ensemble_retriever = EnsembleRetriever(
643
+ retrievers=[self.bm25_retriever, self.faiss_retriever],
644
+ weights=[self.bm25_weight, self.faiss_weight]
645
+ )
646
+
647
+ base_runnable = RunnableParallel({
648
+ "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
649
+ "question": RunnableLambda(self.extract_question)
650
+ }) | self.capture_context
651
+
652
+ # Ensure the prompt template is set
653
+ self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
654
+ if self.rag_prompt is None:
655
+ raise ValueError("Prompt template could not be created from the given template.")
656
+ prompt_runnable = RunnableLambda(lambda vars: self.rag_prompt.format(**vars))
657
+
658
+ self.str_output_parser = StrOutputParser()
659
+ debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
660
+ self.llm = self.create_llm_pipeline()
661
+ if self.llm is None:
662
+ raise ValueError("LLM pipeline creation failed.")
663
+
664
+ def format_response(response: str) -> str:
665
+ input_tokens = count_tokens(self.context + self.prompt_template)
666
+ output_tokens = count_tokens(response)
667
+ formatted = f"### Response\n\n{response}\n\n---\n"
668
+ formatted += f"- **Input tokens:** {input_tokens}\n"
669
+ formatted += f"- **Output tokens:** {output_tokens}\n"
670
+ formatted += f"- **Generated using:** {self.llm_choice}\n"
671
+ formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
672
+ return formatted
673
+
674
+ self.elevated_rag_chain = base_runnable | prompt_runnable | self.llm | format_response
675
+ debug_print("Elevated RAG chain successfully built and ready to use.")
676
+
677
+
678
+
679
+ def get_current_context(self) -> str:
680
+ base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
681
+ history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
682
+ recent = self.conversation_history[-3:]
683
+ if recent:
684
+ for i, conv in enumerate(recent, 1):
685
+ history_summary += f"**Conversation {i}:**\n- Query: {conv['query']}\n- Response: {conv['response']}\n"
686
+ else:
687
+ history_summary += "No conversation history."
688
+ return base_context + history_summary
689
+
690
+ # ----------------------------
691
+ # Gradio Interface Functions
692
+ # ----------------------------
693
+ global rag_chain
694
+ rag_chain = ElevatedRagChain()
695
+
696
+ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
697
+ debug_print("Inside load_pdfs function.")
698
+ if not file_links:
699
+ debug_print("Please enter non-empty URLs")
700
+ return "Please enter non-empty URLs", "Word count: N/A", "Model used: N/A", "Context: N/A"
701
+ try:
702
+ links = [link.strip() for link in file_links.split("\n") if link.strip()]
703
+ global rag_chain
704
+ if rag_chain.raw_data:
705
+ rag_chain.update_llm_pipeline(model_choice, temperature, top_p, prompt_template, bm25_weight)
706
+ context_display = rag_chain.get_current_context()
707
+ response_msg = f"Files already loaded. Chain updated with model: {model_choice}"
708
+ return (
709
+ response_msg,
710
+ f"Word count: {word_count(rag_chain.context)}",
711
+ f"Model used: {rag_chain.llm_choice}",
712
+ f"Context:\n{context_display}"
713
+ )
714
+ else:
715
+ rag_chain = ElevatedRagChain(
716
+ llm_choice=model_choice,
717
+ prompt_template=prompt_template,
718
+ bm25_weight=bm25_weight,
719
+ temperature=temperature,
720
+ top_p=top_p
721
+ )
722
+ rag_chain.add_pdfs_to_vectore_store(links)
723
+ context_display = rag_chain.get_current_context()
724
+ response_msg = f"Files loaded successfully. Using model: {model_choice}"
725
+ return (
726
+ response_msg,
727
+ f"Word count: {word_count(rag_chain.context)}",
728
+ f"Model used: {rag_chain.llm_choice}",
729
+ f"Context:\n{context_display}"
730
+ )
731
+ except Exception as e:
732
+ error_msg = traceback.format_exc()
733
+ debug_print("Could not load files. Error: " + error_msg)
734
+ return (
735
+ "Error loading files: " + str(e),
736
+ f"Word count: {word_count('')}",
737
+ f"Model used: {rag_chain.llm_choice}",
738
+ "Context: N/A"
739
+ )
740
+
741
+ def update_model(new_model: str):
742
+ global rag_chain
743
+ if rag_chain and rag_chain.raw_data:
744
+ rag_chain.update_llm_pipeline(new_model, rag_chain.temperature, rag_chain.top_p,
745
+ rag_chain.prompt_template, rag_chain.bm25_weight)
746
+ debug_print(f"Model updated to {rag_chain.llm_choice}")
747
+ return f"Model updated to: {rag_chain.llm_choice}"
748
+ else:
749
+ return "No files loaded; please load files first."
750
+
751
+
752
+ # Update submit_query_updated to better handle context limitation
753
+ def submit_query_updated(query):
754
+ debug_print(f"Processing query: {query}")
755
+ if not query:
756
+ debug_print("Empty query received")
757
+ return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
758
+
759
+ if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
760
+ debug_print("RAG chain not initialized")
761
+ return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
762
+
763
+ try:
764
+ # Determine max context size based on model
765
+ model_name = rag_chain.llm_choice.lower()
766
+ max_context_tokens = 32000 if "mistral" in model_name else 4096
767
+
768
+ # Reserve 20% of tokens for the question and response generation
769
+ reserved_tokens = int(max_context_tokens * 0.2)
770
+ max_context_tokens -= reserved_tokens
771
+
772
+ # Collect conversation history (last 2 only to save tokens)
773
+ if rag_chain.conversation_history:
774
+ recent_history = rag_chain.conversation_history[-2:]
775
+ history_text = "\n".join([f"Q: {conv['query']}\nA: {conv['response'][:300]}..."
776
+ for conv in recent_history])
777
+ else:
778
+ history_text = ""
779
+
780
+ # Get history token count
781
+ history_tokens = count_tokens(history_text)
782
+
783
+ # Adjust context tokens based on history size
784
+ context_tokens = max_context_tokens - history_tokens
785
+
786
+ # Ensure we have some minimum context
787
+ context_tokens = max(context_tokens, 1000)
788
+
789
+ # Truncate context if needed
790
+ context = truncate_prompt(rag_chain.context, max_tokens=context_tokens)
791
+
792
+ debug_print(f"Using model: {model_name}, context tokens: {count_tokens(context)}, history tokens: {history_tokens}")
793
+
794
+ prompt_variables = {
795
+ "conversation_history": history_text,
796
+ "context": context,
797
+ "question": query
798
+ }
799
+
800
+ debug_print("Invoking RAG chain")
801
+ response = rag_chain.elevated_rag_chain.invoke({"question": query})
802
+
803
+ # Store only a reasonable amount of the response in history
804
+ trimmed_response = response[:1000] + ("..." if len(response) > 1000 else "")
805
+ rag_chain.conversation_history.append({"query": query, "response": trimmed_response})
806
+
807
+ input_token_count = count_tokens(query)
808
+ output_token_count = count_tokens(response)
809
+
810
+ debug_print(f"Query processed successfully. Output tokens: {output_token_count}")
811
+
812
+ return (
813
+ response,
814
+ rag_chain.get_current_context(),
815
+ f"Input tokens: {input_token_count}",
816
+ f"Output tokens: {output_token_count}"
817
+ )
818
+ except Exception as e:
819
+ error_msg = traceback.format_exc()
820
+ debug_print(f"LLM error: {error_msg}")
821
+ return (
822
+ f"Query error: {str(e)}\n\nTry using a smaller document or simplifying your query.",
823
+ "",
824
+ "Input tokens: 0",
825
+ "Output tokens: 0"
826
+ )
827
+
828
+ def reset_app_updated():
829
+ global rag_chain
830
+ rag_chain = ElevatedRagChain()
831
+ debug_print("App reset successfully.")
832
+ return (
833
+ "App reset successfully. You can now load new files",
834
+ "",
835
+ "Model used: Not selected"
836
+ )
837
+
838
+ # ----------------------------
839
+ # Gradio Interface Setup
840
+ # ----------------------------
841
+ custom_css = """
842
+ textarea {
843
+ overflow-y: scroll !important;
844
+ max-height: 200px;
845
+ }
846
+ """
847
+
848
+ # Update the Gradio interface to include job status checking
849
+ with gr.Blocks(css=custom_css, js="""
850
+ document.addEventListener('DOMContentLoaded', function() {
851
+ // Add event listener for job list clicks
852
+ const jobListInterval = setInterval(() => {
853
+ const jobLinks = document.querySelectorAll('.job-list-container a');
854
+ if (jobLinks.length > 0) {
855
+ jobLinks.forEach(link => {
856
+ link.addEventListener('click', function(e) {
857
+ e.preventDefault();
858
+ const jobId = this.textContent.split(' ')[0];
859
+ // Find the job ID input textbox and set its value
860
+ const jobIdInput = document.querySelector('.job-id-input input');
861
+ if (jobIdInput) {
862
+ jobIdInput.value = jobId;
863
+ // Trigger the input event to update Gradio's state
864
+ jobIdInput.dispatchEvent(new Event('input', { bubbles: true }));
865
+ }
866
+ });
867
+ });
868
+ clearInterval(jobListInterval);
869
+ }
870
+ }, 500);
871
+ });
872
+ """) as app:
873
+ gr.Markdown('''# PhiRAG - Async Version
874
+ **PhiRAG** Query Your Data with Advanced RAG Techniques
875
+
876
+ **Model Selection & Parameters:** Choose from the following options:
877
+ - 🇺🇸 Remote Meta-Llama-3 - has context windows of 8000 tokens
878
+ - 🇪🇺 Mistral-API - has context windows of 32000 tokens
879
+
880
+ **🔥 Randomness (Temperature):** Adjusts output predictability.
881
+ - Example: 0.2 makes the output very deterministic (less creative), while 0.8 introduces more variety and spontaneity.
882
+
883
+ **🎯 Word Variety (Top‑p):** Limits word choices to a set probability percentage.
884
+ - Example: 0.5 restricts output to the most likely 50% of token choices for a focused answer; 0.95 allows almost all possibilities for more diverse responses.
885
+
886
+ **⚖️ BM25 Weight:** Adjust Lexical vs Semantics.
887
+ - Example: A value of 0.8 puts more emphasis on exact keyword (lexical) matching, while 0.3 shifts emphasis toward semantic similarity.
888
+
889
+ **✏️ Prompt Template:** Edit as desired.
890
+
891
+ **🔗 File URLs:** Enter one URL per line (.pdf or .txt).\
892
+ - Example: Provide one URL per line, such as
893
+ https://www.gutenberg.org/ebooks/8438.txt.utf-8
894
+
895
+ **🔍 Query:** Enter your query below.
896
+
897
+ **⚠️ IMPORTANT: This app now uses asynchronous processing to avoid timeout issues**
898
+ - When you load files or submit a query, you'll receive a Job ID
899
+ - Use the "Check Job Status" tab to monitor and retrieve your results
900
+ ''')
901
+
902
+ with gr.Tabs() as tabs:
903
+ with gr.TabItem("Setup & Load Files"):
904
+ with gr.Row():
905
+ with gr.Column():
906
+ model_dropdown = gr.Dropdown(
907
+ choices=["🇺🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
908
+ value="🇺🇸 Remote Meta-Llama-3",
909
+ label="Select Model"
910
+ )
911
+ temperature_slider = gr.Slider(
912
+ minimum=0.1, maximum=1.0, value=0.5, step=0.1,
913
+ label="Randomness (Temperature)"
914
+ )
915
+ top_p_slider = gr.Slider(
916
+ minimum=0.1, maximum=0.99, value=0.95, step=0.05,
917
+ label="Word Variety (Top-p)"
918
+ )
919
+ with gr.Column():
920
+ pdf_input = gr.Textbox(
921
+ label="Enter your file URLs (one per line)",
922
+ placeholder="Enter one URL per line (.pdf or .txt)",
923
+ lines=4
924
+ )
925
+ prompt_input = gr.Textbox(
926
+ label="Custom Prompt Template",
927
+ placeholder="Enter your custom prompt template here",
928
+ lines=8,
929
+ value=default_prompt
930
+ )
931
+ with gr.Column():
932
+ bm25_weight_slider = gr.Slider(
933
+ minimum=0.0, maximum=1.0, value=0.6, step=0.1,
934
+ label="Lexical vs Semantics (BM25 Weight)"
935
+ )
936
+ load_button = gr.Button("Load Files (Async)")
937
+ load_status = gr.Markdown("Status: Waiting for files")
938
+
939
+ with gr.Row():
940
+ load_response = gr.Textbox(
941
+ label="Load Response",
942
+ placeholder="Response will appear here",
943
+ lines=4
944
+ )
945
+ load_context = gr.Textbox(
946
+ label="Context Info",
947
+ placeholder="Context info will appear here",
948
+ lines=4
949
+ )
950
+
951
+ with gr.Row():
952
+ model_output = gr.Markdown("**Current Model**: Not selected")
953
+
954
+ with gr.TabItem("Submit Query"):
955
+ with gr.Row():
956
+ # Add this line to define the query_model_dropdown
957
+ query_model_dropdown = gr.Dropdown(
958
+ choices=["🇺🇸 Remote Meta-Llama-3", "🇪🇺 Mistral-API"],
959
+ value="🇺🇸 Remote Meta-Llama-3",
960
+ label="Query Model"
961
+ )
962
+
963
+ query_input = gr.Textbox(
964
+ label="Enter your query here",
965
+ placeholder="Type your query",
966
+ lines=4
967
+ )
968
+ submit_button = gr.Button("Submit Query (Async)")
969
+
970
+ with gr.Row():
971
+ query_response = gr.Textbox(
972
+ label="Query Response",
973
+ placeholder="Response will appear here (formatted as Markdown)",
974
+ lines=6
975
+ )
976
+ query_context = gr.Textbox(
977
+ label="Context Information",
978
+ placeholder="Retrieved context and conversation history will appear here",
979
+ lines=6
980
+ )
981
+
982
+ with gr.Row():
983
+ input_tokens = gr.Markdown("Input tokens: 0")
984
+ output_tokens = gr.Markdown("Output tokens: 0")
985
+
986
+ with gr.TabItem("Check Job Status"):
987
+ with gr.Row():
988
+ with gr.Column(scale=1):
989
+ job_list = gr.Markdown(
990
+ value="No jobs yet",
991
+ label="Job List (Click to select)"
992
+ )
993
+ refresh_button = gr.Button("Refresh Job List")
994
+
995
+ with gr.Column(scale=2):
996
+ job_id_input = gr.Textbox(
997
+ label="Job ID",
998
+ placeholder="Job ID will appear here when selected from the list",
999
+ lines=1
1000
+ )
1001
+ job_query_display = gr.Textbox(
1002
+ label="Job Query",
1003
+ placeholder="The query associated with this job will appear here",
1004
+ lines=2,
1005
+ interactive=False
1006
+ )
1007
+ check_button = gr.Button("Check Status")
1008
+ cleanup_button = gr.Button("Cleanup Old Jobs")
1009
+
1010
+ with gr.Row():
1011
+ status_response = gr.Textbox(
1012
+ label="Job Result",
1013
+ placeholder="Job result will appear here",
1014
+ lines=6
1015
+ )
1016
+ status_context = gr.Textbox(
1017
+ label="Context Information",
1018
+ placeholder="Context information will appear here",
1019
+ lines=6
1020
+ )
1021
+
1022
+ with gr.Row():
1023
+ status_tokens1 = gr.Markdown("")
1024
+ status_tokens2 = gr.Markdown("")
1025
+
1026
+ with gr.TabItem("App Management"):
1027
+ with gr.Row():
1028
+ reset_button = gr.Button("Reset App")
1029
+
1030
+ with gr.Row():
1031
+ reset_response = gr.Textbox(
1032
+ label="Reset Response",
1033
+ placeholder="Reset confirmation will appear here",
1034
+ lines=2
1035
+ )
1036
+ reset_context = gr.Textbox(
1037
+ label="",
1038
+ placeholder="",
1039
+ lines=2,
1040
+ visible=False
1041
+ )
1042
+
1043
+ with gr.Row():
1044
+ reset_model = gr.Markdown("")
1045
+
1046
+ # Connect the buttons to their respective functions
1047
+ load_button.click(
1048
+ load_pdfs_async,
1049
+ inputs=[pdf_input, model_dropdown, prompt_input, bm25_weight_slider, temperature_slider, top_p_slider],
1050
+ outputs=[load_response, load_context, model_output, job_id_input, job_query_display, job_list]
1051
+ )
1052
+
1053
+ # Also sync in the other direction
1054
+ query_model_dropdown.change(
1055
+ fn=sync_model_dropdown,
1056
+ inputs=query_model_dropdown,
1057
+ outputs=model_dropdown
1058
+ )
1059
+
1060
+ submit_button.click(
1061
+ submit_query_async,
1062
+ inputs=[query_input, query_model_dropdown],
1063
+ outputs=[query_response, query_context, input_tokens, output_tokens, job_id_input, job_query_display, job_list]
1064
+ )
1065
+
1066
+ check_button.click(
1067
+ check_job_status,
1068
+ inputs=[job_id_input],
1069
+ outputs=[status_response, status_context, status_tokens1, status_tokens2, job_query_display]
1070
+ )
1071
+
1072
+ refresh_button.click(
1073
+ refresh_job_list,
1074
+ inputs=[],
1075
+ outputs=[job_list]
1076
+ )
1077
+
1078
+ # Connect the job list selection event (this is handled by JavaScript)
1079
+ job_id_input.change(
1080
+ job_selected,
1081
+ inputs=[job_id_input],
1082
+ outputs=[job_id_input, job_query_display]
1083
+ )
1084
+
1085
+ cleanup_button.click(
1086
+ cleanup_old_jobs,
1087
+ inputs=[],
1088
+ outputs=[status_response, status_context, status_tokens1]
1089
+ )
1090
+
1091
+ reset_button.click(
1092
+ reset_app_updated,
1093
+ inputs=[],
1094
+ outputs=[reset_response, reset_context, reset_model]
1095
+ )
1096
+
1097
+
1098
+ model_dropdown.change(
1099
+ fn=sync_model_dropdown,
1100
+ inputs=model_dropdown,
1101
+ outputs=query_model_dropdown
1102
+ )
1103
+
1104
+ # Add an event to refresh the job list on page load
1105
+ app.load(
1106
+ fn=refresh_job_list,
1107
+ inputs=None,
1108
+ outputs=job_list
1109
+ )
1110
+
1111
+ if __name__ == "__main__":
1112
+ debug_print("Launching Gradio interface.")
1113
+ app.launch(share=False)