oliver-aizip commited on
Commit
6c63a2d
·
1 Parent(s): a8243a3

vllm backend swap v1

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -0
  2. utils/models.py +80 -62
requirements.txt CHANGED
@@ -6,3 +6,4 @@ numpy==1.26.4
6
  openai>=1.60.2
7
  torch>=2.5.1
8
  tqdm==4.67.1
 
 
6
  openai>=1.60.2
7
  torch>=2.5.1
8
  tqdm==4.67.1
9
+ vllm>=0.8.5
utils/models.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
3
  from .prompts import format_rag_prompt
@@ -5,7 +8,7 @@ from .shared import generation_interrupt
5
  import threading
6
  import queue
7
  import time # Added for sleep
8
-
9
  models = {
10
  "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
11
  "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
@@ -123,86 +126,101 @@ def run_inference(model_name, context, question, result_queue):
123
  if tokenizer.chat_template else False # Handle missing chat_template
124
  )
125
 
126
- if tokenizer.pad_token is None:
127
- tokenizer.pad_token = tokenizer.eos_token
128
 
129
- # Check interrupt before loading the model
130
- if generation_interrupt.is_set():
131
- result_queue.put("")
132
- return
133
 
134
- model = AutoModelForCausalLM.from_pretrained(
135
- model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
136
- ).to(device)
137
- model.eval() # Set model to evaluation mode
138
 
139
  text_input = format_rag_prompt(question, context, accepts_sys)
140
 
141
- # Check interrupt before tokenization/template application
142
- if generation_interrupt.is_set():
143
- result_queue.put("")
144
- return
145
-
146
- actual_input = tokenizer.apply_chat_template(
147
- text_input,
148
- return_tensors="pt",
149
- tokenize=True,
150
- # Consider reducing max_length if context/question is very long
151
- # max_length=tokenizer.model_max_length, # Use model's max length
152
- # truncation=True, # Ensure truncation if needed
153
- max_length=2048, # Keep original max_length for now
154
- add_generation_prompt=True,
155
- ).to(device)
156
-
157
- # Ensure input does not exceed model max length after adding generation prompt
158
- # This check might be redundant if tokenizer handles it, but good for safety
159
- # if actual_input.shape[1] > tokenizer.model_max_length:
160
- # # Handle too long input - maybe truncate manually or raise error
161
- # print(f"Warning: Input length {actual_input.shape[1]} exceeds model max length {tokenizer.model_max_length}")
162
- # # Simple truncation (might lose important info):
163
- # # actual_input = actual_input[:, -tokenizer.model_max_length:]
164
-
165
- input_length = actual_input.shape[1]
166
- attention_mask = torch.ones_like(actual_input).to(device)
167
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  # Check interrupt before generation
169
  if generation_interrupt.is_set():
170
  result_queue.put("")
171
  return
172
-
173
- stopping_criteria = StoppingCriteriaList([InterruptCriteria(generation_interrupt)])
174
-
175
- with torch.inference_mode():
176
- outputs = model.generate(
177
- actual_input,
178
- attention_mask=attention_mask,
179
- max_new_tokens=512,
180
- pad_token_id=tokenizer.pad_token_id,
181
- stopping_criteria=stopping_criteria,
182
- do_sample=True, # Consider adding sampling parameters if needed
183
- temperature=0.6,
184
- top_p=0.9,
185
- )
186
-
187
  # Check interrupt immediately after generation finishes or stops
188
- if generation_interrupt.is_set():
189
- result = "" # Discard potentially partial result if interrupted
190
- else:
191
- # Decode the generated tokens, excluding the input tokens
192
- result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
193
-
194
- result_queue.put(result)
195
 
196
  except Exception as e:
197
  print(f"Error in inference thread for {model_name}: {e}")
198
  # Put error message in queue for the main thread to handle/display
199
- result_queue.put(f"Error generating response: {str(e)[:100]}...")
200
 
201
  finally:
202
  # Clean up resources within the thread
203
  del model
204
  del tokenizer
205
- del actual_input
206
  del outputs
207
  if torch.cuda.is_available():
208
  torch.cuda.empty_cache()
 
1
+ import os
2
+ os.environ['MKL_THREADING_LAYER'] = 'GNU'
3
+
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
6
  from .prompts import format_rag_prompt
 
8
  import threading
9
  import queue
10
  import time # Added for sleep
