Spaces:
Paused
Paused
| import asyncio | |
| import gc | |
| import os | |
| import time | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr | |
| import pandas as pd | |
| import torch | |
| import yaml | |
| from huggingface_hub import hf_hub_download | |
| from peft import PeftModel | |
| from PIL import Image | |
| from transformers import AutoTokenizer, LlavaForConditionalGeneration, LlavaProcessor | |
| from src.engine.medical_eval import ( | |
| _build_b1_prompt, | |
| _build_bad_words_ids, | |
| _en_to_vi_direct, | |
| _extract_key_medical_term, | |
| _normalize_closed_answer, | |
| ) | |
| from src.models.medical_vqa_model import MedicalVQAModelA | |
| from src.models.multimodal_vqa import MultimodalVQA | |
| from src.utils.answer_rewriter import MedicalAnswerRewriter | |
| from src.utils.text_utils import normalize_answer, postprocess_answer | |
| from src.utils.translator import MedicalTranslator | |
| from src.utils.visualization import MedicalImageTransform | |
| os.environ.setdefault("ANSWER_REWRITE_MODEL_ID", "Qwen/Qwen2.5-1.5B-Instruct") | |
| os.environ.setdefault("ANSWER_REWRITE_USE_4BIT", "1") | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| ROOT_DIR = Path(__file__).resolve().parent | |
| CONFIG_PATH = ROOT_DIR / "configs" / "medical_vqa.yaml" | |
| VARIANT_ORDER = ["A1", "A2", "B1", "B2", "DPO", "PPO", "SOUP"] | |
| MODEL_DISPLAY_NAMES = { | |
| "A1": "A1 LSTM", | |
| "A2": "A2 Transformer", | |
| "B1": "B1 Zero-shot", | |
| "B2": "B2 Fine-tuned", | |
| "DPO": "DPO Alignment", | |
| "PPO": "PPO RL refinement", | |
| "SOUP": "SOUP Model Soup", | |
| } | |
| HF_MODEL_REPOS = { | |
| "A1": "SpringWang08/medical-vqa-a1", | |
| "A2": "SpringWang08/medical-vqa-a2", | |
| "B1": "chaoyinshe/llava-med-v1.5-mistral-7b-hf", | |
| "B2": "SpringWang08/medical-vqa-b2", | |
| "DPO": "SpringWang08/medical-vqa-dpo", | |
| "PPO": "SpringWang08/medical-vqa-ppo", | |
| "SOUP": "SpringWang08/medical-vqa-soup", | |
| } | |
| with open(CONFIG_PATH, "r", encoding="utf-8") as f: | |
| CFG = yaml.safe_load(f) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ANSWER_MAX_WORDS = int(CFG["data"].get("answer_max_words", 10)) | |
| IMAGE_SIZE = int(CFG["data"].get("image_size", 224)) | |
| MAX_QUESTION_LEN = int(CFG["data"].get("max_question_len", 64)) | |
| MAX_ANSWER_LEN = int(CFG["data"].get("max_answer_len", 20)) | |
| MODEL_A_CFG = CFG.get("model_a", {}) | |
| MODEL_B_CFG = CFG.get("model_b", {}) | |
| EVAL_CFG = CFG.get("eval", {}) | |
| PHOBERT_MODEL = MODEL_A_CFG.get("phobert_model", "vinai/phobert-base") | |
| LLAVA_MODEL_ID = MODEL_B_CFG.get("model_name", HF_MODEL_REPOS["B1"]) | |
| qa_tokenizer = None | |
| image_transform = MedicalImageTransform(size=IMAGE_SIZE) | |
| translator = MedicalTranslator(device=DEVICE.type) | |
| rewriter = MedicalAnswerRewriter() | |
| loaded_a_models: dict[str, dict[str, Any]] = {} | |
| llava_bundle: dict[str, Any] | None = None | |
| b_lock = asyncio.Lock() | |
| def _ensure_qa_tokenizer(): | |
| global qa_tokenizer | |
| if qa_tokenizer is None: | |
| tokenizer = AutoTokenizer.from_pretrained(PHOBERT_MODEL) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token or tokenizer.sep_token | |
| qa_tokenizer = tokenizer | |
| return qa_tokenizer | |
| def _looks_closed_question(question: str) -> bool: | |
| normalized = normalize_answer(question) | |
| closed_prefixes = ( | |
| "có ", | |
| "không ", | |
| "phải ", | |
| "đây có", | |
| "hình ảnh có", | |
| "ảnh có", | |
| "is ", | |
| "are ", | |
| "does ", | |
| "do ", | |
| "can ", | |
| "has ", | |
| ) | |
| open_prefixes = ("what ", "where ", "when ", "who ", "which ", "how ", "why ") | |
| if normalized.startswith(open_prefixes): | |
| return False | |
| if normalized.startswith(closed_prefixes): | |
| return True | |
| return any(word in normalized.split() for word in {"có", "không", "normal", "abnormal"}) | |
| def _prepare_question_text(question: str) -> tuple[str, str]: | |
| question = (question or "").strip() | |
| if not question: | |
| return "", "" | |
| # B1 benefits from English when users provide English; otherwise it still works | |
| # with the concise Vietnamese instruction used in the notebook. | |
| return question, question | |
| def _download_direction_a_checkpoint(variant: str) -> str: | |
| filename = f"medical_vqa_{variant}_best.pth" | |
| local_path = ROOT_DIR / "checkpoints" / filename | |
| if local_path.exists(): | |
| return str(local_path) | |
| return hf_hub_download(repo_id=HF_MODEL_REPOS[variant], filename=filename) | |
| def _ensure_direction_a_model(variant: str) -> dict[str, Any]: | |
| if variant in loaded_a_models: | |
| return loaded_a_models[variant] | |
| tokenizer = _ensure_qa_tokenizer() | |
| ckpt_path = _download_direction_a_checkpoint(variant) | |
| decoder_type = "lstm" if variant == "A1" else "transformer" | |
| model = MedicalVQAModelA( | |
| decoder_type=decoder_type, | |
| vocab_size=len(tokenizer), | |
| hidden_size=int(MODEL_A_CFG.get("hidden_size", 768)), | |
| phobert_model=PHOBERT_MODEL, | |
| ).to(DEVICE) | |
| payload = torch.load(ckpt_path, map_location=DEVICE) | |
| state_dict = payload.get("model_state_dict") if isinstance(payload, dict) and "model_state_dict" in payload else payload | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| bundle = { | |
| "variant": variant, | |
| "family": "A", | |
| "model": model, | |
| "tokenizer": tokenizer, | |
| "checkpoint": HF_MODEL_REPOS[variant], | |
| } | |
| loaded_a_models[variant] = bundle | |
| return bundle | |
| def _build_llava_base_and_processor(): | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("B1/B2/DPO/PPO cần GPU CUDA trên Hugging Face Space.") | |
| wrapper = MultimodalVQA( | |
| model_id=LLAVA_MODEL_ID, | |
| lora_r=int(MODEL_B_CFG.get("lora_r", 16)), | |
| lora_alpha=int(MODEL_B_CFG.get("lora_alpha", 32)), | |
| lora_dropout=float(MODEL_B_CFG.get("lora_dropout", 0.05)), | |
| lora_target_modules=MODEL_B_CFG.get("lora_target_modules"), | |
| ) | |
| processor = LlavaProcessor.from_pretrained(wrapper.model_id) | |
| processor.tokenizer.padding_side = "left" | |
| base_model = LlavaForConditionalGeneration.from_pretrained( | |
| wrapper.model_id, | |
| quantization_config=wrapper.bnb_config, | |
| device_map="auto", | |
| ) | |
| base_model.config.use_cache = False | |
| return wrapper, processor, base_model | |
| def _ensure_llava_bundle() -> dict[str, Any]: | |
| global llava_bundle | |
| if llava_bundle is not None: | |
| return llava_bundle | |
| wrapper, processor, base_model = _build_llava_base_and_processor() | |
| adapter_variants = ["B2", "DPO", "PPO", "SOUP"] | |
| first_variant = adapter_variants[0] | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| HF_MODEL_REPOS[first_variant], | |
| adapter_name=first_variant, | |
| is_trainable=False, | |
| ) | |
| for variant in adapter_variants[1:]: | |
| model.load_adapter(HF_MODEL_REPOS[variant], adapter_name=variant, is_trainable=False) | |
| model.eval() | |
| llava_bundle = { | |
| "family": "B", | |
| "model": model, | |
| "processor": processor, | |
| "wrapper": wrapper, | |
| "checkpoint": LLAVA_MODEL_ID, | |
| "adapter_name_map": {variant: variant for variant in adapter_variants}, | |
| } | |
| return llava_bundle | |
| def _predict_direction_a(bundle: dict[str, Any], question_vi: str, image: Image.Image) -> dict[str, str]: | |
| model = bundle["model"] | |
| tokenizer = bundle["tokenizer"] | |
| image_tensor = image_transform(image.convert("L")).unsqueeze(0).to(DEVICE) | |
| inputs = tokenizer( | |
| question_vi, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=MAX_QUESTION_LEN, | |
| return_tensors="pt", | |
| ) | |
| input_ids = inputs["input_ids"].to(DEVICE) | |
| attention_mask = inputs["attention_mask"].to(DEVICE) | |
| is_closed = _looks_closed_question(question_vi) | |
| with torch.inference_mode(): | |
| logits_closed, pred_ids = model.inference( | |
| image_tensor, | |
| input_ids, | |
| attention_mask, | |
| beam_width=int(EVAL_CFG.get("beam_width_a", 5)), | |
| max_len=MAX_ANSWER_LEN, | |
| ) | |
| if is_closed: | |
| prediction_raw = "có" if logits_closed.argmax(dim=1).item() == 1 else "không" | |
| prediction = prediction_raw | |
| else: | |
| prediction_raw = tokenizer.decode(pred_ids[0], skip_special_tokens=True) | |
| prediction = postprocess_answer(prediction_raw, max_words=ANSWER_MAX_WORDS) | |
| return {"prediction": prediction, "prediction_raw": prediction_raw} | |
| async def _predict_direction_b( | |
| bundle: dict[str, Any], | |
| question_vi: str, | |
| question_en: str, | |
| image: Image.Image, | |
| variant: str, | |
| ) -> dict[str, str]: | |
| model = bundle["model"] | |
| processor = bundle["processor"] | |
| wrapper = bundle["wrapper"] | |
| is_closed = _looks_closed_question(question_vi if variant != "B1" else question_en) | |
| question_for_variant = question_en if variant == "B1" else question_vi | |
| adapter_name = bundle.get("adapter_name_map", {}).get(variant) | |
| if variant == "B1": | |
| prompt = _build_b1_prompt(question_for_variant, ANSWER_MAX_WORDS) | |
| num_beams = int(EVAL_CFG.get("beam_width_b_open", 5)) | |
| max_new_tokens = int(EVAL_CFG.get("max_new_tokens_b_open", ANSWER_MAX_WORDS + 6)) | |
| else: | |
| prompt = wrapper.build_instruction_prompt(question_for_variant, language="vi", include_answer=False) | |
| num_beams = int(EVAL_CFG.get("beam_width_b_closed", 1)) if is_closed else int(EVAL_CFG.get("beam_width_b_open", 5)) | |
| max_new_tokens = ( | |
| int(EVAL_CFG.get("max_new_tokens_b_closed", 4)) | |
| if is_closed | |
| else int(EVAL_CFG.get("max_new_tokens_b_open", ANSWER_MAX_WORDS + 6)) | |
| ) | |
| bad_words_ids = _build_bad_words_ids(processor, variant) | |
| inputs = processor(text=[prompt], images=[image.convert("RGB")], return_tensors="pt", padding=True).to(DEVICE) | |
| if "pixel_values" in inputs: | |
| inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) | |
| async with b_lock: | |
| if adapter_name and hasattr(model, "set_adapter"): | |
| model.set_adapter(adapter_name) | |
| if variant == "B1" and hasattr(model, "disable_adapter"): | |
| context = model.disable_adapter() | |
| else: | |
| context = torch.inference_mode() | |
| with context: | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| num_beams=num_beams, | |
| early_stopping=num_beams > 1, | |
| bad_words_ids=bad_words_ids, | |
| ) | |
| input_token_len = inputs.input_ids.shape[1] | |
| pred_raw = processor.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0].strip() | |
| if variant == "B1": | |
| pred_en = _extract_key_medical_term(pred_raw, 50) | |
| if is_closed: | |
| prediction = _normalize_closed_answer(question_vi, question_en, pred_en, pred_en) | |
| else: | |
| prediction = _en_to_vi_direct(pred_en) | |
| if prediction is None: | |
| prediction = translator.translate_en2vi(pred_en) | |
| prediction = postprocess_answer(prediction, max_words=ANSWER_MAX_WORDS) | |
| else: | |
| prediction = _normalize_closed_answer(question_vi, question_en, pred_raw) if is_closed else pred_raw | |
| prediction = postprocess_answer(prediction, max_words=ANSWER_MAX_WORDS) | |
| return {"prediction": prediction, "prediction_raw": pred_raw} | |
| async def _predict_variant(variant: str, question: str, image: Image.Image) -> dict[str, Any]: | |
| start = time.perf_counter() | |
| try: | |
| question_vi, question_en = _prepare_question_text(question) | |
| if variant in {"A1", "A2"}: | |
| bundle = _ensure_direction_a_model(variant) | |
| out = _predict_direction_a(bundle, question_vi, image) | |
| else: | |
| bundle = _ensure_llava_bundle() | |
| out = await _predict_direction_b(bundle, question_vi, question_en, image, variant) | |
| answer_for_rewrite = out["prediction"] or out["prediction_raw"] | |
| rewritten = rewriter.rewrite( | |
| question=question_vi, | |
| answer=answer_for_rewrite, | |
| language="vi", | |
| source_model=variant, | |
| ) | |
| return { | |
| "model": variant, | |
| "Model": MODEL_DISPLAY_NAMES.get(variant, variant), | |
| "prediction": rewritten, | |
| "Prediction": rewritten, | |
| "prediction_before_rewrite": out["prediction"], | |
| "raw": out["prediction_raw"], | |
| "answer_used_for_rewrite": answer_for_rewrite, | |
| "checkpoint": HF_MODEL_REPOS.get(variant, ""), | |
| "latency_ms": round((time.perf_counter() - start) * 1000, 2), | |
| "status": "ok", | |
| } | |
| except Exception as exc: | |
| return { | |
| "model": variant, | |
| "Model": MODEL_DISPLAY_NAMES.get(variant, variant), | |
| "prediction": "", | |
| "Prediction": "", | |
| "prediction_before_rewrite": "", | |
| "raw": "", | |
| "answer_used_for_rewrite": "", | |
| "checkpoint": HF_MODEL_REPOS.get(variant, ""), | |
| "latency_ms": round((time.perf_counter() - start) * 1000, 2), | |
| "status": f"error: {exc}", | |
| } | |
| finally: | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def predict_all(image: Image.Image, question: str, selected_models: list[str]) -> pd.DataFrame: | |
| if image is None: | |
| raise gr.Error("Vui lòng upload ảnh y khoa.") | |
| if not question or not question.strip(): | |
| raise gr.Error("Vui lòng nhập câu hỏi.") | |
| variants = selected_models or VARIANT_ORDER | |
| async def _run(): | |
| rows = [] | |
| for variant in variants: | |
| rows.append(await _predict_variant(variant, question, image)) | |
| return rows | |
| rows = asyncio.run(_run()) | |
| return pd.DataFrame(rows)[["Model", "Prediction"]] | |
| CSS = """ | |
| .gradio-container { max-width: 1180px !important; } | |
| #run-btn { height: 44px; } | |
| """ | |
| with gr.Blocks(css=CSS, title="Medical VQA Compare") as demo: | |
| gr.Markdown("# Medical VQA Compare") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(label="Ảnh y khoa", type="pil", image_mode="RGB", sources=["upload", "clipboard"]) | |
| question_input = gr.Textbox( | |
| label="Câu hỏi", | |
| value="Hình ảnh này có bất thường không?", | |
| lines=2, | |
| ) | |
| model_input = gr.CheckboxGroup( | |
| label="Model", | |
| choices=VARIANT_ORDER, | |
| value=VARIANT_ORDER, | |
| ) | |
| run_button = gr.Button("Chạy dự đoán", variant="primary", elem_id="run-btn") | |
| with gr.Column(scale=2): | |
| output_table = gr.Dataframe( | |
| label="Kết quả", | |
| headers=[ | |
| "Model", | |
| "Prediction", | |
| ], | |
| wrap=True, | |
| ) | |
| run_button.click( | |
| fn=predict_all, | |
| inputs=[image_input, question_input, model_input], | |
| outputs=output_table, | |
| show_progress="full", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=1).launch(server_name="0.0.0.0", server_port=7860) | |