rivapereira123 commited on
Commit
f67b75b
Β·
verified Β·
1 Parent(s): 22adc85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -56
app.py CHANGED
@@ -307,56 +307,32 @@ class OptimizedGazaRAGSystem:
307
  logger.info("πŸš€ Initializing Optimized Gaza RAG System...")
308
  self.knowledge_base.initialize()
309
  logger.info("βœ… Optimized Gaza RAG System ready!")
 
310
 
 
311
  def _initialize_llm(self):
312
- """Enhanced LLM initialization with better error handling"""
313
- if self.llm is not None:
314
- return
315
-
316
- model_name = "microsoft/Phi-3-mini-4k-instruct"
317
- try:
318
- logger.info(f"πŸ”„ Loading LLM: {model_name}")
319
-
320
- # Enhanced quantization configuration
321
- quantization_config = BitsAndBytesConfig(
322
- load_in_4bit=True,
323
- bnb_4bit_use_double_quant=True,
324
- bnb_4bit_quant_type="nf4",
325
- bnb_4bit_compute_dtype=torch.float16,
326
- )
 
 
 
 
 
 
327
 
328
- self.tokenizer = AutoTokenizer.from_pretrained(
329
- model_name,
330
- trust_remote_code=True,
331
- padding_side="left"
332
- )
333
-
334
- if self.tokenizer.pad_token is None:
335
- self.tokenizer.pad_token = self.tokenizer.eos_token
336
-
337
- self.llm = AutoModelForCausalLM.from_pretrained(
338
- model_name,
339
- quantization_config=quantization_config,
340
- device_map="auto",
341
- trust_remote_code=True,
342
- torch_dtype=torch.float16,
343
- low_cpu_mem_usage=True
344
- )
345
-
346
- self.generation_pipeline = pipeline(
347
- "text-generation",
348
- model=self.llm,
349
- tokenizer=self.tokenizer,
350
- device_map="auto",
351
- torch_dtype=torch.float16,
352
- return_full_text=False
353
- )
354
-
355
- logger.info("βœ… LLM loaded successfully")
356
-
357
- except Exception as e:
358
- logger.error(f"❌ Error loading primary model: {e}")
359
- raise RuntimeError("Model loading failed β€” check GPU availability and bitsandbytes install")
360
 
361
 
362
  def _initialize_fallback_llm(self):
@@ -529,15 +505,9 @@ RESPONSE (provide practical, Gaza-appropriate medical guidance):"""
529
 
530
  # Generate the response
531
  with torch.no_grad():
532
- outputs = self.llm.generate(
533
- **inputs,
534
- max_new_tokens=600,
535
- temperature=0.3,
536
- pad_token_id=self.tokenizer.eos_token_id,
537
- do_sample=True,
538
- repetition_penalty=1.15,
539
- no_repeat_ngram_size=3
540
- )
541
 
542
  # Decode and clean up
543
  response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
307
  logger.info("πŸš€ Initializing Optimized Gaza RAG System...")
308
  self.knowledge_base.initialize()
309
  logger.info("βœ… Optimized Gaza RAG System ready!")
310
+
311
 
312
+
313
  def _initialize_llm(self):
314
+ """Initialize FLAN-T5 model (CPU-friendly)"""
315
+ model_name = "google/flan-t5-base"
316
+ try:
317
+ logger.info(f"πŸ”„ Loading fallback CPU model: {model_name}")
318
+
319
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
320
+ self.llm = AutoModelForCausalLM.from_pretrained(model_name)
321
+
322
+ self.generation_pipeline = pipeline(
323
+ "text2text-generation", # <-- Important for T5!
324
+ model=self.llm,
325
+ tokenizer=self.tokenizer,
326
+ return_full_text=False
327
+ )
328
+
329
+ logger.info("βœ… FLAN-T5 model loaded successfully")
330
+
331
+ except Exception as e:
332
+ logger.error(f"❌ Error loading FLAN-T5 model: {e}")
333
+ self.llm = None
334
+ self.generation_pipeline = None
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
 
338
  def _initialize_fallback_llm(self):
 
505
 
506
  # Generate the response
507
  with torch.no_grad():
508
+ outputs = self.generation_pipeline(prompt, max_new_tokens=300, temperature=0.3, repetition_penalty=1.15, no_repeat_ngram_size=3)
509
+ response_text = outputs[0]["generated_text"]
510
+
 
 
 
 
 
 
511
 
512
  # Decode and clean up
513
  response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)