Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
from peft import PeftModel | |
import torch | |
# --- Configuration --- | |
base_model_id = "Qwen/Qwen-1_8B-Chat" | |
lora_adapter_id = "jinv2/qwen-1_8b-hemiplegia-lora" # Your HF Model ID | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# --- Load Model and Tokenizer --- | |
print("Loading tokenizer...") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(lora_adapter_id, trust_remote_code=True) | |
print(f"Successfully loaded tokenizer from {lora_adapter_id}.") | |
except Exception: | |
print(f"Could not load tokenizer from {lora_adapter_id}, falling back to {base_model_id}.") | |
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True) | |
if tokenizer.pad_token_id is None: | |
if tokenizer.eos_token_id is not None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
else: # Fallback for Qwen, ensure this ID is correct for your Qwen version | |
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>") if "<|endoftext|>" in tokenizer.vocab else 0 | |
tokenizer.padding_side = "left" # Important for generation | |
print("Loading base model with quantization...") | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16 # As used in fine-tuning | |
) | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_id, | |
quantization_config=quantization_config, | |
trust_remote_code=True, | |
device_map={"":0} if device == "cuda" else "cpu" # Load directly to GPU if available, else CPU | |
) | |
print("Base model loaded.") | |
print(f"Loading LoRA adapter: {lora_adapter_id}...") | |
model = PeftModel.from_pretrained(base_model, lora_adapter_id) | |
model.eval() # Set to evaluation mode | |
print("LoRA adapter loaded and model is ready.") | |
if device == "cpu": # If on CPU, PEFT might not automatically move the full model if device_map wasn't used correctly for CPU | |
model = model.to(device) | |
print(f"Model explicitly moved to {device}") | |
# --- Prediction Function --- | |
def get_response(user_query): | |
system_prompt_content = "你是一个专注于偏瘫、脑血栓、半身不遂领域的医疗问答助手。" | |
# Construct prompt using Qwen's ChatML format | |
prompt = f"<|im_start|>system\n{system_prompt_content}<|im_end|>\n<|im_start|>user\n{user_query}<|im_end|>\n<|im_start|>assistant\n" | |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512-150).to(model.device) # Leave space for generation | |
eos_token_ids_list = [] | |
if isinstance(tokenizer.eos_token_id, int): | |
eos_token_ids_list.append(tokenizer.eos_token_id) | |
try: | |
im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
if im_end_token_id not in eos_token_ids_list: | |
eos_token_ids_list.append(im_end_token_id) | |
except KeyError: pass | |
if not eos_token_ids_list and tokenizer.eos_token_id is not None: # Fallback if list is empty but single eos_token_id exists | |
eos_token_ids_list = [tokenizer.eos_token_id] | |
elif not eos_token_ids_list: # Absolute fallback | |
print("Warning: EOS token ID list is empty. Generation might not stop correctly.") | |
# Attempt to use a known Qwen EOS ID if possible, otherwise generation might be problematic. | |
# This scenario should ideally be avoided by robust tokenizer setup. | |
# eos_token_ids_list = [tokenizer.vocab_size - 1] # Very risky fallback | |
print(f"Generating response for query: '{user_query}'") | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=150, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=eos_token_ids_list if eos_token_ids_list else None, # Pass list or None | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
num_beams=1 # Use 1 for sampling, or >1 for beam search (do_sample=False then) | |
) | |
response_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
print(f"Raw response: '{response_text}'") | |
return response_text.strip() | |
# --- Gradio Interface --- | |
iface = gr.Interface( | |
fn=get_response, | |
inputs=gr.Textbox(lines=3, placeholder="请输入您关于偏瘫、脑血栓或半身不遂的问题...", label="您的问题 (Your Question)"), | |
outputs=gr.Textbox(lines=5, label="模型回答 (Model Response)"), | |
title="偏瘫脑血栓问答助手 (Hemiplegia/Stroke Q&A Assistant)", | |
description="由 Qwen-1.8B-Chat LoRA 微调得到的模型 (jinv2/qwen-1_8b-hemiplegia-lora)。与天算AI相关。**医疗建议请咨询专业医生。**", | |
examples=[ | |
["偏瘫患者的早期康复锻炼有哪些?"], | |
["什么是脑血栓?"], | |
["中风后如何进行语言恢复训练?"] | |
], | |
allow_flagging="never" # Disable flagging for simplicity | |
) | |
if __name__ == "__main__": | |
iface.launch() |