fthor commited on
Commit
270cb79
1 Parent(s): 5909e50

Testing by reloading the model on __call__ with quantization in case GPU is not available during __init__

Browse files
Files changed (1) hide show
  1. handler.py +17 -8
handler.py CHANGED
@@ -14,10 +14,13 @@ def _fake_generate(n: int = 3):
14
 
15
 
16
  class EndpointHandler():
17
- def __init__(self, path="", use_cuda: bool = True, test_mode: bool= False):
18
  # Preload all the elements you are going to need at inference.
19
  # pseudo:
20
  # self.model= load_model(path)
 
 
 
21
  self.test_mode = test_mode
22
  self.MAXIMUM_PIXEL_VALUES = 3725568
23
  self.quantization_config = BitsAndBytesConfig(
@@ -29,12 +32,7 @@ class EndpointHandler():
29
  self.model_id = "llava-hf/llava-1.5-7b-hf"
30
  self.processor = AutoProcessor.from_pretrained(self.model_id)
31
  if use_cuda:
32
- self.model = LlavaForConditionalGeneration.from_pretrained(
33
- self.model_id,
34
- quantization_config=self.quantization_config,
35
- device_map="auto",
36
- low_cpu_mem_usage=True,
37
- )
38
  else:
39
  # Testing without CUDA device does not allow quantization
40
  self.model = LlavaForConditionalGeneration.from_pretrained(
@@ -43,6 +41,15 @@ class EndpointHandler():
43
  low_cpu_mem_usage=True,
44
  )
45
 
 
 
 
 
 
 
 
 
 
46
  def text_to_image(self, image_batch, prompt):
47
  prompt = f'USER: <image>\n{prompt}\nASSISTANT:'
48
  prompt_batch = [prompt for _ in range(len(image_batch))]
@@ -123,7 +130,9 @@ class EndpointHandler():
123
  Return:
124
  A :obj:`list` | `dict`: will be serialized and returned
125
  """
126
-
 
 
127
  images = data['inputs']
128
  prompt = data['prompt']
129
 
 
14
 
15
 
16
  class EndpointHandler():
17
+ def __init__(self, path="", test_mode: bool= False):
18
  # Preload all the elements you are going to need at inference.
19
  # pseudo:
20
  # self.model= load_model(path)
21
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
+ use_cuda = self.device == 'cuda'
23
+
24
  self.test_mode = test_mode
25
  self.MAXIMUM_PIXEL_VALUES = 3725568
26
  self.quantization_config = BitsAndBytesConfig(
 
32
  self.model_id = "llava-hf/llava-1.5-7b-hf"
33
  self.processor = AutoProcessor.from_pretrained(self.model_id)
34
  if use_cuda:
35
+ self.load_quantized()
 
 
 
 
 
36
  else:
37
  # Testing without CUDA device does not allow quantization
38
  self.model = LlavaForConditionalGeneration.from_pretrained(
 
41
  low_cpu_mem_usage=True,
42
  )
43
 
44
+ def load_quantized(self):
45
+ print('Loading model with quantization')
46
+ self.model = LlavaForConditionalGeneration.from_pretrained(
47
+ self.model_id,
48
+ quantization_config=self.quantization_config,
49
+ device_map="auto",
50
+ low_cpu_mem_usage=True,
51
+ )
52
+
53
  def text_to_image(self, image_batch, prompt):
54
  prompt = f'USER: <image>\n{prompt}\nASSISTANT:'
55
  prompt_batch = [prompt for _ in range(len(image_batch))]
 
130
  Return:
131
  A :obj:`list` | `dict`: will be serialized and returned
132
  """
133
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
134
+ if device != self.device and device == 'cuda':
135
+ self.load_quantized()
136
  images = data['inputs']
137
  prompt = data['prompt']
138