TETSU0701 commited on
Commit
5fab04c
·
verified ·
1 Parent(s): 5108765

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -40
app.py CHANGED
@@ -7,6 +7,9 @@ from Model import OmniPathWithInterTaskAttention
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  import transformers
9
  import os
 
 
 
10
 
11
  # 强制设置 Gradio 为英文环境
12
  os.environ["GRADIO_LOCALE"] = "en"
@@ -53,8 +56,8 @@ def load_models():
53
  tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
54
  llm_model = AutoModelForCausalLM.from_pretrained(
55
  llm_model_name,
56
- dtype="auto",
57
- device_map="auto"
58
  )
59
 
60
  return classification_model, llm_model, tokenizer, label_mappings
@@ -105,20 +108,19 @@ def analyze_npy_file(npy_file):
105
  return None, f"An error occurred during processing: {str(e)}"
106
 
107
  def generate_response(message, chat_history, analysis_results):
108
- """Generate response based on user message and analysis results"""
109
  if analysis_results is None:
110
- return "Please upload an NPY file first to analyze the patient data.", chat_history
111
-
 
112
  pred_names = analysis_results["pred_names"]
113
  pred_scores = analysis_results["pred_scores"]
114
  patient_id = analysis_results["patient_id"]
115
-
116
- # Build context from analysis results
117
  context = f"Patient {patient_id} analysis results:\n"
118
  for task, name in pred_names.items():
119
  context += f"- {task}: {name} (confidence: {pred_scores.get(task, 0.0):.3f})\n"
120
-
121
- # Build prompt based on user message
122
  if "diagnosis" in message.lower() or "result" in message.lower():
123
  prompt = f"{context}\nBased on the above analysis results, provide a detailed diagnosis summary and interpretation."
124
  elif "treatment" in message.lower() or "therapy" in message.lower():
@@ -131,41 +133,37 @@ def generate_response(message, chat_history, analysis_results):
131
  prompt = f"{context}\nDescribe the histological characteristics and their significance."
132
  else:
133
  prompt = f"{context}\nUser question: {message}\nPlease provide a helpful response based on the analysis results."
134
-
135
- try:
136
- # Generate response using LLM
137
- messages = [{"role": "user", "content": prompt}]
138
- text = tokenizer.apply_chat_template(
139
- messages,
140
- tokenize=False,
141
- add_generation_prompt=True,
142
- enable_thinking=False
143
- )
144
-
145
- model_inputs = tokenizer([text], return_tensors="pt").to(llm_model.device)
146
- generated_ids = llm_model.generate(
147
  **model_inputs,
148
- max_new_tokens=2048,
149
  do_sample=True,
150
  temperature=0.7,
 
 
151
  )
152
-
153
- output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
154
- try:
155
- index = len(output_ids) - output_ids[::-1].index(151668)
156
- except ValueError:
157
- index = 0
158
-
159
- response = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
160
-
161
- # Add to chat history
162
- chat_history.append((message, response))
163
- return "", chat_history
164
-
165
- except Exception as e:
166
- error_msg = f"Error generating response: {str(e)}"
167
- chat_history.append((message, error_msg))
168
- return "", chat_history
169
 
170
  def upload_file(npy_file, chat_history, analysis_results):
171
  """Handle file upload and initial analysis"""
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  import transformers
9
  import os
10
+ from threading import Thread
11
+ from transformers import TextIteratorStreamer
12
+
13
 
14
  # 强制设置 Gradio 为英文环境
15
  os.environ["GRADIO_LOCALE"] = "en"
 
56
  tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
57
  llm_model = AutoModelForCausalLM.from_pretrained(
58
  llm_model_name,
59
+ device_map="auto",
60
+ load_in_4bit=True
61
  )
62
 
63
  return classification_model, llm_model, tokenizer, label_mappings
 
108
  return None, f"An error occurred during processing: {str(e)}"
109
 
110
  def generate_response(message, chat_history, analysis_results):
111
+ """Generate streamed LLM response"""
112
  if analysis_results is None:
113
+ yield "Please upload an NPY file first to analyze the patient data.", chat_history
114
+ return
115
+
116
  pred_names = analysis_results["pred_names"]
117
  pred_scores = analysis_results["pred_scores"]
118
  patient_id = analysis_results["patient_id"]
119
+
 
120
  context = f"Patient {patient_id} analysis results:\n"
121
  for task, name in pred_names.items():
122
  context += f"- {task}: {name} (confidence: {pred_scores.get(task, 0.0):.3f})\n"
123
+
 
124
  if "diagnosis" in message.lower() or "result" in message.lower():
125
  prompt = f"{context}\nBased on the above analysis results, provide a detailed diagnosis summary and interpretation."
126
  elif "treatment" in message.lower() or "therapy" in message.lower():
 
133
  prompt = f"{context}\nDescribe the histological characteristics and their significance."
134
  else:
135
  prompt = f"{context}\nUser question: {message}\nPlease provide a helpful response based on the analysis results."
136
+
137
+ messages = [{"role": "user", "content": prompt}]
138
+ text = tokenizer.apply_chat_template(
139
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
140
+ )
141
+
142
+ model_inputs = tokenizer([text], return_tensors="pt").to(llm_model.device)
143
+
144
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
145
+
146
+ thread = Thread(
147
+ target=lambda: llm_model.generate(
 
148
  **model_inputs,
149
+ max_new_tokens=1024, # 🚀 改成较小输出以提升速度
150
  do_sample=True,
151
  temperature=0.7,
152
+ top_p=0.9,
153
+ streamer=streamer
154
  )
155
+ )
156
+ thread.start()
157
+
158
+ partial = ""
159
+ for new_text in streamer:
160
+ partial += new_text
161
+ # 实时输出
162
+ yield "", chat_history + [(message, partial)]
163
+
164
+ # 完成后写回最终内容到历史
165
+ chat_history.append((message, partial))
166
+ yield "", chat_history
 
 
 
 
 
167
 
168
  def upload_file(npy_file, chat_history, analysis_results):
169
  """Handle file upload and initial analysis"""