jinv2 commited on
Commit
14b9d25
·
verified ·
1 Parent(s): ba352fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -41
app.py CHANGED
@@ -1,61 +1,83 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  from peft import PeftModel
4
  import torch
 
5
 
6
  # --- Configuration ---
7
  base_model_id = "Qwen/Qwen-1_8B-Chat"
8
  lora_adapter_id = "jinv2/qwen-1_8b-hemiplegia-lora" # Your HF Model ID
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
10
  print(f"Using device: {device}")
11
 
12
  # --- Load Model and Tokenizer ---
13
  print("Loading tokenizer...")
14
  try:
 
15
  tokenizer = AutoTokenizer.from_pretrained(lora_adapter_id, trust_remote_code=True)
16
  print(f"Successfully loaded tokenizer from {lora_adapter_id}.")
17
- except Exception:
18
- print(f"Could not load tokenizer from {lora_adapter_id}, falling back to {base_model_id}.")
19
  tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
20
 
 
21
  if tokenizer.pad_token_id is None:
22
  if tokenizer.eos_token_id is not None:
23
  tokenizer.pad_token_id = tokenizer.eos_token_id
24
- else: # Fallback for Qwen, ensure this ID is correct for your Qwen version
25
- tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>") if "<|endoftext|>" in tokenizer.vocab else 0
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  tokenizer.padding_side = "left" # Important for generation
28
 
29
- print("Loading base model with quantization...")
30
- quantization_config = BitsAndBytesConfig(
31
- load_in_4bit=True,
32
- bnb_4bit_compute_dtype=torch.float16 # As used in fine-tuning
33
- )
34
- base_model = AutoModelForCausalLM.from_pretrained(
35
- base_model_id,
36
- quantization_config=quantization_config,
37
- trust_remote_code=True,
38
- device_map={"":0} if device == "cuda" else "cpu" # Load directly to GPU if available, else CPU
39
- )
40
- print("Base model loaded.")
 
 
 
 
41
 
42
  print(f"Loading LoRA adapter: {lora_adapter_id}...")
43
- model = PeftModel.from_pretrained(base_model, lora_adapter_id)
44
- model.eval() # Set to evaluation mode
45
- print("LoRA adapter loaded and model is ready.")
46
- if device == "cpu": # If on CPU, PEFT might not automatically move the full model if device_map wasn't used correctly for CPU
47
- model = model.to(device)
48
- print(f"Model explicitly moved to {device}")
 
 
 
49
 
50
 
51
  # --- Prediction Function ---
52
  def get_response(user_query):
53
  system_prompt_content = "你是一个专注于偏瘫、脑血栓、半身不遂领域的医疗问答助手。"
54
 
55
- # Construct prompt using Qwen's ChatML format
56
  prompt = f"<|im_start|>system\n{system_prompt_content}<|im_end|>\n<|im_start|>user\n{user_query}<|im_end|>\n<|im_start|>assistant\n"
57
 
58
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512-150).to(model.device) # Leave space for generation
 
59
 
60
  eos_token_ids_list = []
61
  if isinstance(tokenizer.eos_token_id, int):
@@ -65,25 +87,31 @@ def get_response(user_query):
65
  if im_end_token_id not in eos_token_ids_list:
66
  eos_token_ids_list.append(im_end_token_id)
67
  except KeyError: pass
68
- if not eos_token_ids_list and tokenizer.eos_token_id is not None: # Fallback if list is empty but single eos_token_id exists
69
- eos_token_ids_list = [tokenizer.eos_token_id]
70
- elif not eos_token_ids_list: # Absolute fallback
71
- print("Warning: EOS token ID list is empty. Generation might not stop correctly.")
72
- # Attempt to use a known Qwen EOS ID if possible, otherwise generation might be problematic.
73
- # This scenario should ideally be avoided by robust tokenizer setup.
74
- # eos_token_ids_list = [tokenizer.vocab_size - 1] # Very risky fallback
 
 
 
 
 
 
75
 
76
- print(f"Generating response for query: '{user_query}'")
77
- with torch.no_grad():
78
  outputs = model.generate(
79
  **inputs,
80
  max_new_tokens=150,
81
  pad_token_id=tokenizer.pad_token_id,
82
- eos_token_id=eos_token_ids_list if eos_token_ids_list else None, # Pass list or None
83
  temperature=0.7,
84
  top_p=0.9,
85
  do_sample=True,
86
- num_beams=1 # Use 1 for sampling, or >1 for beam search (do_sample=False then)
87
  )
