hperkins commited on
Commit
057b8f0
1 Parent(s): b8812a1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +20 -4
handler.py CHANGED
@@ -5,15 +5,21 @@ import json
5
 
6
  class Qwen2VL7bHandler:
7
  def __init__(self):
8
- # Load the model and processor for Qwen2-VL-7B
9
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
10
- "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
 
 
 
11
  )
12
  self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
13
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  self.model.to(self.device)
15
  self.model.eval()
16
 
 
 
 
17
  def preprocess(self, request_data):
18
  # Handle image and video input from the request
19
  messages = request_data.get('messages')
@@ -42,12 +48,22 @@ class Qwen2VL7bHandler:
42
  def inference(self, inputs):
43
  # Perform inference with the model
44
  with torch.no_grad():
45
- generated_ids = self.model.generate(**inputs, max_new_tokens=128)
 
 
 
 
 
 
46
 
47
  # Trim the output (remove input tokens from generated output)
48
  generated_ids_trimmed = [
49
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
50
  ]
 
 
 
 
51
  return generated_ids_trimmed
52
 
53
  def postprocess(self, inference_output):
 
5
 
6
  class Qwen2VL7bHandler:
7
  def __init__(self):
8
+ # Load the model and processor for Qwen2-VL-7B with FP16 precision and flash attention enabled
9
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
10
+ "Qwen/Qwen2-VL-7B-Instruct",
11
+ torch_dtype=torch.float16,
12
+ attn_implementation="flash_attention_2", # Enable flash attention for efficiency
13
+ device_map="auto" # Automatically assign devices for model
14
  )
15
  self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
16
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  self.model.to(self.device)
18
  self.model.eval()
19
 
20
+ # Enable gradient checkpointing to save memory during inference
21
+ self.model.gradient_checkpointing_enable()
22
+
23
  def preprocess(self, request_data):
24
  # Handle image and video input from the request
25
  messages = request_data.get('messages')
 
48
  def inference(self, inputs):
49
  # Perform inference with the model
50
  with torch.no_grad():
51
+ # Generate the output with memory-efficient settings
52
+ generated_ids = self.model.generate(
53
+ **inputs,
54
+ max_new_tokens=128, # Limit output length
55
+ num_beams=1, # Set beam size to reduce memory consumption
56
+ max_batch_size=1 # Set batch size to 1 for memory optimization
57
+ )
58
 
59
  # Trim the output (remove input tokens from generated output)
60
  generated_ids_trimmed = [
61
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
62
  ]
63
+
64
+ # Clear the CUDA cache after inference to release unused memory
65
+ torch.cuda.empty_cache()
66
+
67
  return generated_ids_trimmed
68
 
69
  def postprocess(self, inference_output):