Update handler.py
Browse files- handler.py +20 -4
handler.py
CHANGED
@@ -5,15 +5,21 @@ import json
|
|
5 |
|
6 |
class Qwen2VL7bHandler:
|
7 |
def __init__(self):
|
8 |
-
# Load the model and processor for Qwen2-VL-7B
|
9 |
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
10 |
-
"Qwen/Qwen2-VL-7B-Instruct",
|
|
|
|
|
|
|
11 |
)
|
12 |
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
13 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
self.model.to(self.device)
|
15 |
self.model.eval()
|
16 |
|
|
|
|
|
|
|
17 |
def preprocess(self, request_data):
|
18 |
# Handle image and video input from the request
|
19 |
messages = request_data.get('messages')
|
@@ -42,12 +48,22 @@ class Qwen2VL7bHandler:
|
|
42 |
def inference(self, inputs):
|
43 |
# Perform inference with the model
|
44 |
with torch.no_grad():
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
# Trim the output (remove input tokens from generated output)
|
48 |
generated_ids_trimmed = [
|
49 |
-
out_ids[len(in_ids)
|
50 |
]
|
|
|
|
|
|
|
|
|
51 |
return generated_ids_trimmed
|
52 |
|
53 |
def postprocess(self, inference_output):
|
|
|
5 |
|
6 |
class Qwen2VL7bHandler:
|
7 |
def __init__(self):
|
8 |
+
# Load the model and processor for Qwen2-VL-7B with FP16 precision and flash attention enabled
|
9 |
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
10 |
+
"Qwen/Qwen2-VL-7B-Instruct",
|
11 |
+
torch_dtype=torch.float16,
|
12 |
+
attn_implementation="flash_attention_2", # Enable flash attention for efficiency
|
13 |
+
device_map="auto" # Automatically assign devices for model
|
14 |
)
|
15 |
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
16 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
self.model.to(self.device)
|
18 |
self.model.eval()
|
19 |
|
20 |
+
# Enable gradient checkpointing to save memory during inference
|
21 |
+
self.model.gradient_checkpointing_enable()
|
22 |
+
|
23 |
def preprocess(self, request_data):
|
24 |
# Handle image and video input from the request
|
25 |
messages = request_data.get('messages')
|
|
|
48 |
def inference(self, inputs):
|
49 |
# Perform inference with the model
|
50 |
with torch.no_grad():
|
51 |
+
# Generate the output with memory-efficient settings
|
52 |
+
generated_ids = self.model.generate(
|
53 |
+
**inputs,
|
54 |
+
max_new_tokens=128, # Limit output length
|
55 |
+
num_beams=1, # Set beam size to reduce memory consumption
|
56 |
+
max_batch_size=1 # Set batch size to 1 for memory optimization
|
57 |
+
)
|
58 |
|
59 |
# Trim the output (remove input tokens from generated output)
|
60 |
generated_ids_trimmed = [
|
61 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
62 |
]
|
63 |
+
|
64 |
+
# Clear the CUDA cache after inference to release unused memory
|
65 |
+
torch.cuda.empty_cache()
|
66 |
+
|
67 |
return generated_ids_trimmed
|
68 |
|
69 |
def postprocess(self, inference_output):
|