11
+ from vllm import LLM, SamplingParams
12
  models = {
13
  "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
14
  "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
 
126
  if tokenizer.chat_template else False # Handle missing chat_template
127
  )
128
 
129
+ # if tokenizer.pad_token is None:
130
+ # tokenizer.pad_token = tokenizer.eos_token
131
 
132
+ # # Check interrupt before loading the model
133
+ # if generation_interrupt.is_set():
134
+ # result_queue.put("")
135
+ # return
136
 
137
+ # model = AutoModelForCausalLM.from_pretrained(
138
+ # model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
139
+ # ).to(device)
140
+ # model.eval() # Set model to evaluation mode
141
 
142
  text_input = format_rag_prompt(question, context, accepts_sys)
143
 
144
+ # # Check interrupt before tokenization/template application
145
+ # if generation_interrupt.is_set():
146
+ # result_queue.put("")
147
+ # return
148
+
149
+ # actual_input = tokenizer.apply_chat_template(
150
+ # text_input,
151
+ # return_tensors="pt",
152
+ # tokenize=True,
153
+ # # Consider reducing max_length if context/question is very long
154
+ # # max_length=tokenizer.model_max_length, # Use model's max length
155
+ # # truncation=True, # Ensure truncation if needed
156
+ # max_length=2048, # Keep original max_length for now
157
+ # add_generation_prompt=True,
158
+ # ).to(device)
159
+
160
+ # # Ensure input does not exceed model max length after adding generation prompt
161
+ # # This check might be redundant if tokenizer handles it, but good for safety
162
+ # # if actual_input.shape[1] > tokenizer.model_max_length:
163
+ # # # Handle too long input - maybe truncate manually or raise error
164
+ # # print(f"Warning: Input length {actual_input.shape[1]} exceeds model max length {tokenizer.model_max_length}")
165
+ # # # Simple truncation (might lose important info):
166
+ # # # actual_input = actual_input[:, -tokenizer.model_max_length:]
167
+
168
+ # input_length = actual_input.shape[1]
169
+ # attention_mask = torch.ones_like(actual_input).to(device)
170
+
171
+ # # Check interrupt before generation
172
+ # if generation_interrupt.is_set():
173
+ # result_queue.put("")
174
+ # return
175
+
176
+ # stopping_criteria = StoppingCriteriaList([InterruptCriteria(generation_interrupt)])
177
+
178
+ # with torch.inference_mode():
179
+ # outputs = model.generate(
180
+ # actual_input,
181
+ # attention_mask=attention_mask,
182
+ # max_new_tokens=512,
183
+ # pad_token_id=tokenizer.pad_token_id,
184
+ # stopping_criteria=stopping_criteria,
185
+ # do_sample=True, # Consider adding sampling parameters if needed
186
+ # temperature=0.6,
187
+ # top_p=0.9,
188
+ # )
189
+
190
+ # # Check interrupt immediately after generation finishes or stops
191
+ # if generation_interrupt.is_set():
192
+ # result = "" # Discard potentially partial result if interrupted
193
+ # else:
194
+ # # Decode the generated tokens, excluding the input tokens
195
+ # result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
196
+ llm = LLM(model_name, dtype=torch.bfloat16, hf_token=True, enforce_eager=True)
197
+ params = SamplingParams(
198
+ max_tokens=512,
199
+ )
200
+
201
  # Check interrupt before generation
202
  if generation_interrupt.is_set():
203
  result_queue.put("")
204
  return
205
+ # Generate the response
206
+ outputs = llm.chat(
207
+ text_input,
208
+ sampling_params=params,
209
+ # stopping_criteria=StoppingCriteriaList([InterruptCriteria(generation_interrupt)]),
210
+ )
 
 
 
 
 
 
 
 
 
211
  # Check interrupt immediately after generation finishes or stops
212
+ result_queue.put(outputs[0].outputs[0].text)
 
 
 
 
 
 
213
 
214
  except Exception as e:
215
  print(f"Error in inference thread for {model_name}: {e}")
216
  # Put error message in queue for the main thread to handle/display
217
+ result_queue.put(f"Error generating response: {str(e)[:200]}...")
218
 
219
  finally:
220
  # Clean up resources within the thread
221
  del model
222
  del tokenizer
223
+ del text_input
224
  del outputs
225
  if torch.cuda.is_available():
226
  torch.cuda.empty_cache()