88
 
89
  response_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
@@ -95,15 +123,18 @@ iface = gr.Interface(
95
  fn=get_response,
96
  inputs=gr.Textbox(lines=3, placeholder="请输入您关于偏瘫、脑血栓或半身不遂的问题...", label="您的问题 (Your Question)"),
97
  outputs=gr.Textbox(lines=5, label="模型回答 (Model Response)"),
98
- title="偏瘫脑血栓问答助手 (Hemiplegia/Stroke Q&A Assistant)",
99
- description="由 Qwen-1.8B-Chat LoRA 微调得到的模型 (jinv2/qwen-1_8b-hemiplegia-lora)。与天算AI相关。**医疗建议请咨询专业医生。**",
 
 
 
100
  examples=[
101
  ["偏瘫患者的早期康复锻炼有哪些?"],
102
  ["什么是脑血栓?"],
103
  ["中风后如何进行语言恢复训练?"]
104
  ],
105
- allow_flagging="never" # Disable flagging for simplicity
106
  )
107
 
108
  if __name__ == "__main__":
109
- iface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer # Removed BitsAndBytesConfig as we are not quantizing for CPU
3
  from peft import PeftModel
4
  import torch
5
+ import os # Ensure os is imported for potential path joining if needed
6
 
7
  # --- Configuration ---
8
  base_model_id = "Qwen/Qwen-1_8B-Chat"
9
  lora_adapter_id = "jinv2/qwen-1_8b-hemiplegia-lora" # Your HF Model ID
10
+ # device = "cuda" if torch.cuda.is_available() else "cpu" # Will always be "cpu" on free tier
11
+ device = "cpu" # Explicitly set to CPU for this configuration
12
  print(f"Using device: {device}")
13
 
14
  # --- Load Model and Tokenizer ---
15
  print("Loading tokenizer...")
16
  try:
17
+ # Try loading tokenizer from your LoRA repo first, as it might contain specific settings
18
  tokenizer = AutoTokenizer.from_pretrained(lora_adapter_id, trust_remote_code=True)
19
  print(f"Successfully loaded tokenizer from {lora_adapter_id}.")
20
+ except Exception as e_lora_tok:
21
+ print(f"Could not load tokenizer from {lora_adapter_id} (Error: {e_lora_tok}), falling back to {base_model_id}.")
22
  tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
23
 
24
+ # Set pad_token if not already set
25
  if tokenizer.pad_token_id is None:
26
  if tokenizer.eos_token_id is not None:
27
  tokenizer.pad_token_id = tokenizer.eos_token_id
28
+ tokenizer.pad_token = tokenizer.eos_token
29
+ print(f"Set tokenizer.pad_token_id to eos_token_id: {tokenizer.pad_token_id}")
30
+ else:
31
+ # Fallback for Qwen, ensure this ID is correct for your Qwen version
32
+ try:
33
+ qwen_eos_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
34
+ tokenizer.pad_token_id = qwen_eos_id
35
+ tokenizer.pad_token = "<|endoftext|>"
36
+ print(f"Set tokenizer.pad_token_id to ID of '<|endoftext|>: {tokenizer.pad_token_id}")
37
+ except KeyError:
38
+ tokenizer.pad_token_id = 0 # Absolute fallback, very risky
39
+ tokenizer.pad_token = tokenizer.decode([0])
40
+ print(f"CRITICAL WARNING: Could not set pad_token_id reliably. Set to 0 ('{tokenizer.pad_token}').")
41
 
42
  tokenizer.padding_side = "left" # Important for generation
43
 
44
+ print("Loading base model (NO QUANTIZATION as running on CPU)...")
45
+ # IMPORTANT: For CPU, we cannot use bitsandbytes 4-bit quantization.
46
+ # We load the model in its original precision (or try float16/bfloat16 if memory allows and CPU supports).
47
+ # This will be much slower and more memory-intensive.
48
+ try:
49
+ base_model = AutoModelForCausalLM.from_pretrained(
50
+ base_model_id,
51
+ trust_remote_code=True,
52
+ torch_dtype=torch.float32, # Use float32 for CPU for max compatibility, bfloat16 might work on some newer CPUs
53
+ # device_map="auto" will likely map to CPU. Can be explicit: device_map="cpu"
54
+ device_map={"":device} # Ensure model parts are on the correct device
55
+ )
56
+ print("Base model loaded.")
57
+ except Exception as e_load_model:
58
+ print(f"Error loading base model: {e_load_model}")
59
+ raise # Re-raise the exception to stop the app if model loading fails
60
 
