bonsAI-NLP / app.py
jin3213's picture
Update app.py
34a5565 verified
# app.py
# BonsAI – Pharmaceutical QA System (DistilBERT, XLM-R)
# pip install -U gradio transformers torch sentence-transformers scikit-learn numpy rapidfuzz safetensors huggingface_hub
# python app.py
import os
import json
import re
import difflib
from typing import List, Dict, Tuple
import gradio as gr
import numpy as np
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from sentence_transformers import SentenceTransformer
try:
from sentence_transformers.cross_encoder import CrossEncoder
except Exception:
CrossEncoder = None
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
# Better fuzzy matching (optional but recommended)
try:
from rapidfuzz import process, fuzz
HAS_RAPIDFUZZ = True
except Exception:
HAS_RAPIDFUZZ = False
# -------------------------
# CONFIG (EDIT IF NEEDED)
# -------------------------
CORPUS_PATH = "drug_entries.json"
# HF model repos (your uploaded models)
# These will be downloaded automatically by Transformers inside the Space runtime.
MODEL_CHOICES = {
"DistilBERT (fine-tuned)": "jin3213/distilbert",
"XLM-RoBERTa (fine-tuned)": "jin3213/xlm-roberta",
"ClinicalBERT (fine-tuned)": "jin3213/clinicalbert",
}
# Retrieval models
DENSE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
RERANK_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" # optional
USE_RERANKER = True # set False if you want faster, fewer deps
TOPK_SOURCES = 5
FUSION_K = 60
TOPN_RERANK = 20
# Cache dense embeddings to disk (helps restart speed)
EMB_CACHE_PATH = "dense_embeddings_cache.npy"
# Drug-name fuzzy detection
FUZZY_DRUG_CUTOFF = 0.75 # 0..1 (stricter = fewer false matches)
# Reader settings
MAX_ANSWER_LEN = 80
MAX_SEQ_LEN = 384
DOC_STRIDE = 128
# If True: answer is extracted by QA reader from retrieved passage
# If False: answer directly returns forms_and_strengths from JSON (no QA)
USE_QA_READER = True
# IMPORTANT: avoids Hugging Face "backend tokenizer" instantiation issues on Spaces
# (keeps everything on slow tokenizers, no conversion attempt)
FORCE_SLOW_TOKENIZER = True
# -------------------------
# TEXT UTILS
# -------------------------
def normalize(s: str) -> str:
s = str(s).lower()
s = s.replace("’", "'")
s = re.sub(r"\s+", " ", s).strip()
return s
def normalize_question(q: str) -> str:
q = normalize(q)
q = re.sub(r"[^a-z0-9\s\-\+\/]", " ", q)
q = re.sub(r"\s+", " ", q).strip()
return q
def clean_drug_name(name: str) -> str:
name = normalize(name)
first_line = name.splitlines()[0].strip()
first_line = re.sub(r"\(see [^)]+\)", "", first_line).strip()
first_line = re.sub(r"\([^)]*\)", "", first_line).strip()
first_line = re.sub(r"[^\w\s\+\-\/]", " ", first_line)
first_line = re.sub(r"\s+", " ", first_line).strip()
return first_line
def split_multi_ingredient(raw: str) -> List[str]:
raw_norm = normalize(raw)
parts: List[str] = []
for line in raw_norm.splitlines():
line = line.strip()
if not line:
continue
if "+" in line:
for p in line.split("+"):
p = p.strip()
if p:
parts.append(p)
else:
parts.append(line)
return parts
def pretty_answer(text: str) -> str:
t = str(text).strip()
t = t.replace(";", "\n")
t = re.sub(r"\s*\n\s*", "\n", t).strip()
t = re.sub(r"\s*(Oral:)", r"\n\1", t)
t = re.sub(r"\s*(Injection:)", r"\n\1", t)
t = re.sub(r"\s*(Inhalation:)", r"\n\1", t)
t = re.sub(r"\s*(Topical:)", r"\n\1", t)
t = re.sub(r"\n+", "\n", t).strip()
return t
# -------------------------
# LOAD CORPUS
# -------------------------
if not os.path.exists(CORPUS_PATH):
raise FileNotFoundError(
f"Cannot find {CORPUS_PATH}. Put drug_entries.json beside app.py (in the Space repo root)."
)
with open(CORPUS_PATH, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("drug_entries.json must be a LIST of objects with keys: ingredient, forms_and_strengths, page")
entries: Dict[str, Dict[str, str]] = {}
aliases: Dict[str, str] = {}
passages: List[str] = []
meta: List[Dict[str, str]] = []
canonical_keys: List[str] = []
for obj in data:
if not isinstance(obj, dict):
continue
ingredient_raw = obj.get("ingredient", "")
fas = obj.get("forms_and_strengths", "")
if not ingredient_raw or not fas:
continue
canonical = clean_drug_name(ingredient_raw)
if not canonical:
continue
rec = {
"ingredient": ingredient_raw,
"forms_and_strengths": fas,
"page": obj.get("page", "")
}
entries[canonical] = rec
# aliases for matching
aliases[canonical] = canonical
for part in split_multi_ingredient(ingredient_raw):
base = clean_drug_name(part)
if base:
aliases[base] = canonical
aliases[canonical.replace(" ", "")] = canonical
for canon, rec in entries.items():
canonical_keys.append(canon)
passages.append(f"{rec['ingredient']}\n{rec['forms_and_strengths']}")
meta.append({
"canonical": canon,
"ingredient": rec["ingredient"],
"page": rec.get("page", ""),
"source": "PNF-EML_11022022.pdf"
})
if not entries:
raise ValueError("No valid entries built from drug_entries.json. Check your JSON fields.")
alias_list = sorted(aliases.keys(), key=len, reverse=True)
# -------------------------
# RETRIEVAL INDEX (Option C)
# -------------------------
tfidf = TfidfVectorizer(
lowercase=True,
analyzer="word",
ngram_range=(1, 2),
min_df=1
)
tfidf_matrix = tfidf.fit_transform(passages)
dense_model = SentenceTransformer(DENSE_MODEL_NAME)
def load_dense_cache(path: str, n_expected: int):
try:
if os.path.exists(path):
arr = np.load(path)
if arr.shape[0] == n_expected:
return arr
except Exception:
pass
return None
def save_dense_cache(path: str, arr: np.ndarray):
try:
np.save(path, arr)
except Exception:
pass
dense_embeddings = load_dense_cache(EMB_CACHE_PATH, len(passages))
if dense_embeddings is None:
dense_embeddings = dense_model.encode(
passages,
batch_size=64,
show_progress_bar=True,
normalize_embeddings=True
)
save_dense_cache(EMB_CACHE_PATH, dense_embeddings)
reranker = None
if USE_RERANKER and CrossEncoder is not None:
try:
reranker = CrossEncoder(RERANK_MODEL_NAME)
except Exception:
reranker = None
def sparse_retrieve(query: str, topk: int = 80) -> List[int]:
q_vec = tfidf.transform([query])
sims = cosine_similarity(q_vec, tfidf_matrix).ravel()
idxs = sims.argsort()[::-1][:topk]
return idxs.tolist()
def dense_retrieve(query: str, topk: int = 80) -> List[int]:
q_emb = dense_model.encode([query], normalize_embeddings=True)[0]
sims = (dense_embeddings @ q_emb).astype(float)
idxs = np.argsort(sims)[::-1][:topk]
return idxs.tolist()
def rrf_fusion(ranks_a: List[int], ranks_b: List[int], k: int = FUSION_K) -> Dict[int, float]:
fused: Dict[int, float] = {}
for rank_list in (ranks_a, ranks_b):
for r, idx in enumerate(rank_list, start=1):
fused[idx] = fused.get(idx, 0.0) + 1.0 / (k + r)
return fused
def minmax_norm(items: List[Tuple[int, float]]) -> List[Tuple[int, float]]:
if not items:
return items
vals = [s for _, s in items]
lo, hi = min(vals), max(vals)
if hi - lo < 1e-9:
return [(i, 1.0) for i, _ in items]
return [(i, (s - lo) / (hi - lo)) for i, s in items]
def rerank(query: str, candidate_idxs: List[int]) -> List[Tuple[int, float]]:
if reranker is None:
return [(i, 0.0) for i in candidate_idxs]
pairs = [(query, passages[i]) for i in candidate_idxs]
scores = reranker.predict(pairs)
ranked = list(zip(candidate_idxs, [float(s) for s in scores]))
ranked.sort(key=lambda x: x[1], reverse=True)
return ranked
def rag_retrieve(query: str, topk_sources: int = TOPK_SOURCES) -> List[Dict]:
q = normalize_question(query)
s_idxs = sparse_retrieve(q, topk=80)
d_idxs = dense_retrieve(q, topk=80)
fused_map = rrf_fusion(s_idxs, d_idxs, k=FUSION_K)
fused_sorted = sorted(fused_map.items(), key=lambda x: x[1], reverse=True)
fused_top = [idx for idx, _ in fused_sorted[:max(TOPN_RERANK, topk_sources)]]
if reranker is None:
fused_items = [(idx, fused_map[idx]) for idx in fused_top]
fused_norm = minmax_norm(fused_items)[:topk_sources]
return [
{**meta[idx], "idx": idx, "score": float(score), "method": "fusion(RRF)"}
for idx, score in fused_norm
]
reranked = rerank(q, fused_top)
reranked_norm = minmax_norm(reranked)[:topk_sources]
return [
{**meta[idx], "idx": idx, "score": float(score), "method": "rerank(cross-encoder)"}
for idx, score in reranked_norm
]
# -------------------------
# DRUG DETECTION (for display)
# -------------------------
def detect_drug_alias(question: str):
q_raw = normalize_question(question)
q = " " + q_raw + " "
q_nospace = q.replace(" ", "")
# Exact/substring match first
for a in alias_list:
if f" {a} " in q or (a and a in q_nospace):
return a, 1.0, "EXACT"
# Strong fuzzy over the whole question (RapidFuzz) if available
if HAS_RAPIDFUZZ:
best = process.extractOne(q_raw, alias_list, scorer=fuzz.WRatio)
if best:
cand, score, _ = best
score01 = float(score) / 100.0
if score01 >= FUZZY_DRUG_CUTOFF:
return cand, score01, "RAPIDFUZZ"
# Fallback: token-based difflib
tokens = [t for t in q_raw.split() if len(t) >= 4]
best = None
best_score = 0.0
best_tok = None
for tok in set(tokens):
m = difflib.get_close_matches(tok, alias_list, n=1, cutoff=FUZZY_DRUG_CUTOFF)
if m:
cand = m[0]
score = difflib.SequenceMatcher(None, tok, cand).ratio()
if score > best_score:
best_score = score
best = cand
best_tok = tok
if best:
return best, float(best_score), f"DIFFLIB({best_tok}β†’{best})"
return None, 0.0, "NONE"
# -------------------------
# QA READER (MODEL DROPDOWN) - HF REPOS
# -------------------------
_loaded_readers: Dict[str, Tuple[AutoTokenizer, AutoModelForQuestionAnswering]] = {}
def get_reader(model_key: str) -> Tuple[AutoTokenizer, AutoModelForQuestionAnswering]:
"""
Loads selected HF model repo once, then reuses it.
Works in Hugging Face Spaces without local folders.
Fix: FORCE_SLOW_TOKENIZER prevents tokenizer backend instantiation errors on Spaces.
"""
if model_key in _loaded_readers:
return _loaded_readers[model_key]
model_id = MODEL_CHOICES.get(model_key)
if not model_id:
raise ValueError(f"Unknown model choice: {model_key}")
token = os.getenv("HF_TOKEN", None)
tok_kwargs = {"token": token}
if FORCE_SLOW_TOKENIZER:
tok_kwargs["use_fast"] = False
tok = AutoTokenizer.from_pretrained(model_id, **tok_kwargs)
mdl = AutoModelForQuestionAnswering.from_pretrained(model_id, token=token)
mdl.eval()
_loaded_readers[model_key] = (tok, mdl)
return tok, mdl
def run_reader(question: str, context: str, model_key: str) -> str:
"""
Extractive QA span from context using selected model.
"""
tok, mdl = get_reader(model_key)
inputs = tok(
question,
context,
truncation="only_second",
max_length=MAX_SEQ_LEN,
stride=DOC_STRIDE,
return_overflowing_tokens=False,
return_offsets_mapping=True,
padding="max_length",
return_tensors="pt"
)
offset_mapping = inputs.pop("offset_mapping")[0].tolist()
outputs = mdl(**inputs)
start_logits = outputs.start_logits[0].detach().cpu().numpy()
end_logits = outputs.end_logits[0].detach().cpu().numpy()
best_score = -1e18
best_s, best_e = 0, 0
# Faster span search: only check top candidates
top_start = start_logits.argsort()[-30:][::-1]
top_end = end_logits.argsort()[-30:][::-1]
for s in top_start:
for e in top_end:
if e < s:
continue
if e - s > MAX_ANSWER_LEN:
continue
score = float(start_logits[s] + end_logits[e])
if score > best_score:
best_score = score
best_s, best_e = int(s), int(e)
start_char, _ = offset_mapping[best_s]
_, end_char = offset_mapping[best_e]
if end_char <= start_char:
return ""
return context[start_char:end_char].strip()
# -------------------------
# DISPLAY HELPERS
# -------------------------
def clamp01(x: float) -> float:
return max(0.0, min(1.0, x))
def confidence_bar_html(label: str, pct01: float, subtitle: str = "") -> str:
pct01 = clamp01(pct01)
pct = int(round(pct01 * 100))
sub = f"<div class='conf-sub'>{subtitle}</div>" if subtitle else ""
return f"""
<div class="conf-wrap">
<div class="conf-top">
<div class="conf-title">{label}</div>
<div class="conf-pct">{pct}%</div>
</div>
{sub}
<div class="conf-bar">
<div class="conf-fill" style="width:{pct}%;"></div>
</div>
</div>
"""
def format_sources_block(sources: List[Dict]) -> str:
lines = ["Sources:"]
for i, s in enumerate(sources, start=1):
page = s.get("page") or "Page ?"
lines.append(f" [{i}] {s.get('source','PNF-EML_11022022.pdf')} {page} score={s['score']:.3f}")
return "\n".join(lines)
# -------------------------
# MAIN PIPELINE
# -------------------------
def qa_system(question: str, model_key: str):
if not question or not question.strip():
return (
"",
'<div class="meta_box">Detected: β€”</div>',
confidence_bar_html("Retrieval ranking score", 0.0, "β€”"),
""
)
# Retrieve sources
sources = rag_retrieve(question, topk_sources=TOPK_SOURCES)
sources_text = format_sources_block(sources)
# Best candidate passage
best = sources[0]
idx = best["idx"]
canon = meta[idx]["canonical"]
rec = entries[canon]
context = passages[idx]
# Answer
if USE_QA_READER:
try:
ans = run_reader(question, context, model_key).strip()
except Exception as e:
# IMPORTANT: do not show the error to the user
print(f"[Reader error] {repr(e)}")
ans = ""
else:
ans = pretty_answer(rec["forms_and_strengths"])
# Fallback if empty
if not ans:
ans = pretty_answer(rec["forms_and_strengths"])
# Detected drug display (misspelling tolerant)
alias, match_score, how = detect_drug_alias(question)
if alias:
canonical = aliases[alias]
detected_name = entries[canonical]["ingredient"]
detected_page = entries[canonical].get("page", "")
detected_txt = (
f"Detected: {detected_name} | {detected_page} | match={match_score:.2f} ({how})"
if detected_page
else f"Detected: {detected_name} | match={match_score:.2f} ({how})"
)
else:
detected_txt = f"Detected: {rec['ingredient']} | {rec.get('page','')}".strip()
meta_html = f'<div class="meta_box">{detected_txt}</div>'
# Important: this is NOT accuracy; it’s a normalized ranking score (0..1)
conf_html = confidence_bar_html(
"Retrieval ranking score",
float(best["score"]),
f"Reader: {model_key} β€’ Retrieval: {best.get('method','retrieval')} β€’ TopK={TOPK_SOURCES}"
)
return ans, meta_html, conf_html, sources_text
def do_clear():
return "", '<div class="meta_box">Detected: β€”</div>', confidence_bar_html("Retrieval ranking score", 0.0, "β€”"), ""
# -------------------------
# UI (Inter font like your screenshot)
# -------------------------
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700;800&display=swap');
:root{
--bg: #0b0f14;
--card: rgba(255,255,255,0.06);
--card2: rgba(255,255,255,0.08);
--text: #e6edf3;
--muted: rgba(230,237,243,0.72);
--accent: #6d5cff;
--border: rgba(255,255,255,0.12);
}
* { font-family: Inter, system-ui, -apple-system, Segoe UI, Roboto, Arial, sans-serif; }
.gradio-container{
background:
radial-gradient(1200px 500px at 20% 0%, rgba(109,92,255,0.20), transparent 55%),
radial-gradient(1200px 500px at 80% 0%, rgba(0,180,255,0.12), transparent 55%),
linear-gradient(180deg, var(--bg), #06080c);
color: var(--text);
}
#app_wrap{ max-width: 1120px; margin: 0 auto; }
.header{
padding: 18px 18px 8px 18px;
border: 1px solid var(--border);
background: linear-gradient(180deg, rgba(255,255,255,0.08), rgba(255,255,255,0.04));
border-radius: 18px;
}
.brand{ font-size: 28px; font-weight: 800; letter-spacing: 0.2px; }
.card{
border: 1px solid var(--border);
background: var(--card);
border-radius: 18px;
padding: 14px;
}
.card h3{ margin: 0 0 10px 0; font-weight: 800; }
textarea, input{ border-radius: 14px !important; }
button.primary{
background: var(--accent) !important;
border: 1px solid rgba(109,92,255,0.45) !important;
border-radius: 14px !important;
font-weight: 800 !important;
}
button.secondary{
border-radius: 14px !important;
font-weight: 800 !important;
}
.meta_box{
border: 1px solid var(--border);
background: var(--card2);
border-radius: 14px;
padding: 10px 12px;
color: var(--muted);
font-size: 13px;
margin-top: 10px;
}
.conf-wrap{
border: 1px solid var(--border);
background: var(--card2);
border-radius: 14px;
padding: 12px;
margin-top: 12px;
}
.conf-top{
display:flex;
justify-content:space-between;
align-items:baseline;
gap: 12px;
}
.conf-title{ font-weight: 800; font-size: 14px; }
.conf-pct{ font-weight: 900; font-size: 20px; }
.conf-sub{ margin-top: 4px; color: var(--muted); font-size: 12px; }
.conf-bar{
margin-top: 10px;
height: 10px;
border-radius: 999px;
background: rgba(255,255,255,0.10);
overflow: hidden;
}
.conf-fill{
height: 100%;
border-radius: 999px;
background: linear-gradient(90deg, rgba(109,92,255,1), rgba(0,180,255,0.9));
}
.small-note{ color: var(--muted); font-size: 12px; margin-top: 8px; }
"""
with gr.Blocks(title="BonsAI – Pharmaceutical QA System (RAG + Model Switch)") as demo:
with gr.Column(elem_id="app_wrap"):
gr.HTML(
"""
<div class="header">
<div class="brand">BonsAI – Pharmaceutical QA System</div>
</div>
"""
)
gr.Markdown("")
with gr.Row():
with gr.Column(scale=7):
with gr.Group(elem_classes="card"):
gr.HTML("<h3>Ask a Drug Question</h3>")
model_dd = gr.Dropdown(
choices=list(MODEL_CHOICES.keys()),
value=list(MODEL_CHOICES.keys())[0],
label="Select Reader Model"
)
q = gr.Textbox(
placeholder="Example: What are the available forms and strengths of Amoxicillin?",
lines=2,
label="Question",
)
with gr.Row():
clear_btn = gr.Button("Clear", elem_classes=["secondary"])
ask_btn = gr.Button("Submit", variant="primary", elem_classes=["primary"])
tip = "Tip: Misspellings are OK (e.g., amoxicilin, metformn). Switch models using the dropdown."
gr.HTML(f"<div class='small-note'>{tip}</div>")
with gr.Column(scale=5):
with gr.Group(elem_classes="card"):
gr.HTML("<h3>Answer</h3>")
ans = gr.Textbox(label="", lines=7)
meta_html = gr.HTML('<div class="meta_box">Detected: β€”</div>')
conf_html = gr.HTML(confidence_bar_html("Retrieval ranking score", 0.0, "β€”"))
sources_box = gr.Textbox(label="Sources (Top-k)", lines=9)
ask_btn.click(fn=qa_system, inputs=[q, model_dd], outputs=[ans, meta_html, conf_html, sources_box])
q.submit(fn=qa_system, inputs=[q, model_dd], outputs=[ans, meta_html, conf_html, sources_box])
clear_btn.click(fn=do_clear, inputs=None, outputs=[q, meta_html, conf_html, sources_box])
clear_btn.click(lambda: "", inputs=None, outputs=ans)
if __name__ == "__main__":
port = int(os.environ.get("PORT", os.environ.get("GRADIO_SERVER_PORT", 7860)))
demo.launch(
server_name="0.0.0.0",
server_port=port,
share=False,
css=CSS
)