hperkins commited on
Commit
babad84
·
verified ·
1 Parent(s): eea7d6f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +22 -21
handler.py CHANGED
@@ -4,38 +4,40 @@ import torch
4
  import json
5
  import os
6
 
7
- # Set the PyTorch CUDA allocation to use expandable segments to avoid memory fragmentation
8
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
9
 
10
  class EndpointHandler:
11
  def __init__(self, model_dir):
12
- # Load the model with memory-efficient settings
13
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
14
  model_dir,
15
- torch_dtype=torch.float16, # Using FP16 for reduced memory usage
16
- device_map="auto", # Automatically assigns model layers to the available GPU(s)
17
  low_cpu_mem_usage=True # Minimize CPU memory usage
18
  )
19
  self.processor = AutoProcessor.from_pretrained(model_dir)
20
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- self.model.to(self.device) # Move model to the appropriate device
22
  self.model.eval()
23
 
24
- # Enable gradient checkpointing for additional memory savings
25
  self.model.gradient_checkpointing_enable()
26
 
27
  def preprocess(self, request_data):
28
- # Extract the 'messages' from the incoming request
29
  messages = request_data.get('messages')
30
  if not messages:
31
  raise ValueError("Messages are required")
32
-
33
- # Process the vision inputs (images, videos) from the messages
34
  image_inputs, video_inputs = process_vision_info(messages)
 
35
  # Prepare text input for the chat model
36
  text = self.processor.apply_chat_template(
37
  messages, tokenize=False, add_generation_prompt=True
38
  )
 
39
  # Prepare inputs for the model (text + vision inputs)
40
  inputs = self.processor(
41
  text=[text],
@@ -45,30 +47,30 @@ class EndpointHandler:
45
  return_tensors="pt",
46
  )
47
 
48
- return inputs.to(self.device) # Move inputs to the correct device
49
 
50
  def inference(self, inputs):
51
- # Perform inference with the model, ensuring memory-efficient execution
52
  with torch.no_grad():
53
  generated_ids = self.model.generate(
54
  **inputs,
55
- max_new_tokens=64, # Reduce response length to conserve memory
56
- num_beams=1, # Set beam size to reduce memory usage
57
- max_batch_size=1 # Keep batch size small to save memory
58
  )
59
 
60
- # Trim generated output (remove input tokens from the generated output)
61
  generated_ids_trimmed = [
62
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
63
  ]
64
 
65
- # Clear CUDA cache after inference to release unused memory
66
  torch.cuda.empty_cache()
67
 
68
  return generated_ids_trimmed
69
 
70
  def postprocess(self, inference_output):
71
- # Decode generated output into human-readable text
72
  output_text = self.processor.batch_decode(
73
  inference_output, skip_special_tokens=True, clean_up_tokenization_spaces=False
74
  )
@@ -76,15 +78,14 @@ class EndpointHandler:
76
 
77
  def __call__(self, request):
78
  try:
79
- # Parse incoming JSON request
80
  request_data = json.loads(request)
81
- # Preprocess inputs (text, images, videos)
82
  inputs = self.preprocess(request_data)
83
  # Perform inference
84
  outputs = self.inference(inputs)
85
- # Postprocess model outputs
86
  result = self.postprocess(outputs)
87
  return json.dumps({"result": result})
88
  except Exception as e:
89
- # Handle any errors during execution
90
  return json.dumps({"error": str(e)})
 
4
  import json
5
  import os
6
 
7
+ # Set the environment variable to handle memory fragmentation
8
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
9
 
10
  class EndpointHandler:
11
  def __init__(self, model_dir):
12
+ # Load the model with automatic device dispatching
13
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
14
  model_dir,
15
+ torch_dtype=torch.float16, # Use FP16 for memory efficiency
16
+ device_map="auto", # Auto device dispatch across available GPUs
17
  low_cpu_mem_usage=True # Minimize CPU memory usage
18
  )
19
  self.processor = AutoProcessor.from_pretrained(model_dir)
20
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ # No need to move model to device manually; device_map handles it
22
  self.model.eval()
23
 
24
+ # Enable gradient checkpointing for further memory optimization
25
  self.model.gradient_checkpointing_enable()
26
 
27
  def preprocess(self, request_data):
28
+ # Handle the request and extract vision data (images, videos)
29
  messages = request_data.get('messages')
30
  if not messages:
31
  raise ValueError("Messages are required")
32
+
33
+ # Process vision input from the messages
34
  image_inputs, video_inputs = process_vision_info(messages)
35
+
36
  # Prepare text input for the chat model
37
  text = self.processor.apply_chat_template(
38
  messages, tokenize=False, add_generation_prompt=True
39
  )
40
+
41
  # Prepare inputs for the model (text + vision inputs)
42
  inputs = self.processor(
43
  text=[text],
 
47
  return_tensors="pt",
48
  )
49
 
50
+ return inputs.to(self.device)
51
 
52
  def inference(self, inputs):
53
+ # Perform inference using memory-efficient settings
54
  with torch.no_grad():
55
  generated_ids = self.model.generate(
56
  **inputs,
57
+ max_new_tokens=64, # Reduce max tokens for memory optimization
58
+ num_beams=1, # Reduce beam size to save memory
59
+ max_batch_size=1 # Keep batch size small to minimize memory usage
60
  )
61
 
62
+ # Trim the output by removing input tokens from the generated output
63
  generated_ids_trimmed = [
64
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
65
  ]
66
 
67
+ # Clear CUDA memory cache after inference to free up memory
68
  torch.cuda.empty_cache()
69
 
70
  return generated_ids_trimmed
71
 
72
  def postprocess(self, inference_output):
73
+ # Decode the model's output into human-readable text
74
  output_text = self.processor.batch_decode(
75
  inference_output, skip_special_tokens=True, clean_up_tokenization_spaces=False
76
  )
 
78
 
79
  def __call__(self, request):
80
  try:
81
+ # Parse the JSON request
82
  request_data = json.loads(request)
83
+ # Preprocess the input data
84
  inputs = self.preprocess(request_data)
85
  # Perform inference
86
  outputs = self.inference(inputs)
87
+ # Postprocess the output and return the result
88
  result = self.postprocess(outputs)
89
  return json.dumps({"result": result})
90
  except Exception as e:
 
91
  return json.dumps({"error": str(e)})