Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -480,52 +480,64 @@ Provide clear, actionable advice while emphasizing the need for professional med
|
|
480 |
return "\n\n".join(context_parts)
|
481 |
|
482 |
def _generate_response(self, query: str, context: str) -> str:
|
483 |
-
"""
|
484 |
if self.llm is None or self.tokenizer is None:
|
485 |
return self._generate_fallback_response(query, context)
|
486 |
-
|
487 |
-
# Build prompt with Gaza-specific context
|
488 |
-
prompt = f"""{self.system_prompt}
|
489 |
-
|
490 |
MEDICAL KNOWLEDGE CONTEXT:
|
491 |
{context}
|
492 |
|
493 |
PATIENT QUESTION: {query}
|
494 |
|
495 |
RESPONSE (provide practical, Gaza-appropriate medical guidance):"""
|
496 |
-
|
497 |
try:
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
502 |
|
503 |
-
|
504 |
-
|
505 |
-
outputs = self.generation_pipeline(prompt, max_new_tokens=300, temperature=0.3, repetition_penalty=1.15, no_repeat_ngram_size=3)
|
506 |
-
response_text = outputs[0]["generated_text"]
|
507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
508 |
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
# Extract only the generated part
|
513 |
-
if "RESPONSE (provide practical, Gaza-appropriate medical guidance):" in response_text:
|
514 |
-
response_text = response_text.split("RESPONSE (provide practical, Gaza-appropriate medical guidance):")[1]
|
515 |
-
|
516 |
-
# Clean up the response
|
517 |
-
lines = response_text.split('\n')
|
518 |
-
unique_lines = []
|
519 |
-
unique_lines.append(line)
|
520 |
-
final_response = '\n'.join(unique_lines)
|
521 |
-
logger.info(f"🧪 Final cleaned response:\n{final_response}")
|
522 |
|
523 |
-
|
|
|
|
|
|
|
|
|
524 |
|
525 |
|
526 |
-
except Exception as e:
|
527 |
-
logger.error(f"❌ Error in LLM generate(): {e}")
|
528 |
-
return self._generate_fallback_response(query, context)
|
529 |
|
530 |
def _generate_fallback_response(self, query: str, context: str) -> str:
|
531 |
"""Enhanced fallback response with Gaza-specific guidance"""
|
|
|
480 |
return "\n\n".join(context_parts)
|
481 |
|
482 |
def _generate_response(self, query: str, context: str) -> str:
|
483 |
+
"""Generate response using T5-style seq2seq model with Gaza-specific context"""
|
484 |
if self.llm is None or self.tokenizer is None:
|
485 |
return self._generate_fallback_response(query, context)
|
486 |
+
prompt = f"""{self.system_prompt}
|
|
|
|
|
|
|
487 |
MEDICAL KNOWLEDGE CONTEXT:
|
488 |
{context}
|
489 |
|
490 |
PATIENT QUESTION: {query}
|
491 |
|
492 |
RESPONSE (provide practical, Gaza-appropriate medical guidance):"""
|
|
|
493 |
try:
|
494 |
+
inputs = self.tokenizer(
|
495 |
+
prompt,
|
496 |
+
return_tensors="pt",
|
497 |
+
truncation=True,
|
498 |
+
max_length=512,
|
499 |
+
padding="max_length"
|
500 |
+
)
|
501 |
+
input_ids = inputs["input_ids"]
|
502 |
+
attention_mask = inputs["attention_mask"]
|
503 |
+
device = self.llm.device if hasattr(self.llm, "device") else "cpu"
|
504 |
+
input_ids = input_ids.to(device)
|
505 |
+
attention_mask = attention_mask.to(device)
|
506 |
+
|
507 |
+
# Generate output
|
508 |
+
with torch.no_grad():
|
509 |
+
outputs = self.llm.generate(
|
510 |
+
input_ids=input_ids,
|
511 |
+
attention_mask=attention_mask,
|
512 |
+
max_new_tokens=256,
|
513 |
+
temperature=0.3,
|
514 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
515 |
+
do_sample=True,
|
516 |
+
repetition_penalty=1.15,
|
517 |
+
no_repeat_ngram_size=3
|
518 |
+
)
|
519 |
|
520 |
+
# Decode result
|
521 |
+
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
522 |
|
523 |
+
# Clean and filter output
|
524 |
+
lines = response_text.split('\n')
|
525 |
+
unique_lines = []
|
526 |
+
for line in lines:
|
527 |
+
line = line.strip()
|
528 |
+
if line and line not in unique_lines and len(line) > 10:
|
529 |
+
unique_lines.append(line)
|
530 |
|
531 |
+
final_response = '\n'.join(unique_lines)
|
532 |
+
logger.info(f"🧪 Final cleaned response:\n{final_response}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
533 |
|
534 |
+
return final_response
|
535 |
+
|
536 |
+
except Exception as e:
|
537 |
+
logger.error(f"❌ Error in LLM generate(): {e}")
|
538 |
+
return self._generate_fallback_response(query, context)
|
539 |
|
540 |
|
|
|
|
|
|
|
541 |
|
542 |
def _generate_fallback_response(self, query: str, context: str) -> str:
|
543 |
"""Enhanced fallback response with Gaza-specific guidance"""
|