jinv2's picture
Create app.py
af3f7dd verified
raw
history blame
5.01 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import torch
# --- Configuration ---
base_model_id = "Qwen/Qwen-1_8B-Chat"
lora_adapter_id = "jinv2/qwen-1_8b-hemiplegia-lora" # Your HF Model ID
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# --- Load Model and Tokenizer ---
print("Loading tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained(lora_adapter_id, trust_remote_code=True)
print(f"Successfully loaded tokenizer from {lora_adapter_id}.")
except Exception:
print(f"Could not load tokenizer from {lora_adapter_id}, falling back to {base_model_id}.")
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
if tokenizer.pad_token_id is None:
if tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else: # Fallback for Qwen, ensure this ID is correct for your Qwen version
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>") if "<|endoftext|>" in tokenizer.vocab else 0
tokenizer.padding_side = "left" # Important for generation
print("Loading base model with quantization...")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16 # As used in fine-tuning
)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
quantization_config=quantization_config,
trust_remote_code=True,
device_map={"":0} if device == "cuda" else "cpu" # Load directly to GPU if available, else CPU
)
print("Base model loaded.")
print(f"Loading LoRA adapter: {lora_adapter_id}...")
model = PeftModel.from_pretrained(base_model, lora_adapter_id)
model.eval() # Set to evaluation mode
print("LoRA adapter loaded and model is ready.")
if device == "cpu": # If on CPU, PEFT might not automatically move the full model if device_map wasn't used correctly for CPU
model = model.to(device)
print(f"Model explicitly moved to {device}")
# --- Prediction Function ---
def get_response(user_query):
system_prompt_content = "你是一个专注于偏瘫、脑血栓、半身不遂领域的医疗问答助手。"
# Construct prompt using Qwen's ChatML format
prompt = f"<|im_start|>system\n{system_prompt_content}<|im_end|>\n<|im_start|>user\n{user_query}<|im_end|>\n<|im_start|>assistant\n"
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512-150).to(model.device) # Leave space for generation
eos_token_ids_list = []
if isinstance(tokenizer.eos_token_id, int):
eos_token_ids_list.append(tokenizer.eos_token_id)
try:
im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
if im_end_token_id not in eos_token_ids_list:
eos_token_ids_list.append(im_end_token_id)
except KeyError: pass
if not eos_token_ids_list and tokenizer.eos_token_id is not None: # Fallback if list is empty but single eos_token_id exists
eos_token_ids_list = [tokenizer.eos_token_id]
elif not eos_token_ids_list: # Absolute fallback
print("Warning: EOS token ID list is empty. Generation might not stop correctly.")
# Attempt to use a known Qwen EOS ID if possible, otherwise generation might be problematic.
# This scenario should ideally be avoided by robust tokenizer setup.
# eos_token_ids_list = [tokenizer.vocab_size - 1] # Very risky fallback
print(f"Generating response for query: '{user_query}'")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=eos_token_ids_list if eos_token_ids_list else None, # Pass list or None
temperature=0.7,
top_p=0.9,
do_sample=True,
num_beams=1 # Use 1 for sampling, or >1 for beam search (do_sample=False then)
)
response_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
print(f"Raw response: '{response_text}'")
return response_text.strip()
# --- Gradio Interface ---
iface = gr.Interface(
fn=get_response,
inputs=gr.Textbox(lines=3, placeholder="请输入您关于偏瘫、脑血栓或半身不遂的问题...", label="您的问题 (Your Question)"),
outputs=gr.Textbox(lines=5, label="模型回答 (Model Response)"),
title="偏瘫脑血栓问答助手 (Hemiplegia/Stroke Q&A Assistant)",
description="由 Qwen-1.8B-Chat LoRA 微调得到的模型 (jinv2/qwen-1_8b-hemiplegia-lora)。与天算AI相关。**医疗建议请咨询专业医生。**",
examples=[
["偏瘫患者的早期康复锻炼有哪些?"],
["什么是脑血栓?"],
["中风后如何进行语言恢复训练?"]
],
allow_flagging="never" # Disable flagging for simplicity
)
if __name__ == "__main__":
iface.launch()