oliver-aizip commited on
Commit
6b26b26
·
1 Parent(s): c4fe1db

proper threaded generation interrupt

Browse files
Files changed (1) hide show
  1. utils/models.py +118 -35
utils/models.py CHANGED
@@ -2,6 +2,9 @@ import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
3
  from .prompts import format_rag_prompt
4
  from .shared import generation_interrupt
 
 
 
5
 
6
  models = {
7
  "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
@@ -42,84 +45,164 @@ def generate_summaries(example, model_a_name, model_b_name):
42
  if generation_interrupt.is_set():
43
  return "", ""
44
 
45
- summary_a = run_inference(models[model_a_name], context_text, question)
46
-
47
- if generation_interrupt.is_set():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  return summary_a, ""
49
-
50
- summary_b = run_inference(models[model_b_name], context_text, question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  return summary_a, summary_b
52
 
53
- def run_inference(model_name, context, question):
 
 
54
  """
55
- Run inference using the specified model.
 
56
  """
 
57
  if generation_interrupt.is_set():
58
- return ""
59
-
 
60
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
61
 
62
  try:
63
  tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
64
  accepts_sys = (
65
  "System role not supported" not in tokenizer.chat_template
 
66
  )
67
-
68
  if tokenizer.pad_token is None:
69
  tokenizer.pad_token = tokenizer.eos_token
70
-
 
71
  if generation_interrupt.is_set():
72
- return ""
73
-
 
74
  model = AutoModelForCausalLM.from_pretrained(
75
  model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
76
  ).to(device)
 
77
 
78
  text_input = format_rag_prompt(question, context, accepts_sys)
79
 
 
80
  if generation_interrupt.is_set():
81
- return ""
82
-
 
83
  actual_input = tokenizer.apply_chat_template(
84
  text_input,
85
  return_tensors="pt",
86
  tokenize=True,
87
- max_length=2048,
 
 
 
88
  add_generation_prompt=True,
89
  ).to(device)
90
 
 
 
 
 
 
 
 
 
91
  input_length = actual_input.shape[1]
92
  attention_mask = torch.ones_like(actual_input).to(device)
93
-
 
94
  if generation_interrupt.is_set():
95
- return ""
96
-
 
97
  stopping_criteria = StoppingCriteriaList([InterruptCriteria(generation_interrupt)])
98
-
99
  with torch.inference_mode():
100
  outputs = model.generate(
101
  actual_input,
102
  attention_mask=attention_mask,
103
- max_new_tokens=512,
104
  pad_token_id=tokenizer.pad_token_id,
105
- stopping_criteria=stopping_criteria
 
 
 
106
  )
107
 
 
108
  if generation_interrupt.is_set():
109
- return ""
110
-
111
- result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
112
-
113
- return result
114
-
 
115
  except Exception as e:
116
- print(f"Error in inference: {e}")
117
- return f"Error generating response: {str(e)[:100]}..."
118
-
 
119
  finally:
120
- if 'model' in locals():
121
- del model
122
- if 'tokenizer' in locals():
123
- del tokenizer
 
124
  if torch.cuda.is_available():
125
  torch.cuda.empty_cache()
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
3
  from .prompts import format_rag_prompt
4
  from .shared import generation_interrupt
5
+ import threading
6
+ import queue
7
+ import time # Added for sleep
8
 
9
  models = {
10
  "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
 
45
  if generation_interrupt.is_set():
46
  return "", ""
47
 
48
+ # Use a queue to get results from threads
49
+ result_queue_a = queue.Queue()
50
+ thread_a = threading.Thread(target=run_inference, args=(models[model_a_name], context_text, question, result_queue_a))
51
+ thread_a.start()
52
+
53
+ summary_a = ""
54
+ while thread_a.is_alive():
55
+ if generation_interrupt.is_set():
56
+ print(f"Interrupting model A ({model_a_name})...")
57
+ # The InterruptCriteria within the thread will handle stopping generate
58
+ # We return early from the main control flow.
59
+ thread_a.join(timeout=1.0) # Give thread a moment to potentially stop
60
+ return "", ""
61
+ try:
62
+ summary_a = result_queue_a.get(timeout=0.1) # Check queue periodically
63
+ break # Got result
64
+ except queue.Empty:
65
+ continue # Still running, check interrupt again
66
+
67
+ # If thread finished but we didn't get a result (e.g., interrupted just before putting in queue)
68
+ if not summary_a and not result_queue_a.empty():
69
+ summary_a = result_queue_a.get_nowait()
70
+ elif not summary_a and generation_interrupt.is_set(): # Check interrupt again if thread finished quickly
71
+ return "", ""
72
+
73
+
74
+ if generation_interrupt.is_set(): # Check between models
75
  return summary_a, ""
76
+
77
+ # --- Model B ---
78
+ result_queue_b = queue.Queue()
79
+ thread_b = threading.Thread(target=run_inference, args=(models[model_b_name], context_text, question, result_queue_b))
80
+ thread_b.start()
81
+
82
+ summary_b = ""
83
+ while thread_b.is_alive():
84
+ if generation_interrupt.is_set():
85
+ print(f"Interrupting model B ({model_b_name})...")
86
+ thread_b.join(timeout=1.0)
87
+ return summary_a, "" # Return summary_a obtained so far
88
+ try:
89
+ summary_b = result_queue_b.get(timeout=0.1)
90
+ break
91
+ except queue.Empty:
92
+ continue
93
+
94
+ if not summary_b and not result_queue_b.empty():
95
+ summary_b = result_queue_b.get_nowait()
96
+ elif not summary_b and generation_interrupt.is_set():
97
+ return summary_a, ""
98
+
99
+
100
  return summary_a, summary_b
101
 
102
+
103
+ # Modified run_inference to run in a thread and use a queue for results
104
+ def run_inference(model_name, context, question, result_queue):
105
  """
106
+ Run inference using the specified model. Designed to be run in a thread.
107
+ Puts the result or an error string into the result_queue.
108
  """
109
+ # Check interrupt at the very beginning of the thread
110
  if generation_interrupt.is_set():
111
+ result_queue.put("")
112
+ return
113
+
114
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
115
+ model = None
116
+ tokenizer = None
117
+ result = ""
118
 
119
  try:
120
  tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
121
  accepts_sys = (
122
  "System role not supported" not in tokenizer.chat_template
123
+ if tokenizer.chat_template else False # Handle missing chat_template
124
  )
125
+
126
  if tokenizer.pad_token is None:
127
  tokenizer.pad_token = tokenizer.eos_token
128
+
129
+ # Check interrupt before loading the model
130
  if generation_interrupt.is_set():
131
+ result_queue.put("")
132
+ return
133
+
134
  model = AutoModelForCausalLM.from_pretrained(
135
  model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
136
  ).to(device)
137
+ model.eval() # Set model to evaluation mode
138
 
139
  text_input = format_rag_prompt(question, context, accepts_sys)
140
 
141
+ # Check interrupt before tokenization/template application
142
  if generation_interrupt.is_set():
143
+ result_queue.put("")
144
+ return
145
+
146
  actual_input = tokenizer.apply_chat_template(
147
  text_input,
148
  return_tensors="pt",
149
  tokenize=True,
150
+ # Consider reducing max_length if context/question is very long
151
+ # max_length=tokenizer.model_max_length, # Use model's max length
152
+ # truncation=True, # Ensure truncation if needed
153
+ max_length=2048, # Keep original max_length for now
154
  add_generation_prompt=True,
155
  ).to(device)
156
 
157
+ # Ensure input does not exceed model max length after adding generation prompt
158
+ # This check might be redundant if tokenizer handles it, but good for safety
159
+ # if actual_input.shape[1] > tokenizer.model_max_length:
160
+ # # Handle too long input - maybe truncate manually or raise error
161
+ # print(f"Warning: Input length {actual_input.shape[1]} exceeds model max length {tokenizer.model_max_length}")
162
+ # # Simple truncation (might lose important info):
163
+ # # actual_input = actual_input[:, -tokenizer.model_max_length:]
164
+
165
  input_length = actual_input.shape[1]
166
  attention_mask = torch.ones_like(actual_input).to(device)
167
+
168
+ # Check interrupt before generation
169
  if generation_interrupt.is_set():
170
+ result_queue.put("")
171
+ return
172
+
173
  stopping_criteria = StoppingCriteriaList([InterruptCriteria(generation_interrupt)])
174
+
175
  with torch.inference_mode():
176
  outputs = model.generate(
177
  actual_input,
178
  attention_mask=attention_mask,
179
+ max_new_tokens=512,
180
  pad_token_id=tokenizer.pad_token_id,
181
+ stopping_criteria=stopping_criteria,
182
+ do_sample=True, # Consider adding sampling parameters if needed
183
+ temperature=0.6,
184
+ top_p=0.9,
185
  )
186
 
187
+ # Check interrupt immediately after generation finishes or stops
188
  if generation_interrupt.is_set():
189
+ result = "" # Discard potentially partial result if interrupted
190
+ else:
191
+ # Decode the generated tokens, excluding the input tokens
192
+ result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
193
+
194
+ result_queue.put(result)
195
+
196
  except Exception as e:
197
+ print(f"Error in inference thread for {model_name}: {e}")
198
+ # Put error message in queue for the main thread to handle/display
199
+ result_queue.put(f"Error generating response: {str(e)[:100]}...")
200
+
201
  finally:
202
+ # Clean up resources within the thread
203
+ del model
204
+ del tokenizer
205
+ del actual_input
206
+ del outputs
207
  if torch.cuda.is_available():
208
  torch.cuda.empty_cache()