Medical-VQA / src /models /multimodal_vqa.py
SpringWang08's picture
Deploy Medical VQA app
d63774a
import torch
from transformers import LlavaProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
class MultimodalVQA:
"""
Wrapper cho LLaVA-Med-7B tích hợp QLoRA 4-bit để huấn luyện trên Kaggle.
Sử dụng kiến trúc LLaVA-1.5 (microsoft/llava-med-v1.5-7b).
"""
def __init__(
self,
model_id="chaoyinshe/llava-med-v1.5-mistral-7b-hf",
lora_r=16,
lora_alpha=32,
lora_dropout=0.05,
lora_target_modules=None,
):
self.model_id = model_id
# 1. Cấu hình Quantization 4-bit (Tiết kiệm VRAM)
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# 2. Cấu hình LoRA (Chỉ huấn luyện một phần nhỏ tham số)
self.peft_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules or ["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
def load_model(self, adapter_path=None, is_trainable=True):
print(f"[INFO] Đang tải LLaVA-Med-v1.5-7B với chế độ 4-bit...")
processor = LlavaProcessor.from_pretrained(self.model_id)
processor.tokenizer.padding_side = "left" # Bắt buộc cho decoder-only models
model = LlavaForConditionalGeneration.from_pretrained(
self.model_id,
quantization_config=self.bnb_config,
device_map="auto"
)
model.config.use_cache = False
# Chuẩn bị mô hình cho PEFT
model = prepare_model_for_kbit_training(model)
if adapter_path:
print(f"[INFO] Đang nạp adapter LoRA từ: {adapter_path}")
model = PeftModel.from_pretrained(model, adapter_path, is_trainable=is_trainable)
else:
model = get_peft_model(model, self.peft_config)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
model.print_trainable_parameters()
return model, processor
def generate_prompt_vi(self, question_en):
"""
Hàm hỗ trợ tạo prompt cho LLaVA-Med (EN).
Nhớ dùng Translation Layer trước khi gọi hàm này.
"""
return self.build_instruction_prompt(question_en, language="en", include_answer=False)
def build_instruction_prompt(self, question, language="vi", include_answer=False):
"""
Prompt thống nhất cho zero-shot, SFT và demo.
"""
if language == "vi":
instruction = "Chi tra loi bang tieng Viet, khong dung tieng Anh, thuat ngu y khoa chuan, ngan gon, toi da 10 tu."
else:
instruction = "Answer with standard medical terminology, concise, at most 10 words."
suffix = " ASSISTANT:" if not include_answer else ""
return f"USER: <image>\n{question}\n{instruction}{suffix}"