| | import re |
| | from hashlib import md5 |
| | from sentence_transformers import SentenceTransformer |
| | from langchain_text_splitters import RecursiveCharacterTextSplitter |
| | from transformers import AutoTokenizer |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | _embed_model = None |
| | _flant5tokenizer = None |
| |
|
| | def get_embed_model(): |
| | """Load embedding model (lazy-loaded on first use)""" |
| | global _embed_model |
| | if _embed_model is None: |
| | _embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
| | return _embed_model |
| |
|
| | def get_flant5_tokenizer(): |
| | """Load Flan-T5 tokenizer (lazy-loaded on first use)""" |
| | global _flant5tokenizer |
| | if _flant5tokenizer is None: |
| | _flant5tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") |
| | return _flant5tokenizer |
| |
|
| |
|
| | def normalize_text(text): |
| | """Normalize text for duplicate detection by removing spaces and punctuation. |
| | Returns: str""" |
| | if not isinstance(text, str): |
| | return "" |
| | |
| | text = text.lower() |
| | text = re.sub(r'\s+', ' ', text) |
| | text = text.strip() |
| | text = re.sub(r'[^\w\s]', '', text) |
| | |
| | return text |
| |
|
| |
|
| | def chunk_text(text, chunk_size=384, chunk_overlap=20): |
| | """Split text into chunks for embeddings. |
| | Returns: list(chunks)""" |
| | splitter = RecursiveCharacterTextSplitter( |
| | separators=["\n\n", "\n", " ", ""], |
| | chunk_size=chunk_size, |
| | chunk_overlap=chunk_overlap |
| | ) |
| | return splitter.split_text(text) |
| |
|
| |
|
| | def create_embeddings(texts): |
| | """Create embeddings for a list of texts. |
| | Returns: list(embeddings)""" |
| | embed_model = get_embed_model() |
| | return embed_model.encode( |
| | texts, |
| | batch_size=64, |
| | show_progress_bar=False, |
| | convert_to_numpy=True, |
| | normalize_embeddings=True |
| | ) |
| |
|
| |
|
| | def refine_response(answer): |
| | """Clean and format generated response text. |
| | Returns: str(refined_answer)""" |
| | |
| | answer = re.sub(r'\. {2,}', '.', answer) |
| | answer = re.sub(r'\.([^\s])', r'. \1', answer) |
| | |
| | |
| | if not answer.strip().endswith(('.', '!', '?')): |
| | last_punc_pos = max(answer.rfind('.'), answer.rfind('!'), answer.rfind('?')) |
| | if last_punc_pos != -1: |
| | answer = answer[:last_punc_pos + 1] |
| | |
| | |
| | sentences = re.split(r'([.!?]\s*)', answer) |
| | refined_sentences = [] |
| | for i in range(0, len(sentences), 2): |
| | sentence_part = sentences[i].strip() |
| | if sentence_part: |
| | refined_sentences.append(sentence_part.capitalize()) |
| | if i + 1 < len(sentences): |
| | refined_sentences.append(sentences[i + 1]) |
| | |
| | return ''.join(refined_sentences).strip() |
| |
|
| |
|
| | def build_prompt(user_query, context, max_tokens=512): |
| | """Build prompt with context and query within token limit. |
| | Returns: str(full_prompt)""" |
| | |
| | flant5tokenizer = get_flant5_tokenizer() |
| | |
| | if not context: |
| | return f"""No relevant medical information found. |
| | Q: {user_query} |
| | A: Information unavailable.""" |
| | |
| | instruction_text = "Medical Context:\n" |
| | query_footer = f"\nQ: {user_query}\nA:" |
| | |
| | |
| | inst_tokens = len(flant5tokenizer.encode(instruction_text, add_special_tokens=False)) |
| | query_tokens = len(flant5tokenizer.encode(query_footer, add_special_tokens=False)) |
| | total_static_cost = inst_tokens + query_tokens + 5 |
| | |
| | |
| | remaining_tokens = max_tokens - total_static_cost |
| | if remaining_tokens < 0: |
| | remaining_tokens = 0 |
| | |
| | |
| | valid_contexts = [] |
| | current_context_tokens = 0 |
| | |
| | for idx, c in enumerate(context, start=1): |
| | chunk_text = f"[C{idx}] {c['question']}\n{c['chunk_answer']}" |
| | chunk_len = len(flant5tokenizer.encode(chunk_text, add_special_tokens=False)) |
| | |
| | if current_context_tokens + chunk_len > remaining_tokens: |
| | break |
| | |
| | valid_contexts.append(chunk_text) |
| | current_context_tokens += chunk_len |
| | |
| | the_context_block = "\n".join(valid_contexts) |
| | full_prompt = f"{instruction_text}{the_context_block}{query_footer}" |
| | |
| | return full_prompt |
| |
|