fthor commited on
Commit
4b86165
1 Parent(s): 4f40293

Added handler.py and requirements.txt

Browse files
Files changed (3) hide show
  1. handler.py +146 -0
  2. requirements.txt +153 -0
  3. test_handler.py +63 -0
handler.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import Dict, List, Any
3
+ import torch
4
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
5
+ from transformers import BitsAndBytesConfig
6
+ from sentence_transformers import SentenceTransformer, util
7
+ from PIL import Image
8
+
9
+ def _fake_generate(n: int = 3):
10
+ generate = list()
11
+ for _ in range(n):
12
+ generate.append(torch.IntTensor([103, 23, 48, 498, 536]))
13
+ return torch.stack(generate, dim=0)
14
+
15
+
16
+ class EndpointHandler():
17
+ def __init__(self, use_cuda: bool = False, test_mode: bool= False):
18
+ # Preload all the elements you are going to need at inference.
19
+ # pseudo:
20
+ # self.model= load_model(path)
21
+ self.test_mode = test_mode
22
+ self.MAXIMUM_PIXEL_VALUES = 3725568
23
+ self.quantization_config = BitsAndBytesConfig(
24
+ load_in_4bit=True,
25
+ bnb_4bit_compute_dtype=torch.float16
26
+ )
27
+
28
+ self.embedder = SentenceTransformer('all-mpnet-base-v2')
29
+ self.model_id = "llava-hf/llava-1.5-7b-hf"
30
+ self.processor = AutoProcessor.from_pretrained(self.model_id)
31
+ if use_cuda:
32
+ self.model = LlavaForConditionalGeneration.from_pretrained(
33
+ self.model_id,
34
+ quantization_config=self.quantization_config,
35
+ device_map="auto",
36
+ low_cpu_mem_usage=True,
37
+ )
38
+ else:
39
+ # Testing without CUDA device does not allow quantization
40
+ self.model = LlavaForConditionalGeneration.from_pretrained(
41
+ self.model_id,
42
+ device_map="auto",
43
+ low_cpu_mem_usage=True,
44
+ )
45
+
46
+ def text_to_image(self, image_batch, prompt):
47
+ prompt = f'USER: <image>\n{prompt}\nASSISTANT:'
48
+ prompt_batch = [prompt for _ in range(len(image_batch))]
49
+
50
+ inputs = self.processor(prompt_batch, images=image_batch, padding=True, return_tensors="pt")
51
+
52
+ batched_inputs: list[dict[str, torch.Tensor]] = list()
53
+ if inputs['pixel_values'].flatten().shape[0] > self.MAXIMUM_PIXEL_VALUES:
54
+ batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
55
+ i = 0
56
+ while i < len(inputs['pixel_values']):
57
+ batch['input_ids'].append(inputs['input_ids'][i])
58
+ batch['attention_mask'].append(inputs['attention_mask'][i])
59
+ batch['pixel_values'].append(inputs['pixel_values'][i])
60
+
61
+ if torch.cat(batch['pixel_values'], dim=0).flatten().shape[0] > self.MAXIMUM_PIXEL_VALUES:
62
+ print(f'[{i}/{len(inputs["pixel_values"])}] - Reached max pixel values for batch prediction on T4 '
63
+ f'16GB GPU. Will split in more batches')
64
+ # Remove the last added image because it's too big to process
65
+ batch['input_ids'].pop()
66
+ batch['attention_mask'].pop()
67
+ batch['pixel_values'].pop()
68
+
69
+ # transform lists to tensors
70
+ batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
71
+ batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
72
+ batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
73
+
74
+ # Add to the batched_inputs
75
+ batched_inputs.append(batch)
76
+ batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
77
+ else:
78
+ i += 1
79
+ if i >= len(inputs['pixel_values']) and len(batch['input_ids']) > 0:
80
+ batch['input_ids'] = torch.stack(batch['input_ids'], dim=0)
81
+ batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=0)
82
+ batch['pixel_values'] = torch.stack(batch['pixel_values'], dim=0)
83
+
84
+ # Add to the batched_inputs
85
+ batched_inputs.append(batch)
86
+ batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
87
+ else:
88
+ batched_inputs.append(inputs)
89
+
90
+ maurice_description = list()
91
+ maurice_embeddings = list()
92
+ for batch in batched_inputs:
93
+ # Load on device
94
+ batch['input_ids'] = batch['input_ids'].to(self.model.device)
95
+ batch['attention_mask'] = batch['attention_mask'].to(self.model.device)
96
+ batch['pixel_values'] = batch['pixel_values'].to(self.model.device)
97
+ # output = model.generate(**batch, max_new_tokens=500, temperature=0.3)
98
+ if self.test_mode:
99
+ output = _fake_generate(n=len(batch['input_ids']))
100
+ else:
101
+ output = self.model.generate(**batch, max_new_tokens=500)
102
+ # Unload GPU
103
+ batch['input_ids'].to('cpu')
104
+ batch['attention_mask'].to('cpu')
105
+ batch['pixel_values'].to('cpu')
106
+
107
+ generated_text = self.processor.batch_decode(output, skip_special_tokens=True)
108
+ output = output.to('cpu')
109
+
110
+ for text in generated_text:
111
+ text_output = text.split("ASSISTANT:")[-1]
112
+ text_embeddings = self.embedder.encode(text_output)
113
+ maurice_description.append(text_output)
114
+ maurice_embeddings.append(text_embeddings)
115
+
116
+ return maurice_description, maurice_embeddings
117
+
118
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
119
+ """
120
+ data args:
121
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
122
+ kwargs
123
+ Return:
124
+ A :obj:`list` | `dict`: will be serialized and returned
125
+ """
126
+
127
+ images = data['inputs']
128
+ prompt = data['prompt']
129
+
130
+ pil_images = list()
131
+ for image in images:
132
+ pil_images.append(Image.open(io.BytesIO(image)))
133
+
134
+ output_text, output_embedded = self.text_to_image(pil_images, prompt)
135
+
136
+ result = list()
137
+ for text, embed in zip(output_text, output_embedded):
138
+ result.append(
139
+ dict(
140
+ maurice_description=text,
141
+ maurice_embedding=embed
142
+ )
143
+ )
144
+ return result
145
+
146
+
requirements.txt ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.25.0
2
+ aiofiles==23.2.1
3
+ altair==5.2.0
4
+ annotated-types==0.6.0
5
+ anyio==3.7.1
6
+ appnope==0.1.3
7
+ argon2-cffi==23.1.0
8
+ argon2-cffi-bindings==21.2.0
9
+ arrow==1.3.0
10
+ asttokens==2.4.1
11
+ async-lru==2.0.4
12
+ attrs==23.1.0
13
+ Babel==2.14.0
14
+ beautifulsoup4==4.12.2
15
+ bitsandbytes==0.41.3.post2
16
+ bleach==6.1.0
17
+ certifi==2023.11.17
18
+ cffi==1.16.0
19
+ charset-normalizer==3.3.2
20
+ click==8.1.7
21
+ colorama==0.4.6
22
+ comm==0.2.0
23
+ contourpy==1.2.0
24
+ cycler==0.12.1
25
+ debugpy==1.8.0
26
+ decorator==5.1.1
27
+ defusedxml==0.7.1
28
+ executing==2.0.1
29
+ fastapi==0.105.0
30
+ fastjsonschema==2.19.0
31
+ ffmpy==0.3.1
32
+ filelock==3.13.1
33
+ fonttools==4.46.0
34
+ fqdn==1.5.1
35
+ fsspec==2023.12.2
36
+ gradio==4.10.0
37
+ gradio_client==0.7.3
38
+ h11==0.14.0
39
+ httpcore==1.0.2
40
+ httpx==0.25.2
41
+ huggingface-hub==0.19.4
42
+ idna==3.6
43
+ importlib-resources==6.1.1
44
+ ipykernel==6.27.1
45
+ ipython==8.18.1
46
+ ipywidgets==8.1.1
47
+ isoduration==20.11.0
48
+ jedi==0.19.1
49
+ Jinja2==3.1.2
50
+ json5==0.9.14
51
+ jsonpointer==2.4
52
+ jsonschema==4.20.0
53
+ jsonschema-specifications==2023.11.2
54
+ jupyter==1.0.0
55
+ jupyter-console==6.6.3
56
+ jupyter-events==0.9.0
57
+ jupyter-lsp==2.2.1
58
+ jupyter_client==8.6.0
59
+ jupyter_core==5.5.1
60
+ jupyter_server==2.12.1
61
+ jupyter_server_terminals==0.5.0
62
+ jupyterlab==4.0.9
63
+ jupyterlab-widgets==3.0.9
64
+ jupyterlab_pygments==0.3.0
65
+ jupyterlab_server==2.25.2
66
+ kiwisolver==1.4.5
67
+ markdown-it-py==3.0.0
68
+ MarkupSafe==2.1.3
69
+ matplotlib==3.8.2
70
+ matplotlib-inline==0.1.6
71
+ mdurl==0.1.2
72
+ mistune==3.0.2
73
+ mpmath==1.3.0
74
+ nbclient==0.9.0
75
+ nbconvert==7.13.0
76
+ nbformat==5.9.2
77
+ nest-asyncio==1.5.8
78
+ networkx==3.2.1
79
+ notebook==7.0.6
80
+ notebook_shim==0.2.3
81
+ numpy==1.26.2
82
+ orjson==3.9.10
83
+ overrides==7.4.0
84
+ packaging==23.2
85
+ pandas==2.1.4
86
+ pandocfilters==1.5.0
87
+ parso==0.8.3
88
+ pexpect==4.9.0
89
+ Pillow==10.1.0
90
+ platformdirs==4.1.0
91
+ prometheus-client==0.19.0
92
+ prompt-toolkit==3.0.43
93
+ psutil==5.9.6
94
+ ptyprocess==0.7.0
95
+ pure-eval==0.2.2
96
+ pycparser==2.21
97
+ pydantic==2.5.2
98
+ pydantic_core==2.14.5
99
+ pydub==0.25.1
100
+ Pygments==2.17.2
101
+ pyparsing==3.1.1
102
+ python-dateutil==2.8.2
103
+ python-json-logger==2.0.7
104
+ python-multipart==0.0.6
105
+ pytz==2023.3.post1
106
+ PyYAML==6.0.1
107
+ pyzmq==25.1.2
108
+ qtconsole==5.5.1
109
+ QtPy==2.4.1
110
+ referencing==0.32.0
111
+ regex==2023.10.3
112
+ requests==2.31.0
113
+ rfc3339-validator==0.1.4
114
+ rfc3986-validator==0.1.1
115
+ rich==13.7.0
116
+ rpds-py==0.13.2
117
+ safetensors==0.4.1
118
+ scipy==1.11.4
119
+ semantic-version==2.10.0
120
+ Send2Trash==1.8.2
121
+ shellingham==1.5.4
122
+ six==1.16.0
123
+ sniffio==1.3.0
124
+ soupsieve==2.5
125
+ stack-data==0.6.3
126
+ starlette==0.27.0
127
+ sympy==1.12
128
+ terminado==0.18.0
129
+ tinycss2==1.2.1
130
+ tokenizers==0.15.0
131
+ tomlkit==0.12.0
132
+ toolz==0.12.0
133
+ torch==2.1.2
134
+ torchaudio==2.1.2
135
+ torchvision==0.16.2
136
+ tornado==6.4
137
+ tqdm==4.66.1
138
+ traitlets==5.14.0
139
+ transformers
140
+ typer==0.9.0
141
+ types-python-dateutil==2.8.19.14
142
+ typing_extensions==4.9.0
143
+ tzdata==2023.3
144
+ uri-template==1.3.0
145
+ urllib3==2.1.0
146
+ uvicorn==0.24.0.post1
147
+ wcwidth==0.2.12
148
+ webcolors==1.13
149
+ webencodings==0.5.1
150
+ websocket-client==1.7.0
151
+ websockets==11.0.3
152
+ widgetsnbextension==4.0.9
153
+ sentence_transformers
test_handler.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+ from PIL import Image
3
+ import requests, json
4
+
5
+ # init handler
6
+ my_handler = EndpointHandler(use_cuda=False, test_mode=True)
7
+
8
+
9
+ # API_URL = "https://oncm9ojdmjwesag2.eu-west-1.aws.endpoints.huggingface.cloud"
10
+
11
+ # headers = {
12
+ # "Authorization": "Bearer MY_API_TOKEN",
13
+ # "Content-Type": "image/jpg"
14
+ # }
15
+
16
+ # def query(filename):
17
+ # with open(filename, "rb") as f:
18
+ # data = f.read()
19
+ # response = requests.request("POST", API_URL, headers=headers, data=data)
20
+ # return json.loads(response.content.decode("utf-8"))
21
+
22
+ # output = query("food.jpg")
23
+
24
+ # prepare sample payload
25
+ image_path = '/Users/francois/Documents/dev/Maurice/maurice/test_602.jpg'
26
+
27
+ with open(image_path, 'rb') as f:
28
+ img = f.read()
29
+
30
+ single_image = {
31
+ 'inputs': [
32
+ img
33
+ ],
34
+ 'prompt': 'Describe the image'
35
+ }
36
+
37
+ multiple_images = {
38
+ 'inputs': [
39
+ img, img, img
40
+ ],
41
+ 'prompt': 'Describe the image'
42
+ }
43
+
44
+ # test the handler
45
+ print(my_handler(single_image))
46
+ print(my_handler(multiple_images))
47
+
48
+
49
+ # non_holiday_payload = {"inputs": "I am quite excited how this will turn out", "date": "2022-08-08"}
50
+ # holiday_payload = {"inputs": "Today is a though day", "date": "2022-07-04"}
51
+ #
52
+ #
53
+ #
54
+ # # test the handler
55
+ # non_holiday_pred=my_handler(non_holiday_payload)
56
+ # holiday_payload=my_handler(holiday_payload)
57
+ #
58
+ # # show results
59
+ # print("non_holiday_pred", non_holiday_pred)
60
+ # print("holiday_payload", holiday_payload)
61
+ #
62
+ # # non_holiday_pred [{'label': 'joy', 'score': 0.9985942244529724}]
63
+ # # holiday_payload [{'label': 'happy', 'score': 1}]