Added handler.py and requirements.txt
Browse files- handler.py +146 -0
- requirements.txt +153 -0
- test_handler.py +63 -0
handler.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
from typing import Dict, List, Any
|
3 |
+
import torch
|
4 |
+
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
5 |
+
from transformers import BitsAndBytesConfig
|
6 |
+
from sentence_transformers import SentenceTransformer, util
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
def _fake_generate(n: int = 3):
|
10 |
+
generate = list()
|
11 |
+
for _ in range(n):
|
12 |
+
generate.append(torch.IntTensor([103, 23, 48, 498, 536]))
|
13 |
+
return torch.stack(generate, dim=0)
|
14 |
+
|
15 |
+
|
16 |
+
class EndpointHandler():
|
17 |
+
def __init__(self, use_cuda: bool = False, test_mode: bool= False):
|
18 |
+
# Preload all the elements you are going to need at inference.
|
19 |
+
# pseudo:
|
20 |
+
# self.model= load_model(path)
|
21 |
+
self.test_mode = test_mode
|
22 |
+
self.MAXIMUM_PIXEL_VALUES = 3725568
|
23 |
+
self.quantization_config = BitsAndBytesConfig(
|
24 |
+
load_in_4bit=True,
|
25 |
+
bnb_4bit_compute_dtype=torch.float16
|
26 |
+
)
|
27 |
+
|
28 |
+
self.embedder = SentenceTransformer('all-mpnet-base-v2')
|
29 |
+
self.model_id = "llava-hf/llava-1.5-7b-hf"
|
30 |
+
self.processor = AutoProcessor.from_pretrained(self.model_id)
|
31 |
+
if use_cuda:
|
32 |
+
self.model = LlavaForConditionalGeneration.from_pretrained(
|
33 |
+
self.model_id,
|
34 |
+
quantization_config=self.quantization_config,
|
35 |
+
device_map="auto",
|
36 |
+
low_cpu_mem_usage=True,
|
37 |
+
)
|
38 |
+
else:
|
39 |
+
# Testing without CUDA device does not allow quantization
|
40 |
+
self.model = LlavaForConditionalGeneration.from_pretrained(
|
41 |
+
self.model_id,
|
42 |
+
device_map="auto",
|
43 |
+
low_cpu_mem_usage=True,
|
44 |
+
)
|
45 |
+
|
46 |
+
def text_to_image(self, image_batch, prompt):
|
47 |
+
prompt = f'USER: <image>\n{prompt}\nASSISTANT:'
|
48 |
+
prompt_batch = [prompt for _ in range(len(image_batch))]
|
49 |
+
|
50 |
+
inputs = self.processor(prompt_batch, images=image_batch, padding=True, return_tensors="pt")
|
51 |
+
|
52 |
+
batched_inputs: list[dict[str, torch.Tensor]] = list()
|
53 |
+
if inputs['pixel_values'].flatten().shape[0] > self.MAXIMUM_PIXEL_VALUES:
|
54 |
+
batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
|
55 |
+
i = 0
|
56 |
+
while i < len(inputs['pixel_values']):
|
57 |
+
batch['input_ids'].append(inputs['input_ids'][i])
|
58 |
+
batch['attention_mask'].append(inputs['attention_mask'][i])
|
59 |
+
batch['pixel_values'].append(inputs['pixel_values'][i])
|
60 |
+
|
61 |
+
if torch.cat(batch['pixel_values'], dim=0).flatten().shape[0] > self.MAXIMUM_PIXEL_VALUES:
|
62 |
+
print(f'[{i}/{len(inputs["pixel_values"])}] - Reached max pixel values for batch prediction on T4 '
|
63 |
+
f'16GB GPU. Will split in more batches')
|
64 |
+
# Remove the last added image because it's too big to process
|
65 |
+
batch['input_ids'].pop()
|
66 |
+
batch['attention_mask'].pop()
|
67 |
+
batch['pixel_values'].pop()
|
68 |
+
|
69 |
+
# transform lists to tensors
|
70 |
+
batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
|
71 |
+
batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
|
72 |
+
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
73 |
+
|
74 |
+
# Add to the batched_inputs
|
75 |
+
batched_inputs.append(batch)
|
76 |
+
batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
|
77 |
+
else:
|
78 |
+
i += 1
|
79 |
+
if i >= len(inputs['pixel_values']) and len(batch['input_ids']) > 0:
|
80 |
+
batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
|
81 |
+
batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
|
82 |
+
batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
|
83 |
+
|
84 |
+
# Add to the batched_inputs
|
85 |
+
batched_inputs.append(batch)
|
86 |
+
batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
|
87 |
+
else:
|
88 |
+
batched_inputs.append(inputs)
|
89 |
+
|
90 |
+
maurice_description = list()
|
91 |
+
maurice_embeddings = list()
|
92 |
+
for batch in batched_inputs:
|
93 |
+
# Load on device
|
94 |
+
batch['input_ids'] = batch['input_ids'].to(self.model.device)
|
95 |
+
batch['attention_mask'] = batch['attention_mask'].to(self.model.device)
|
96 |
+
batch['pixel_values'] = batch['pixel_values'].to(self.model.device)
|
97 |
+
# output = model.generate(**batch, max_new_tokens=500, temperature=0.3)
|
98 |
+
if self.test_mode:
|
99 |
+
output = _fake_generate(n=len(batch['input_ids']))
|
100 |
+
else:
|
101 |
+
output = self.model.generate(**batch, max_new_tokens=500)
|
102 |
+
# Unload GPU
|
103 |
+
batch['input_ids'].to('cpu')
|
104 |
+
batch['attention_mask'].to('cpu')
|
105 |
+
batch['pixel_values'].to('cpu')
|
106 |
+
|
107 |
+
generated_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
108 |
+
output = output.to('cpu')
|
109 |
+
|
110 |
+
for text in generated_text:
|
111 |
+
text_output = text.split("ASSISTANT:")[-1]
|
112 |
+
text_embeddings = self.embedder.encode(text_output)
|
113 |
+
maurice_description.append(text_output)
|
114 |
+
maurice_embeddings.append(text_embeddings)
|
115 |
+
|
116 |
+
return maurice_description, maurice_embeddings
|
117 |
+
|
118 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
119 |
+
"""
|
120 |
+
data args:
|
121 |
+
inputs (:obj: `str` | `PIL.Image` | `np.array`)
|
122 |
+
kwargs
|
123 |
+
Return:
|
124 |
+
A :obj:`list` | `dict`: will be serialized and returned
|
125 |
+
"""
|
126 |
+
|
127 |
+
images = data['inputs']
|
128 |
+
prompt = data['prompt']
|
129 |
+
|
130 |
+
pil_images = list()
|
131 |
+
for image in images:
|
132 |
+
pil_images.append(Image.open(io.BytesIO(image)))
|
133 |
+
|
134 |
+
output_text, output_embedded = self.text_to_image(pil_images, prompt)
|
135 |
+
|
136 |
+
result = list()
|
137 |
+
for text, embed in zip(output_text, output_embedded):
|
138 |
+
result.append(
|
139 |
+
dict(
|
140 |
+
maurice_description=text,
|
141 |
+
maurice_embedding=embed
|
142 |
+
)
|
143 |
+
)
|
144 |
+
return result
|
145 |
+
|
146 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.25.0
|
2 |
+
aiofiles==23.2.1
|
3 |
+
altair==5.2.0
|
4 |
+
annotated-types==0.6.0
|
5 |
+
anyio==3.7.1
|
6 |
+
appnope==0.1.3
|
7 |
+
argon2-cffi==23.1.0
|
8 |
+
argon2-cffi-bindings==21.2.0
|
9 |
+
arrow==1.3.0
|
10 |
+
asttokens==2.4.1
|
11 |
+
async-lru==2.0.4
|
12 |
+
attrs==23.1.0
|
13 |
+
Babel==2.14.0
|
14 |
+
beautifulsoup4==4.12.2
|
15 |
+
bitsandbytes==0.41.3.post2
|
16 |
+
bleach==6.1.0
|
17 |
+
certifi==2023.11.17
|
18 |
+
cffi==1.16.0
|
19 |
+
charset-normalizer==3.3.2
|
20 |
+
click==8.1.7
|
21 |
+
colorama==0.4.6
|
22 |
+
comm==0.2.0
|
23 |
+
contourpy==1.2.0
|
24 |
+
cycler==0.12.1
|
25 |
+
debugpy==1.8.0
|
26 |
+
decorator==5.1.1
|
27 |
+
defusedxml==0.7.1
|
28 |
+
executing==2.0.1
|
29 |
+
fastapi==0.105.0
|
30 |
+
fastjsonschema==2.19.0
|
31 |
+
ffmpy==0.3.1
|
32 |
+
filelock==3.13.1
|
33 |
+
fonttools==4.46.0
|
34 |
+
fqdn==1.5.1
|
35 |
+
fsspec==2023.12.2
|
36 |
+
gradio==4.10.0
|
37 |
+
gradio_client==0.7.3
|
38 |
+
h11==0.14.0
|
39 |
+
httpcore==1.0.2
|
40 |
+
httpx==0.25.2
|
41 |
+
huggingface-hub==0.19.4
|
42 |
+
idna==3.6
|
43 |
+
importlib-resources==6.1.1
|
44 |
+
ipykernel==6.27.1
|
45 |
+
ipython==8.18.1
|
46 |
+
ipywidgets==8.1.1
|
47 |
+
isoduration==20.11.0
|
48 |
+
jedi==0.19.1
|
49 |
+
Jinja2==3.1.2
|
50 |
+
json5==0.9.14
|
51 |
+
jsonpointer==2.4
|
52 |
+
jsonschema==4.20.0
|
53 |
+
jsonschema-specifications==2023.11.2
|
54 |
+
jupyter==1.0.0
|
55 |
+
jupyter-console==6.6.3
|
56 |
+
jupyter-events==0.9.0
|
57 |
+
jupyter-lsp==2.2.1
|
58 |
+
jupyter_client==8.6.0
|
59 |
+
jupyter_core==5.5.1
|
60 |
+
jupyter_server==2.12.1
|
61 |
+
jupyter_server_terminals==0.5.0
|
62 |
+
jupyterlab==4.0.9
|
63 |
+
jupyterlab-widgets==3.0.9
|
64 |
+
jupyterlab_pygments==0.3.0
|
65 |
+
jupyterlab_server==2.25.2
|
66 |
+
kiwisolver==1.4.5
|
67 |
+
markdown-it-py==3.0.0
|
68 |
+
MarkupSafe==2.1.3
|
69 |
+
matplotlib==3.8.2
|
70 |
+
matplotlib-inline==0.1.6
|
71 |
+
mdurl==0.1.2
|
72 |
+
mistune==3.0.2
|
73 |
+
mpmath==1.3.0
|
74 |
+
nbclient==0.9.0
|
75 |
+
nbconvert==7.13.0
|
76 |
+
nbformat==5.9.2
|
77 |
+
nest-asyncio==1.5.8
|
78 |
+
networkx==3.2.1
|
79 |
+
notebook==7.0.6
|
80 |
+
notebook_shim==0.2.3
|
81 |
+
numpy==1.26.2
|
82 |
+
orjson==3.9.10
|
83 |
+
overrides==7.4.0
|
84 |
+
packaging==23.2
|
85 |
+
pandas==2.1.4
|
86 |
+
pandocfilters==1.5.0
|
87 |
+
parso==0.8.3
|
88 |
+
pexpect==4.9.0
|
89 |
+
Pillow==10.1.0
|
90 |
+
platformdirs==4.1.0
|
91 |
+
prometheus-client==0.19.0
|
92 |
+
prompt-toolkit==3.0.43
|
93 |
+
psutil==5.9.6
|
94 |
+
ptyprocess==0.7.0
|
95 |
+
pure-eval==0.2.2
|
96 |
+
pycparser==2.21
|
97 |
+
pydantic==2.5.2
|
98 |
+
pydantic_core==2.14.5
|
99 |
+
pydub==0.25.1
|
100 |
+
Pygments==2.17.2
|
101 |
+
pyparsing==3.1.1
|
102 |
+
python-dateutil==2.8.2
|
103 |
+
python-json-logger==2.0.7
|
104 |
+
python-multipart==0.0.6
|
105 |
+
pytz==2023.3.post1
|
106 |
+
PyYAML==6.0.1
|
107 |
+
pyzmq==25.1.2
|
108 |
+
qtconsole==5.5.1
|
109 |
+
QtPy==2.4.1
|
110 |
+
referencing==0.32.0
|
111 |
+
regex==2023.10.3
|
112 |
+
requests==2.31.0
|
113 |
+
rfc3339-validator==0.1.4
|
114 |
+
rfc3986-validator==0.1.1
|
115 |
+
rich==13.7.0
|
116 |
+
rpds-py==0.13.2
|
117 |
+
safetensors==0.4.1
|
118 |
+
scipy==1.11.4
|
119 |
+
semantic-version==2.10.0
|
120 |
+
Send2Trash==1.8.2
|
121 |
+
shellingham==1.5.4
|
122 |
+
six==1.16.0
|
123 |
+
sniffio==1.3.0
|
124 |
+
soupsieve==2.5
|
125 |
+
stack-data==0.6.3
|
126 |
+
starlette==0.27.0
|
127 |
+
sympy==1.12
|
128 |
+
terminado==0.18.0
|
129 |
+
tinycss2==1.2.1
|
130 |
+
tokenizers==0.15.0
|
131 |
+
tomlkit==0.12.0
|
132 |
+
toolz==0.12.0
|
133 |
+
torch==2.1.2
|
134 |
+
torchaudio==2.1.2
|
135 |
+
torchvision==0.16.2
|
136 |
+
tornado==6.4
|
137 |
+
tqdm==4.66.1
|
138 |
+
traitlets==5.14.0
|
139 |
+
transformers
|
140 |
+
typer==0.9.0
|
141 |
+
types-python-dateutil==2.8.19.14
|
142 |
+
typing_extensions==4.9.0
|
143 |
+
tzdata==2023.3
|
144 |
+
uri-template==1.3.0
|
145 |
+
urllib3==2.1.0
|
146 |
+
uvicorn==0.24.0.post1
|
147 |
+
wcwidth==0.2.12
|
148 |
+
webcolors==1.13
|
149 |
+
webencodings==0.5.1
|
150 |
+
websocket-client==1.7.0
|
151 |
+
websockets==11.0.3
|
152 |
+
widgetsnbextension==4.0.9
|
153 |
+
sentence_transformers
|
test_handler.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from handler import EndpointHandler
|
2 |
+
from PIL import Image
|
3 |
+
import requests, json
|
4 |
+
|
5 |
+
# init handler
|
6 |
+
my_handler = EndpointHandler(use_cuda=False, test_mode=True)
|
7 |
+
|
8 |
+
|
9 |
+
# API_URL = "https://oncm9ojdmjwesag2.eu-west-1.aws.endpoints.huggingface.cloud"
|
10 |
+
|
11 |
+
# headers = {
|
12 |
+
# "Authorization": "Bearer MY_API_TOKEN",
|
13 |
+
# "Content-Type": "image/jpg"
|
14 |
+
# }
|
15 |
+
|
16 |
+
# def query(filename):
|
17 |
+
# with open(filename, "rb") as f:
|
18 |
+
# data = f.read()
|
19 |
+
# response = requests.request("POST", API_URL, headers=headers, data=data)
|
20 |
+
# return json.loads(response.content.decode("utf-8"))
|
21 |
+
|
22 |
+
# output = query("food.jpg")
|
23 |
+
|
24 |
+
# prepare sample payload
|
25 |
+
image_path = '/Users/francois/Documents/dev/Maurice/maurice/test_602.jpg'
|
26 |
+
|
27 |
+
with open(image_path, 'rb') as f:
|
28 |
+
img = f.read()
|
29 |
+
|
30 |
+
single_image = {
|
31 |
+
'inputs': [
|
32 |
+
img
|
33 |
+
],
|
34 |
+
'prompt': 'Describe the image'
|
35 |
+
}
|
36 |
+
|
37 |
+
multiple_images = {
|
38 |
+
'inputs': [
|
39 |
+
img, img, img
|
40 |
+
],
|
41 |
+
'prompt': 'Describe the image'
|
42 |
+
}
|
43 |
+
|
44 |
+
# test the handler
|
45 |
+
print(my_handler(single_image))
|
46 |
+
print(my_handler(multiple_images))
|
47 |
+
|
48 |
+
|
49 |
+
# non_holiday_payload = {"inputs": "I am quite excited how this will turn out", "date": "2022-08-08"}
|
50 |
+
# holiday_payload = {"inputs": "Today is a though day", "date": "2022-07-04"}
|
51 |
+
#
|
52 |
+
#
|
53 |
+
#
|
54 |
+
# # test the handler
|
55 |
+
# non_holiday_pred=my_handler(non_holiday_payload)
|
56 |
+
# holiday_payload=my_handler(holiday_payload)
|
57 |
+
#
|
58 |
+
# # show results
|
59 |
+
# print("non_holiday_pred", non_holiday_pred)
|
60 |
+
# print("holiday_payload", holiday_payload)
|
61 |
+
#
|
62 |
+
# # non_holiday_pred [{'label': 'joy', 'score': 0.9985942244529724}]
|
63 |
+
# # holiday_payload [{'label': 'happy', 'score': 1}]
|