61
  print(f"Loading LoRA adapter: {lora_adapter_id}...")
62
+ try:
63
+ # For CPU, PEFT should still work. The model should be on the CPU before applying adapter.
64
+ model = PeftModel.from_pretrained(base_model, lora_adapter_id)
65
+ model.eval() # Set to evaluation mode
66
+ model = model.to(device) # Ensure the final PEFT model is on the CPU
67
+ print("LoRA adapter loaded and model is on CPU, ready for inference.")
68
+ except Exception as e_load_adapter:
69
+ print(f"Error loading LoRA adapter: {e_load_adapter}")
70
+ raise
71
 
72
 
73
  # --- Prediction Function ---
74
  def get_response(user_query):
75
  system_prompt_content = "你是一个专注于偏瘫、脑血栓、半身不遂领域的医疗问答助手。"
76
 
 
77
  prompt = f"<|im_start|>system\n{system_prompt_content}<|im_end|>\n<|im_start|>user\n{user_query}<|im_end|>\n<|im_start|>assistant\n"
78
 
79
+ # Ensure inputs are on the same device as the model
80
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512-150).to(model.device)
81
 
82
  eos_token_ids_list = []
83
  if isinstance(tokenizer.eos_token_id, int):
 
87
  if im_end_token_id not in eos_token_ids_list:
88
  eos_token_ids_list.append(im_end_token_id)
89
  except KeyError: pass
90
+
91
+ # Fallback if eos_token_ids_list is still empty
92
+ if not eos_token_ids_list:
93
+ if tokenizer.eos_token_id is not None:
94
+ eos_token_ids_list = [tokenizer.eos_token_id]
95
+ else:
96
+ print("Warning: EOS token ID list is empty and eos_token_id is None. Generation might not stop correctly.")
97
+ # Attempt to use a known Qwen EOS ID if possible, otherwise generation might be problematic.
98
+ try:
99
+ eos_token_ids_list = [tokenizer.convert_tokens_to_ids("<|endoftext|>")]
100
+ except KeyError:
101
+ eos_token_ids_list = [tokenizer.vocab_size - 1 if tokenizer.vocab_size else 0] # Very risky fallback
102
+
103
 
104
+ print(f"Generating response for query: '{user_query}' on device: {model.device}")
105
+ with torch.no_grad(): # Inference doesn't need gradient calculation
106
  outputs = model.generate(
107
  **inputs,
108
  max_new_tokens=150,
109
  pad_token_id=tokenizer.pad_token_id,
110
+ eos_token_id=eos_token_ids_list if eos_token_ids_list else None,
111
  temperature=0.7,
112
  top_p=0.9,
113
  do_sample=True,
114
+ num_beams=1
115
  )
116
 
117
  response_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
 
123
  fn=get_response,
124
  inputs=gr.Textbox(lines=3, placeholder="请输入您关于偏瘫、脑血栓或半身不遂的问题...", label="您的问题 (Your Question)"),
125
  outputs=gr.Textbox(lines=5, label="模型回答 (Model Response)"),
126
+ title="偏瘫脑血栓问答助手 (CPU Version - Expect Slow Response)",
127
+ description=(
128
+ "由 Qwen-1.8B-Chat LoRA 微调得到的模型 (jinv2/qwen-1_8b-hemiplegia-lora)。与天算AI相关。\n"
129
+ "**重要:此版本运行在 CPU 上,无量化,响应会非常慢。医疗建议请咨询专业医生。**"
130
+ ),
131
  examples=[
132
  ["偏瘫患者的早期康复锻炼有哪些?"],
133
  ["什么是脑血栓?"],
134
  ["中风后如何进行语言恢复训练?"]
135
  ],
136
+ allow_flagging="never"
137
  )
138
 
139
  if __name__ == "__main__":
140
+ iface.launch() # debug=True can be helpful for local testing but not for Spaces deployment