Spaces:
Sleeping
Sleeping
| # app.py | |
| # BonsAI β Pharmaceutical QA System (DistilBERT, XLM-R) | |
| # pip install -U gradio transformers torch sentence-transformers scikit-learn numpy rapidfuzz safetensors huggingface_hub | |
| # python app.py | |
| import os | |
| import json | |
| import re | |
| import difflib | |
| from typing import List, Dict, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
| from sentence_transformers import SentenceTransformer | |
| try: | |
| from sentence_transformers.cross_encoder import CrossEncoder | |
| except Exception: | |
| CrossEncoder = None | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| # Better fuzzy matching (optional but recommended) | |
| try: | |
| from rapidfuzz import process, fuzz | |
| HAS_RAPIDFUZZ = True | |
| except Exception: | |
| HAS_RAPIDFUZZ = False | |
| # ------------------------- | |
| # CONFIG (EDIT IF NEEDED) | |
| # ------------------------- | |
| CORPUS_PATH = "drug_entries.json" | |
| # HF model repos (your uploaded models) | |
| # These will be downloaded automatically by Transformers inside the Space runtime. | |
| MODEL_CHOICES = { | |
| "DistilBERT (fine-tuned)": "jin3213/distilbert", | |
| "XLM-RoBERTa (fine-tuned)": "jin3213/xlm-roberta", | |
| "ClinicalBERT (fine-tuned)": "jin3213/clinicalbert", | |
| } | |
| # Retrieval models | |
| DENSE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
| RERANK_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" # optional | |
| USE_RERANKER = True # set False if you want faster, fewer deps | |
| TOPK_SOURCES = 5 | |
| FUSION_K = 60 | |
| TOPN_RERANK = 20 | |
| # Cache dense embeddings to disk (helps restart speed) | |
| EMB_CACHE_PATH = "dense_embeddings_cache.npy" | |
| # Drug-name fuzzy detection | |
| FUZZY_DRUG_CUTOFF = 0.75 # 0..1 (stricter = fewer false matches) | |
| # Reader settings | |
| MAX_ANSWER_LEN = 80 | |
| MAX_SEQ_LEN = 384 | |
| DOC_STRIDE = 128 | |
| # If True: answer is extracted by QA reader from retrieved passage | |
| # If False: answer directly returns forms_and_strengths from JSON (no QA) | |
| USE_QA_READER = True | |
| # IMPORTANT: avoids Hugging Face "backend tokenizer" instantiation issues on Spaces | |
| # (keeps everything on slow tokenizers, no conversion attempt) | |
| FORCE_SLOW_TOKENIZER = True | |
| # ------------------------- | |
| # TEXT UTILS | |
| # ------------------------- | |
| def normalize(s: str) -> str: | |
| s = str(s).lower() | |
| s = s.replace("β", "'") | |
| s = re.sub(r"\s+", " ", s).strip() | |
| return s | |
| def normalize_question(q: str) -> str: | |
| q = normalize(q) | |
| q = re.sub(r"[^a-z0-9\s\-\+\/]", " ", q) | |
| q = re.sub(r"\s+", " ", q).strip() | |
| return q | |
| def clean_drug_name(name: str) -> str: | |
| name = normalize(name) | |
| first_line = name.splitlines()[0].strip() | |
| first_line = re.sub(r"\(see [^)]+\)", "", first_line).strip() | |
| first_line = re.sub(r"\([^)]*\)", "", first_line).strip() | |
| first_line = re.sub(r"[^\w\s\+\-\/]", " ", first_line) | |
| first_line = re.sub(r"\s+", " ", first_line).strip() | |
| return first_line | |
| def split_multi_ingredient(raw: str) -> List[str]: | |
| raw_norm = normalize(raw) | |
| parts: List[str] = [] | |
| for line in raw_norm.splitlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| if "+" in line: | |
| for p in line.split("+"): | |
| p = p.strip() | |
| if p: | |
| parts.append(p) | |
| else: | |
| parts.append(line) | |
| return parts | |
| def pretty_answer(text: str) -> str: | |
| t = str(text).strip() | |
| t = t.replace(";", "\n") | |
| t = re.sub(r"\s*\n\s*", "\n", t).strip() | |
| t = re.sub(r"\s*(Oral:)", r"\n\1", t) | |
| t = re.sub(r"\s*(Injection:)", r"\n\1", t) | |
| t = re.sub(r"\s*(Inhalation:)", r"\n\1", t) | |
| t = re.sub(r"\s*(Topical:)", r"\n\1", t) | |
| t = re.sub(r"\n+", "\n", t).strip() | |
| return t | |
| # ------------------------- | |
| # LOAD CORPUS | |
| # ------------------------- | |
| if not os.path.exists(CORPUS_PATH): | |
| raise FileNotFoundError( | |
| f"Cannot find {CORPUS_PATH}. Put drug_entries.json beside app.py (in the Space repo root)." | |
| ) | |
| with open(CORPUS_PATH, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| if not isinstance(data, list): | |
| raise ValueError("drug_entries.json must be a LIST of objects with keys: ingredient, forms_and_strengths, page") | |
| entries: Dict[str, Dict[str, str]] = {} | |
| aliases: Dict[str, str] = {} | |
| passages: List[str] = [] | |
| meta: List[Dict[str, str]] = [] | |
| canonical_keys: List[str] = [] | |
| for obj in data: | |
| if not isinstance(obj, dict): | |
| continue | |
| ingredient_raw = obj.get("ingredient", "") | |
| fas = obj.get("forms_and_strengths", "") | |
| if not ingredient_raw or not fas: | |
| continue | |
| canonical = clean_drug_name(ingredient_raw) | |
| if not canonical: | |
| continue | |
| rec = { | |
| "ingredient": ingredient_raw, | |
| "forms_and_strengths": fas, | |
| "page": obj.get("page", "") | |
| } | |
| entries[canonical] = rec | |
| # aliases for matching | |
| aliases[canonical] = canonical | |
| for part in split_multi_ingredient(ingredient_raw): | |
| base = clean_drug_name(part) | |
| if base: | |
| aliases[base] = canonical | |
| aliases[canonical.replace(" ", "")] = canonical | |
| for canon, rec in entries.items(): | |
| canonical_keys.append(canon) | |
| passages.append(f"{rec['ingredient']}\n{rec['forms_and_strengths']}") | |
| meta.append({ | |
| "canonical": canon, | |
| "ingredient": rec["ingredient"], | |
| "page": rec.get("page", ""), | |
| "source": "PNF-EML_11022022.pdf" | |
| }) | |
| if not entries: | |
| raise ValueError("No valid entries built from drug_entries.json. Check your JSON fields.") | |
| alias_list = sorted(aliases.keys(), key=len, reverse=True) | |
| # ------------------------- | |
| # RETRIEVAL INDEX (Option C) | |
| # ------------------------- | |
| tfidf = TfidfVectorizer( | |
| lowercase=True, | |
| analyzer="word", | |
| ngram_range=(1, 2), | |
| min_df=1 | |
| ) | |
| tfidf_matrix = tfidf.fit_transform(passages) | |
| dense_model = SentenceTransformer(DENSE_MODEL_NAME) | |
| def load_dense_cache(path: str, n_expected: int): | |
| try: | |
| if os.path.exists(path): | |
| arr = np.load(path) | |
| if arr.shape[0] == n_expected: | |
| return arr | |
| except Exception: | |
| pass | |
| return None | |
| def save_dense_cache(path: str, arr: np.ndarray): | |
| try: | |
| np.save(path, arr) | |
| except Exception: | |
| pass | |
| dense_embeddings = load_dense_cache(EMB_CACHE_PATH, len(passages)) | |
| if dense_embeddings is None: | |
| dense_embeddings = dense_model.encode( | |
| passages, | |
| batch_size=64, | |
| show_progress_bar=True, | |
| normalize_embeddings=True | |
| ) | |
| save_dense_cache(EMB_CACHE_PATH, dense_embeddings) | |
| reranker = None | |
| if USE_RERANKER and CrossEncoder is not None: | |
| try: | |
| reranker = CrossEncoder(RERANK_MODEL_NAME) | |
| except Exception: | |
| reranker = None | |
| def sparse_retrieve(query: str, topk: int = 80) -> List[int]: | |
| q_vec = tfidf.transform([query]) | |
| sims = cosine_similarity(q_vec, tfidf_matrix).ravel() | |
| idxs = sims.argsort()[::-1][:topk] | |
| return idxs.tolist() | |
| def dense_retrieve(query: str, topk: int = 80) -> List[int]: | |
| q_emb = dense_model.encode([query], normalize_embeddings=True)[0] | |
| sims = (dense_embeddings @ q_emb).astype(float) | |
| idxs = np.argsort(sims)[::-1][:topk] | |
| return idxs.tolist() | |
| def rrf_fusion(ranks_a: List[int], ranks_b: List[int], k: int = FUSION_K) -> Dict[int, float]: | |
| fused: Dict[int, float] = {} | |
| for rank_list in (ranks_a, ranks_b): | |
| for r, idx in enumerate(rank_list, start=1): | |
| fused[idx] = fused.get(idx, 0.0) + 1.0 / (k + r) | |
| return fused | |
| def minmax_norm(items: List[Tuple[int, float]]) -> List[Tuple[int, float]]: | |
| if not items: | |
| return items | |
| vals = [s for _, s in items] | |
| lo, hi = min(vals), max(vals) | |
| if hi - lo < 1e-9: | |
| return [(i, 1.0) for i, _ in items] | |
| return [(i, (s - lo) / (hi - lo)) for i, s in items] | |
| def rerank(query: str, candidate_idxs: List[int]) -> List[Tuple[int, float]]: | |
| if reranker is None: | |
| return [(i, 0.0) for i in candidate_idxs] | |
| pairs = [(query, passages[i]) for i in candidate_idxs] | |
| scores = reranker.predict(pairs) | |
| ranked = list(zip(candidate_idxs, [float(s) for s in scores])) | |
| ranked.sort(key=lambda x: x[1], reverse=True) | |
| return ranked | |
| def rag_retrieve(query: str, topk_sources: int = TOPK_SOURCES) -> List[Dict]: | |
| q = normalize_question(query) | |
| s_idxs = sparse_retrieve(q, topk=80) | |
| d_idxs = dense_retrieve(q, topk=80) | |
| fused_map = rrf_fusion(s_idxs, d_idxs, k=FUSION_K) | |
| fused_sorted = sorted(fused_map.items(), key=lambda x: x[1], reverse=True) | |
| fused_top = [idx for idx, _ in fused_sorted[:max(TOPN_RERANK, topk_sources)]] | |
| if reranker is None: | |
| fused_items = [(idx, fused_map[idx]) for idx in fused_top] | |
| fused_norm = minmax_norm(fused_items)[:topk_sources] | |
| return [ | |
| {**meta[idx], "idx": idx, "score": float(score), "method": "fusion(RRF)"} | |
| for idx, score in fused_norm | |
| ] | |
| reranked = rerank(q, fused_top) | |
| reranked_norm = minmax_norm(reranked)[:topk_sources] | |
| return [ | |
| {**meta[idx], "idx": idx, "score": float(score), "method": "rerank(cross-encoder)"} | |
| for idx, score in reranked_norm | |
| ] | |
| # ------------------------- | |
| # DRUG DETECTION (for display) | |
| # ------------------------- | |
| def detect_drug_alias(question: str): | |
| q_raw = normalize_question(question) | |
| q = " " + q_raw + " " | |
| q_nospace = q.replace(" ", "") | |
| # Exact/substring match first | |
| for a in alias_list: | |
| if f" {a} " in q or (a and a in q_nospace): | |
| return a, 1.0, "EXACT" | |
| # Strong fuzzy over the whole question (RapidFuzz) if available | |
| if HAS_RAPIDFUZZ: | |
| best = process.extractOne(q_raw, alias_list, scorer=fuzz.WRatio) | |
| if best: | |
| cand, score, _ = best | |
| score01 = float(score) / 100.0 | |
| if score01 >= FUZZY_DRUG_CUTOFF: | |
| return cand, score01, "RAPIDFUZZ" | |
| # Fallback: token-based difflib | |
| tokens = [t for t in q_raw.split() if len(t) >= 4] | |
| best = None | |
| best_score = 0.0 | |
| best_tok = None | |
| for tok in set(tokens): | |
| m = difflib.get_close_matches(tok, alias_list, n=1, cutoff=FUZZY_DRUG_CUTOFF) | |
| if m: | |
| cand = m[0] | |
| score = difflib.SequenceMatcher(None, tok, cand).ratio() | |
| if score > best_score: | |
| best_score = score | |
| best = cand | |
| best_tok = tok | |
| if best: | |
| return best, float(best_score), f"DIFFLIB({best_tok}β{best})" | |
| return None, 0.0, "NONE" | |
| # ------------------------- | |
| # QA READER (MODEL DROPDOWN) - HF REPOS | |
| # ------------------------- | |
| _loaded_readers: Dict[str, Tuple[AutoTokenizer, AutoModelForQuestionAnswering]] = {} | |
| def get_reader(model_key: str) -> Tuple[AutoTokenizer, AutoModelForQuestionAnswering]: | |
| """ | |
| Loads selected HF model repo once, then reuses it. | |
| Works in Hugging Face Spaces without local folders. | |
| Fix: FORCE_SLOW_TOKENIZER prevents tokenizer backend instantiation errors on Spaces. | |
| """ | |
| if model_key in _loaded_readers: | |
| return _loaded_readers[model_key] | |
| model_id = MODEL_CHOICES.get(model_key) | |
| if not model_id: | |
| raise ValueError(f"Unknown model choice: {model_key}") | |
| token = os.getenv("HF_TOKEN", None) | |
| tok_kwargs = {"token": token} | |
| if FORCE_SLOW_TOKENIZER: | |
| tok_kwargs["use_fast"] = False | |
| tok = AutoTokenizer.from_pretrained(model_id, **tok_kwargs) | |
| mdl = AutoModelForQuestionAnswering.from_pretrained(model_id, token=token) | |
| mdl.eval() | |
| _loaded_readers[model_key] = (tok, mdl) | |
| return tok, mdl | |
| def run_reader(question: str, context: str, model_key: str) -> str: | |
| """ | |
| Extractive QA span from context using selected model. | |
| """ | |
| tok, mdl = get_reader(model_key) | |
| inputs = tok( | |
| question, | |
| context, | |
| truncation="only_second", | |
| max_length=MAX_SEQ_LEN, | |
| stride=DOC_STRIDE, | |
| return_overflowing_tokens=False, | |
| return_offsets_mapping=True, | |
| padding="max_length", | |
| return_tensors="pt" | |
| ) | |
| offset_mapping = inputs.pop("offset_mapping")[0].tolist() | |
| outputs = mdl(**inputs) | |
| start_logits = outputs.start_logits[0].detach().cpu().numpy() | |
| end_logits = outputs.end_logits[0].detach().cpu().numpy() | |
| best_score = -1e18 | |
| best_s, best_e = 0, 0 | |
| # Faster span search: only check top candidates | |
| top_start = start_logits.argsort()[-30:][::-1] | |
| top_end = end_logits.argsort()[-30:][::-1] | |
| for s in top_start: | |
| for e in top_end: | |
| if e < s: | |
| continue | |
| if e - s > MAX_ANSWER_LEN: | |
| continue | |
| score = float(start_logits[s] + end_logits[e]) | |
| if score > best_score: | |
| best_score = score | |
| best_s, best_e = int(s), int(e) | |
| start_char, _ = offset_mapping[best_s] | |
| _, end_char = offset_mapping[best_e] | |
| if end_char <= start_char: | |
| return "" | |
| return context[start_char:end_char].strip() | |
| # ------------------------- | |
| # DISPLAY HELPERS | |
| # ------------------------- | |
| def clamp01(x: float) -> float: | |
| return max(0.0, min(1.0, x)) | |
| def confidence_bar_html(label: str, pct01: float, subtitle: str = "") -> str: | |
| pct01 = clamp01(pct01) | |
| pct = int(round(pct01 * 100)) | |
| sub = f"<div class='conf-sub'>{subtitle}</div>" if subtitle else "" | |
| return f""" | |
| <div class="conf-wrap"> | |
| <div class="conf-top"> | |
| <div class="conf-title">{label}</div> | |
| <div class="conf-pct">{pct}%</div> | |
| </div> | |
| {sub} | |
| <div class="conf-bar"> | |
| <div class="conf-fill" style="width:{pct}%;"></div> | |
| </div> | |
| </div> | |
| """ | |
| def format_sources_block(sources: List[Dict]) -> str: | |
| lines = ["Sources:"] | |
| for i, s in enumerate(sources, start=1): | |
| page = s.get("page") or "Page ?" | |
| lines.append(f" [{i}] {s.get('source','PNF-EML_11022022.pdf')} {page} score={s['score']:.3f}") | |
| return "\n".join(lines) | |
| # ------------------------- | |
| # MAIN PIPELINE | |
| # ------------------------- | |
| def qa_system(question: str, model_key: str): | |
| if not question or not question.strip(): | |
| return ( | |
| "", | |
| '<div class="meta_box">Detected: β</div>', | |
| confidence_bar_html("Retrieval ranking score", 0.0, "β"), | |
| "" | |
| ) | |
| # Retrieve sources | |
| sources = rag_retrieve(question, topk_sources=TOPK_SOURCES) | |
| sources_text = format_sources_block(sources) | |
| # Best candidate passage | |
| best = sources[0] | |
| idx = best["idx"] | |
| canon = meta[idx]["canonical"] | |
| rec = entries[canon] | |
| context = passages[idx] | |
| # Answer | |
| if USE_QA_READER: | |
| try: | |
| ans = run_reader(question, context, model_key).strip() | |
| except Exception as e: | |
| # IMPORTANT: do not show the error to the user | |
| print(f"[Reader error] {repr(e)}") | |
| ans = "" | |
| else: | |
| ans = pretty_answer(rec["forms_and_strengths"]) | |
| # Fallback if empty | |
| if not ans: | |
| ans = pretty_answer(rec["forms_and_strengths"]) | |
| # Detected drug display (misspelling tolerant) | |
| alias, match_score, how = detect_drug_alias(question) | |
| if alias: | |
| canonical = aliases[alias] | |
| detected_name = entries[canonical]["ingredient"] | |
| detected_page = entries[canonical].get("page", "") | |
| detected_txt = ( | |
| f"Detected: {detected_name} | {detected_page} | match={match_score:.2f} ({how})" | |
| if detected_page | |
| else f"Detected: {detected_name} | match={match_score:.2f} ({how})" | |
| ) | |
| else: | |
| detected_txt = f"Detected: {rec['ingredient']} | {rec.get('page','')}".strip() | |
| meta_html = f'<div class="meta_box">{detected_txt}</div>' | |
| # Important: this is NOT accuracy; itβs a normalized ranking score (0..1) | |
| conf_html = confidence_bar_html( | |
| "Retrieval ranking score", | |
| float(best["score"]), | |
| f"Reader: {model_key} β’ Retrieval: {best.get('method','retrieval')} β’ TopK={TOPK_SOURCES}" | |
| ) | |
| return ans, meta_html, conf_html, sources_text | |
| def do_clear(): | |
| return "", '<div class="meta_box">Detected: β</div>', confidence_bar_html("Retrieval ranking score", 0.0, "β"), "" | |
| # ------------------------- | |
| # UI (Inter font like your screenshot) | |
| # ------------------------- | |
| CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700;800&display=swap'); | |
| :root{ | |
| --bg: #0b0f14; | |
| --card: rgba(255,255,255,0.06); | |
| --card2: rgba(255,255,255,0.08); | |
| --text: #e6edf3; | |
| --muted: rgba(230,237,243,0.72); | |
| --accent: #6d5cff; | |
| --border: rgba(255,255,255,0.12); | |
| } | |
| * { font-family: Inter, system-ui, -apple-system, Segoe UI, Roboto, Arial, sans-serif; } | |
| .gradio-container{ | |
| background: | |
| radial-gradient(1200px 500px at 20% 0%, rgba(109,92,255,0.20), transparent 55%), | |
| radial-gradient(1200px 500px at 80% 0%, rgba(0,180,255,0.12), transparent 55%), | |
| linear-gradient(180deg, var(--bg), #06080c); | |
| color: var(--text); | |
| } | |
| #app_wrap{ max-width: 1120px; margin: 0 auto; } | |
| .header{ | |
| padding: 18px 18px 8px 18px; | |
| border: 1px solid var(--border); | |
| background: linear-gradient(180deg, rgba(255,255,255,0.08), rgba(255,255,255,0.04)); | |
| border-radius: 18px; | |
| } | |
| .brand{ font-size: 28px; font-weight: 800; letter-spacing: 0.2px; } | |
| .card{ | |
| border: 1px solid var(--border); | |
| background: var(--card); | |
| border-radius: 18px; | |
| padding: 14px; | |
| } | |
| .card h3{ margin: 0 0 10px 0; font-weight: 800; } | |
| textarea, input{ border-radius: 14px !important; } | |
| button.primary{ | |
| background: var(--accent) !important; | |
| border: 1px solid rgba(109,92,255,0.45) !important; | |
| border-radius: 14px !important; | |
| font-weight: 800 !important; | |
| } | |
| button.secondary{ | |
| border-radius: 14px !important; | |
| font-weight: 800 !important; | |
| } | |
| .meta_box{ | |
| border: 1px solid var(--border); | |
| background: var(--card2); | |
| border-radius: 14px; | |
| padding: 10px 12px; | |
| color: var(--muted); | |
| font-size: 13px; | |
| margin-top: 10px; | |
| } | |
| .conf-wrap{ | |
| border: 1px solid var(--border); | |
| background: var(--card2); | |
| border-radius: 14px; | |
| padding: 12px; | |
| margin-top: 12px; | |
| } | |
| .conf-top{ | |
| display:flex; | |
| justify-content:space-between; | |
| align-items:baseline; | |
| gap: 12px; | |
| } | |
| .conf-title{ font-weight: 800; font-size: 14px; } | |
| .conf-pct{ font-weight: 900; font-size: 20px; } | |
| .conf-sub{ margin-top: 4px; color: var(--muted); font-size: 12px; } | |
| .conf-bar{ | |
| margin-top: 10px; | |
| height: 10px; | |
| border-radius: 999px; | |
| background: rgba(255,255,255,0.10); | |
| overflow: hidden; | |
| } | |
| .conf-fill{ | |
| height: 100%; | |
| border-radius: 999px; | |
| background: linear-gradient(90deg, rgba(109,92,255,1), rgba(0,180,255,0.9)); | |
| } | |
| .small-note{ color: var(--muted); font-size: 12px; margin-top: 8px; } | |
| """ | |
| with gr.Blocks(title="BonsAI β Pharmaceutical QA System (RAG + Model Switch)") as demo: | |
| with gr.Column(elem_id="app_wrap"): | |
| gr.HTML( | |
| """ | |
| <div class="header"> | |
| <div class="brand">BonsAI β Pharmaceutical QA System</div> | |
| </div> | |
| """ | |
| ) | |
| gr.Markdown("") | |
| with gr.Row(): | |
| with gr.Column(scale=7): | |
| with gr.Group(elem_classes="card"): | |
| gr.HTML("<h3>Ask a Drug Question</h3>") | |
| model_dd = gr.Dropdown( | |
| choices=list(MODEL_CHOICES.keys()), | |
| value=list(MODEL_CHOICES.keys())[0], | |
| label="Select Reader Model" | |
| ) | |
| q = gr.Textbox( | |
| placeholder="Example: What are the available forms and strengths of Amoxicillin?", | |
| lines=2, | |
| label="Question", | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear", elem_classes=["secondary"]) | |
| ask_btn = gr.Button("Submit", variant="primary", elem_classes=["primary"]) | |
| tip = "Tip: Misspellings are OK (e.g., amoxicilin, metformn). Switch models using the dropdown." | |
| gr.HTML(f"<div class='small-note'>{tip}</div>") | |
| with gr.Column(scale=5): | |
| with gr.Group(elem_classes="card"): | |
| gr.HTML("<h3>Answer</h3>") | |
| ans = gr.Textbox(label="", lines=7) | |
| meta_html = gr.HTML('<div class="meta_box">Detected: β</div>') | |
| conf_html = gr.HTML(confidence_bar_html("Retrieval ranking score", 0.0, "β")) | |
| sources_box = gr.Textbox(label="Sources (Top-k)", lines=9) | |
| ask_btn.click(fn=qa_system, inputs=[q, model_dd], outputs=[ans, meta_html, conf_html, sources_box]) | |
| q.submit(fn=qa_system, inputs=[q, model_dd], outputs=[ans, meta_html, conf_html, sources_box]) | |
| clear_btn.click(fn=do_clear, inputs=None, outputs=[q, meta_html, conf_html, sources_box]) | |
| clear_btn.click(lambda: "", inputs=None, outputs=ans) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", os.environ.get("GRADIO_SERVER_PORT", 7860))) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=port, | |
| share=False, | |
| css=CSS | |
| ) |