Spaces:
Sleeping
Sleeping
File size: 1,719 Bytes
8595342 67e765b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# app/services/query_expansion_service.py
import logging
from typing import List
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import csv
import re
import os
from typing import Dict
import re # <-- NEW IMPORT for regular expressions
# --- Core App Imports ---
from app.core import state
from app.core.config import settings
logger = logging.getLogger(__name__)
def replace_abbreviations(query_text: str) -> str:
"""
Expands a predefined set of abbreviations.
Specifically, "AA" is only replaced if "Arrangement" is not already in the query.
"""
# 1. Define all replacement rules.
replacements = {
'tph': 'Payment Hub',
'aa': 'Arrangement'
# Add other unconditional replacements here.
}
# 2. Check if the word "Arrangement" is already in the query (case-insensitive).
# We use a regex with \b to ensure we match the whole word.
if re.search(r'\bArrangement\b', query_text, re.IGNORECASE):
# If "Arrangement" is found, we don't want to replace "AA".
# So, we remove the 'aa' rule from our dictionary for this run.
del replacements['aa']
# 3. If there are no rules left to apply, return the original query.
if not replacements:
return query_text
# 4. Build the regex pattern ONLY with the rules that are still active.
pattern = re.compile(
r'\b(' + '|'.join(replacements.keys()) + r')\b',
re.IGNORECASE
)
# 5. The replacer function remains the same.
def get_replacement(match):
return replacements[match.group(0).lower()]
# 6. Perform the substitution and return the result.
return pattern.sub(get_replacement, query_text)
|