| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import annotations |
| |
|
| | import json |
| | import os |
| | import random |
| | import re |
| | from typing import Dict, Any, List, Optional |
| |
|
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | DEFAULT_MODEL = os.getenv( |
| | "BACTAI_LLM_PARSER_MODEL", |
| | "EphAsad/BactAID-v2", |
| | ) |
| |
|
| | |
| | MAX_FEWSHOT_EXAMPLES = int(os.getenv("BACTAI_LLM_FEWSHOT", "0")) |
| |
|
| | MAX_NEW_TOKENS = int(os.getenv("BACTAI_LLM_MAX_NEW_TOKENS", "256")) |
| |
|
| | DEBUG_LLM = os.getenv("BACTAI_LLM_DEBUG", "0").strip().lower() in { |
| | "1", "true", "yes", "y", "on" |
| | } |
| |
|
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | _tokenizer: Optional[AutoTokenizer] = None |
| | _model: Optional[AutoModelForSeq2SeqLM] = None |
| | _GOLD_EXAMPLES: Optional[List[Dict[str, Any]]] = None |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | ALL_FIELDS: List[str] = [ |
| | "Gram Stain", |
| | "Shape", |
| | "Motility", |
| | "Capsule", |
| | "Spore Formation", |
| | "Haemolysis", |
| | "Haemolysis Type", |
| | "Media Grown On", |
| | "Colony Morphology", |
| | "Oxygen Requirement", |
| | "Growth Temperature", |
| | "Catalase", |
| | "Oxidase", |
| | "Indole", |
| | "Urease", |
| | "Citrate", |
| | "Methyl Red", |
| | "VP", |
| | "H2S", |
| | "DNase", |
| | "ONPG", |
| | "Coagulase", |
| | "Gelatin Hydrolysis", |
| | "Esculin Hydrolysis", |
| | "Nitrate Reduction", |
| | "NaCl Tolerant (>=6%)", |
| | "Lipase Test", |
| | "Lysine Decarboxylase", |
| | "Ornithine Decarboxylase", |
| | "Ornitihine Decarboxylase", |
| | "Arginine dihydrolase", |
| | "Glucose Fermentation", |
| | "Lactose Fermentation", |
| | "Sucrose Fermentation", |
| | "Maltose Fermentation", |
| | "Mannitol Fermentation", |
| | "Sorbitol Fermentation", |
| | "Xylose Fermentation", |
| | "Rhamnose Fermentation", |
| | "Arabinose Fermentation", |
| | "Raffinose Fermentation", |
| | "Trehalose Fermentation", |
| | "Inositol Fermentation", |
| | "Gas Production", |
| | "TSI Pattern", |
| | "Colony Pattern", |
| | "Pigment", |
| | "Motility Type", |
| | "Odor", |
| | ] |
| |
|
| | SUGAR_FIELDS = [ |
| | "Glucose Fermentation", |
| | "Lactose Fermentation", |
| | "Sucrose Fermentation", |
| | "Maltose Fermentation", |
| | "Mannitol Fermentation", |
| | "Sorbitol Fermentation", |
| | "Xylose Fermentation", |
| | "Rhamnose Fermentation", |
| | "Arabinose Fermentation", |
| | "Raffinose Fermentation", |
| | "Trehalose Fermentation", |
| | "Inositol Fermentation", |
| | ] |
| |
|
| | PNV_FIELDS = { |
| | f for f in ALL_FIELDS |
| | if f not in { |
| | "Media Grown On", |
| | "Colony Morphology", |
| | "Growth Temperature", |
| | "Gram Stain", |
| | "Shape", |
| | "Oxygen Requirement", |
| | "Haemolysis Type", |
| | "TSI Pattern", |
| | "Colony Pattern", |
| | "Motility Type", |
| | "Odor", |
| | "Pigment", |
| | "Gas Production", |
| | } |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | FIELD_ALIASES: Dict[str, str] = { |
| | "Gram": "Gram Stain", |
| | "Gram stain": "Gram Stain", |
| | "Gram Stain Result": "Gram Stain", |
| |
|
| | "NaCl tolerance": "NaCl Tolerant (>=6%)", |
| | "NaCl Tolerant": "NaCl Tolerant (>=6%)", |
| | "Salt tolerance": "NaCl Tolerant (>=6%)", |
| | "Salt tolerant": "NaCl Tolerant (>=6%)", |
| | "6.5% NaCl": "NaCl Tolerant (>=6%)", |
| | "6% NaCl": "NaCl Tolerant (>=6%)", |
| |
|
| | "Growth temp": "Growth Temperature", |
| | "Growth temperature": "Growth Temperature", |
| | "Temperature growth": "Growth Temperature", |
| |
|
| | "Catalase test": "Catalase", |
| | "Oxidase test": "Oxidase", |
| | "Indole test": "Indole", |
| | "Urease test": "Urease", |
| | "Citrate test": "Citrate", |
| |
|
| | "Glucose fermentation": "Glucose Fermentation", |
| | "Lactose fermentation": "Lactose Fermentation", |
| | "Sucrose fermentation": "Sucrose Fermentation", |
| | "Maltose fermentation": "Maltose Fermentation", |
| | "Mannitol fermentation": "Mannitol Fermentation", |
| | "Sorbitol fermentation": "Sorbitol Fermentation", |
| | "Xylose fermentation": "Xylose Fermentation", |
| | "Rhamnose fermentation": "Rhamnose Fermentation", |
| | "Arabinose fermentation": "Arabinose Fermentation", |
| | "Raffinose fermentation": "Raffinose Fermentation", |
| | "Trehalose fermentation": "Trehalose Fermentation", |
| | "Inositol fermentation": "Inositol Fermentation", |
| |
|
| | |
| | "Voges–Proskauer Test": "VP", |
| | "Voges-Proskauer Test": "VP", |
| | "Voges–Proskauer": "VP", |
| | "Voges-Proskauer": "VP", |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _norm_str(s: Any) -> str: |
| | return str(s).strip() if s is not None else "" |
| |
|
| |
|
| | def _normalise_pnv_value(raw: Any) -> str: |
| | s = _norm_str(raw).lower() |
| | if not s: |
| | return "Unknown" |
| |
|
| | |
| | if any(x in s for x in {"positive", "pos", "+", "yes", "present", "detected", "reactive"}): |
| | return "Positive" |
| |
|
| | |
| | if any(x in s for x in {"negative", "neg", "-", "no", "none", "absent", "not detected", "no growth"}): |
| | return "Negative" |
| |
|
| | |
| | if any(x in s for x in {"variable", "mixed", "inconsistent"}): |
| | return "Variable" |
| |
|
| | return "Unknown" |
| |
|
| |
|
| | def _normalise_gram(raw: Any) -> str: |
| | s = _norm_str(raw).lower() |
| | if "positive" in s: |
| | return "Positive" |
| | if "negative" in s: |
| | return "Negative" |
| | if "variable" in s: |
| | return "Variable" |
| | return "Unknown" |
| |
|
| |
|
| | def _merge_ornithine_variants(fields: Dict[str, str]) -> Dict[str, str]: |
| | v = fields.get("Ornithine Decarboxylase") or fields.get("Ornitihine Decarboxylase") |
| | if v and v != "Unknown": |
| | fields["Ornithine Decarboxylase"] = v |
| | fields["Ornitihine Decarboxylase"] = v |
| | return fields |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | _NON_FERMENTER_PATTERNS = re.compile( |
| | r"\b(" |
| | r"non[-\s]?fermenter|" |
| | r"non[-\s]?fermentative|" |
| | r"asaccharolytic|" |
| | r"does not ferment (sugars|carbohydrates)|" |
| | r"no carbohydrate fermentation" |
| | r")\b", |
| | re.IGNORECASE, |
| | ) |
| |
|
| |
|
| | def _apply_global_sugar_logic(fields: Dict[str, str], original_text: str) -> Dict[str, str]: |
| | if not _NON_FERMENTER_PATTERNS.search(original_text): |
| | return fields |
| |
|
| | for sugar in SUGAR_FIELDS: |
| | if fields.get(sugar) in {"Positive", "Variable"}: |
| | continue |
| | fields[sugar] = "Negative" |
| |
|
| | return fields |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _get_project_root() -> str: |
| | return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| |
|
| |
|
| | def _load_gold_examples() -> List[Dict[str, Any]]: |
| | global _GOLD_EXAMPLES |
| | if _GOLD_EXAMPLES is not None: |
| | return _GOLD_EXAMPLES |
| |
|
| | path = os.path.join(_get_project_root(), "data", "llm_gold_examples.json") |
| | try: |
| | with open(path, "r", encoding="utf-8") as f: |
| | data = json.load(f) |
| | _GOLD_EXAMPLES = data if isinstance(data, list) else [] |
| | except Exception: |
| | _GOLD_EXAMPLES = [] |
| |
|
| | return _GOLD_EXAMPLES |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | PROMPT_HEADER = """ |
| | You are a microbiology phenotype parser. |
| | |
| | Task: |
| | - Extract ONLY explicitly stated results from the input text. |
| | - Do NOT invent results. |
| | - If not stated, omit the field or use "Unknown". |
| | |
| | Output format: |
| | - Prefer "Field: Value" lines, one per line. |
| | - You may also output JSON if instructed. |
| | |
| | Use the exact schema keys where possible. |
| | """ |
| |
|
| | PROMPT_FOOTER = """ |
| | Input: |
| | \"\"\"<<PHENOTYPE>>\"\"\" |
| | |
| | Output: |
| | """ |
| |
|
| |
|
| | def _build_prompt(text: str) -> str: |
| | |
| | blocks: List[str] = [PROMPT_HEADER] |
| |
|
| | if MAX_FEWSHOT_EXAMPLES > 0: |
| | examples = _load_gold_examples() |
| | n = min(MAX_FEWSHOT_EXAMPLES, len(examples)) |
| | sampled = random.sample(examples, n) if n > 0 else [] |
| | for ex in sampled: |
| | inp = _norm_str(ex.get("input", "")) |
| | exp = ex.get("expected", {}) |
| | if not isinstance(exp, dict): |
| | exp = {} |
| | |
| | kv_lines = "\n".join([f"{k}: {v}" for k, v in exp.items()]) |
| | blocks.append(f'Example Input:\n"""{inp}"""\nExample Output:\n{kv_lines}\n') |
| |
|
| | blocks.append(PROMPT_FOOTER.replace("<<PHENOTYPE>>", text)) |
| | return "\n".join(blocks) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _load_model() -> None: |
| | global _model, _tokenizer |
| | if _model is not None and _tokenizer is not None: |
| | return |
| |
|
| | _tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL) |
| | _model = AutoModelForSeq2SeqLM.from_pretrained(DEFAULT_MODEL).to(DEVICE) |
| | _model.eval() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | _JSON_OBJECT_RE = re.compile(r"\{[\s\S]*?\}") |
| |
|
| |
|
| | def _extract_first_json_object(text: str) -> Dict[str, Any]: |
| | m = _JSON_OBJECT_RE.search(text) |
| | if not m: |
| | return {} |
| | try: |
| | return json.loads(m.group(0)) |
| | except Exception: |
| | return {} |
| |
|
| |
|
| | |
| | _KV_LINE_RE = re.compile(r"^\s*([^:\n]{2,120})\s*:\s*(.*?)\s*$") |
| |
|
| |
|
| | def _extract_kv_pairs(text: str) -> Dict[str, Any]: |
| | """ |
| | Parse outputs like: |
| | Gram Stain: Positive |
| | Shape: Cocci |
| | ... |
| | """ |
| | out: Dict[str, Any] = {} |
| | for line in (text or "").splitlines(): |
| | line = line.strip() |
| | if not line: |
| | continue |
| | m = _KV_LINE_RE.match(line) |
| | if not m: |
| | continue |
| | k = _norm_str(m.group(1)) |
| | v = _norm_str(m.group(2)) |
| | if not k: |
| | continue |
| | out[k] = v |
| | return out |
| |
|
| |
|
| | def _apply_field_aliases(fields_raw: Dict[str, Any]) -> Dict[str, Any]: |
| | out: Dict[str, Any] = {} |
| | for k, v in fields_raw.items(): |
| | key = _norm_str(k) |
| | if not key: |
| | continue |
| | mapped = FIELD_ALIASES.get(key, key) |
| | out[mapped] = v |
| | return out |
| |
|
| |
|
| | def _clean_and_normalise(fields_raw: Dict[str, Any], original_text: str) -> Dict[str, str]: |
| | """ |
| | Keep only allowed fields and normalise values into your contract. |
| | """ |
| | cleaned: Dict[str, str] = {} |
| |
|
| | |
| | for field in ALL_FIELDS: |
| | if field not in fields_raw: |
| | continue |
| |
|
| | raw_val = fields_raw[field] |
| |
|
| | if field == "Gram Stain": |
| | cleaned[field] = _normalise_gram(raw_val) |
| | elif field in PNV_FIELDS: |
| | cleaned[field] = _normalise_pnv_value(raw_val) |
| | else: |
| | cleaned[field] = _norm_str(raw_val) or "Unknown" |
| |
|
| | cleaned = _merge_ornithine_variants(cleaned) |
| | cleaned = _apply_global_sugar_logic(cleaned, original_text) |
| | return cleaned |
| |
|
| |
|
| | def _merge_guard_fill_only_missing( |
| | llm_fields: Dict[str, str], |
| | existing_fields: Optional[Dict[str, Any]], |
| | ) -> Dict[str, str]: |
| | """ |
| | Merge guard: |
| | - If an existing field is present and not Unknown -> do NOT overwrite. |
| | - If existing is missing/Unknown -> allow llm value (if not Unknown). |
| | """ |
| | if not existing_fields or not isinstance(existing_fields, dict): |
| | return llm_fields |
| |
|
| | out = dict(existing_fields) |
| | for k, v in llm_fields.items(): |
| | if k not in ALL_FIELDS: |
| | continue |
| | existing_val = _norm_str(out.get(k, "")) |
| | existing_norm = _normalise_pnv_value(existing_val) if k in PNV_FIELDS else existing_val |
| |
|
| | |
| | fillable = (not existing_val) or (existing_val == "Unknown") or (existing_norm == "Unknown") |
| | if not fillable: |
| | continue |
| |
|
| | |
| | if _norm_str(v) and v != "Unknown": |
| | out[k] = v |
| |
|
| | |
| | final: Dict[str, str] = {} |
| | for k, v in out.items(): |
| | if k in ALL_FIELDS: |
| | final[k] = _norm_str(v) or "Unknown" |
| | return final |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def parse_llm(text: str, existing_fields: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: |
| | """ |
| | Parse phenotype text using local seq2seq model. |
| | |
| | Parameters |
| | ---------- |
| | text : str |
| | phenotype description |
| | |
| | existing_fields : dict | None |
| | Optional pre-parsed fields (e.g., from rules/ext). |
| | If provided, LLM will ONLY fill missing/Unknown fields. |
| | |
| | Returns |
| | ------- |
| | dict: |
| | { |
| | "parsed_fields": { ... }, |
| | "source": "llm_parser", |
| | "raw": <original text>, |
| | "decoded": <model output> (only when DEBUG on) |
| | } |
| | """ |
| | original = text or "" |
| | if not original.strip(): |
| | return { |
| | "parsed_fields": (existing_fields or {}) if isinstance(existing_fields, dict) else {}, |
| | "source": "llm_parser", |
| | "raw": original, |
| | } |
| |
|
| | _load_model() |
| | assert _tokenizer is not None and _model is not None |
| |
|
| | prompt = _build_prompt(original) |
| | inputs = _tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE) |
| |
|
| | with torch.no_grad(): |
| | output = _model.generate( |
| | **inputs, |
| | max_new_tokens=MAX_NEW_TOKENS, |
| | do_sample=False, |
| | temperature=0.0, |
| | ) |
| |
|
| | decoded = _tokenizer.decode(output[0], skip_special_tokens=True) |
| |
|
| | if DEBUG_LLM: |
| | print("=== LLM PROMPT (truncated) ===") |
| | print(prompt[:1500] + ("..." if len(prompt) > 1500 else "")) |
| | print("=== LLM RAW OUTPUT ===") |
| | print(decoded) |
| | print("======================") |
| |
|
| | |
| | parsed_obj = _extract_first_json_object(decoded) |
| | fields_raw = {} |
| |
|
| | if isinstance(parsed_obj, dict) and parsed_obj: |
| | if "parsed_fields" in parsed_obj and isinstance(parsed_obj.get("parsed_fields"), dict): |
| | fields_raw = dict(parsed_obj["parsed_fields"]) |
| | else: |
| | |
| | fields_raw = dict(parsed_obj) |
| |
|
| | |
| | if not fields_raw: |
| | fields_raw = _extract_kv_pairs(decoded) |
| |
|
| | |
| | fields_raw = _apply_field_aliases(fields_raw) |
| | cleaned = _clean_and_normalise(fields_raw, original) |
| |
|
| | |
| | if existing_fields is not None: |
| | cleaned = _merge_guard_fill_only_missing(cleaned, existing_fields) |
| |
|
| | out = { |
| | "parsed_fields": cleaned, |
| | "source": "llm_parser", |
| | "raw": original, |
| | } |
| | if DEBUG_LLM: |
| | out["decoded"] = decoded |
| | return out |