File size: 5,011 Bytes
af3f7dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import torch

# --- Configuration ---
base_model_id = "Qwen/Qwen-1_8B-Chat"
lora_adapter_id = "jinv2/qwen-1_8b-hemiplegia-lora" # Your HF Model ID
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# --- Load Model and Tokenizer ---
print("Loading tokenizer...")
try:
    tokenizer = AutoTokenizer.from_pretrained(lora_adapter_id, trust_remote_code=True)
    print(f"Successfully loaded tokenizer from {lora_adapter_id}.")
except Exception:
    print(f"Could not load tokenizer from {lora_adapter_id}, falling back to {base_model_id}.")
    tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)

if tokenizer.pad_token_id is None:
    if tokenizer.eos_token_id is not None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    else: # Fallback for Qwen, ensure this ID is correct for your Qwen version
        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>") if "<|endoftext|>" in tokenizer.vocab else 0

tokenizer.padding_side = "left" # Important for generation

print("Loading base model with quantization...")
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16 # As used in fine-tuning
)
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    quantization_config=quantization_config,
    trust_remote_code=True,
    device_map={"":0} if device == "cuda" else "cpu" # Load directly to GPU if available, else CPU
)
print("Base model loaded.")

print(f"Loading LoRA adapter: {lora_adapter_id}...")
model = PeftModel.from_pretrained(base_model, lora_adapter_id)
model.eval() # Set to evaluation mode
print("LoRA adapter loaded and model is ready.")
if device == "cpu": # If on CPU, PEFT might not automatically move the full model if device_map wasn't used correctly for CPU
    model = model.to(device)
    print(f"Model explicitly moved to {device}")


# --- Prediction Function ---
def get_response(user_query):
    system_prompt_content = "你是一个专注于偏瘫、脑血栓、半身不遂领域的医疗问答助手。"
    
    # Construct prompt using Qwen's ChatML format
    prompt = f"<|im_start|>system\n{system_prompt_content}<|im_end|>\n<|im_start|>user\n{user_query}<|im_end|>\n<|im_start|>assistant\n"
    
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512-150).to(model.device) # Leave space for generation

    eos_token_ids_list = []
    if isinstance(tokenizer.eos_token_id, int):
        eos_token_ids_list.append(tokenizer.eos_token_id)
    try:
        im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
        if im_end_token_id not in eos_token_ids_list:
            eos_token_ids_list.append(im_end_token_id)
    except KeyError: pass
    if not eos_token_ids_list and tokenizer.eos_token_id is not None: # Fallback if list is empty but single eos_token_id exists
         eos_token_ids_list = [tokenizer.eos_token_id]
    elif not eos_token_ids_list: # Absolute fallback
        print("Warning: EOS token ID list is empty. Generation might not stop correctly.")
        # Attempt to use a known Qwen EOS ID if possible, otherwise generation might be problematic.
        # This scenario should ideally be avoided by robust tokenizer setup.
        # eos_token_ids_list = [tokenizer.vocab_size - 1] # Very risky fallback

    print(f"Generating response for query: '{user_query}'")
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=150,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=eos_token_ids_list if eos_token_ids_list else None, # Pass list or None
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            num_beams=1 # Use 1 for sampling, or >1 for beam search (do_sample=False then)
        )
    
    response_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    print(f"Raw response: '{response_text}'")
    return response_text.strip()

# --- Gradio Interface ---
iface = gr.Interface(
    fn=get_response,
    inputs=gr.Textbox(lines=3, placeholder="请输入您关于偏瘫、脑血栓或半身不遂的问题...", label="您的问题 (Your Question)"),
    outputs=gr.Textbox(lines=5, label="模型回答 (Model Response)"),
    title="偏瘫脑血栓问答助手 (Hemiplegia/Stroke Q&A Assistant)",
    description="由 Qwen-1.8B-Chat LoRA 微调得到的模型 (jinv2/qwen-1_8b-hemiplegia-lora)。与天算AI相关。**医疗建议请咨询专业医生。**",
    examples=[
        ["偏瘫患者的早期康复锻炼有哪些?"],
        ["什么是脑血栓?"],
        ["中风后如何进行语言恢复训练?"]
    ],
    allow_flagging="never" # Disable flagging for simplicity
)

if __name__ == "__main__":
    iface.launch()