Testing by reloading the model on __call__ with quantization in case GPU is not available during __init__
Browse files- handler.py +17 -8
handler.py
CHANGED
@@ -14,10 +14,13 @@ def _fake_generate(n: int = 3):
|
|
14 |
|
15 |
|
16 |
class EndpointHandler():
|
17 |
-
def __init__(self, path="",
|
18 |
# Preload all the elements you are going to need at inference.
|
19 |
# pseudo:
|
20 |
# self.model= load_model(path)
|
|
|
|
|
|
|
21 |
self.test_mode = test_mode
|
22 |
self.MAXIMUM_PIXEL_VALUES = 3725568
|
23 |
self.quantization_config = BitsAndBytesConfig(
|
@@ -29,12 +32,7 @@ class EndpointHandler():
|
|
29 |
self.model_id = "llava-hf/llava-1.5-7b-hf"
|
30 |
self.processor = AutoProcessor.from_pretrained(self.model_id)
|
31 |
if use_cuda:
|
32 |
-
self.
|
33 |
-
self.model_id,
|
34 |
-
quantization_config=self.quantization_config,
|
35 |
-
device_map="auto",
|
36 |
-
low_cpu_mem_usage=True,
|
37 |
-
)
|
38 |
else:
|
39 |
# Testing without CUDA device does not allow quantization
|
40 |
self.model = LlavaForConditionalGeneration.from_pretrained(
|
@@ -43,6 +41,15 @@ class EndpointHandler():
|
|
43 |
low_cpu_mem_usage=True,
|
44 |
)
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
def text_to_image(self, image_batch, prompt):
|
47 |
prompt = f'USER: <image>\n{prompt}\nASSISTANT:'
|
48 |
prompt_batch = [prompt for _ in range(len(image_batch))]
|
@@ -123,7 +130,9 @@ class EndpointHandler():
|
|
123 |
Return:
|
124 |
A :obj:`list` | `dict`: will be serialized and returned
|
125 |
"""
|
126 |
-
|
|
|
|
|
127 |
images = data['inputs']
|
128 |
prompt = data['prompt']
|
129 |
|
|
|
14 |
|
15 |
|
16 |
class EndpointHandler():
|
17 |
+
def __init__(self, path="", test_mode: bool= False):
|
18 |
# Preload all the elements you are going to need at inference.
|
19 |
# pseudo:
|
20 |
# self.model= load_model(path)
|
21 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
22 |
+
use_cuda = self.device == 'cuda'
|
23 |
+
|
24 |
self.test_mode = test_mode
|
25 |
self.MAXIMUM_PIXEL_VALUES = 3725568
|
26 |
self.quantization_config = BitsAndBytesConfig(
|
|
|
32 |
self.model_id = "llava-hf/llava-1.5-7b-hf"
|
33 |
self.processor = AutoProcessor.from_pretrained(self.model_id)
|
34 |
if use_cuda:
|
35 |
+
self.load_quantized()
|
|
|
|
|
|
|
|
|
|
|
36 |
else:
|
37 |
# Testing without CUDA device does not allow quantization
|
38 |
self.model = LlavaForConditionalGeneration.from_pretrained(
|
|
|
41 |
low_cpu_mem_usage=True,
|
42 |
)
|
43 |
|
44 |
+
def load_quantized(self):
|
45 |
+
print('Loading model with quantization')
|
46 |
+
self.model = LlavaForConditionalGeneration.from_pretrained(
|
47 |
+
self.model_id,
|
48 |
+
quantization_config=self.quantization_config,
|
49 |
+
device_map="auto",
|
50 |
+
low_cpu_mem_usage=True,
|
51 |
+
)
|
52 |
+
|
53 |
def text_to_image(self, image_batch, prompt):
|
54 |
prompt = f'USER: <image>\n{prompt}\nASSISTANT:'
|
55 |
prompt_batch = [prompt for _ in range(len(image_batch))]
|
|
|
130 |
Return:
|
131 |
A :obj:`list` | `dict`: will be serialized and returned
|
132 |
"""
|
133 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
134 |
+
if device != self.device and device == 'cuda':
|
135 |
+
self.load_quantized()
|
136 |
images = data['inputs']
|
137 |
prompt = data['prompt']
|
138 |
|