add custom handler
Browse files- .gitignore +13 -0
- handler.py +18 -24
- model-00001-of-00004.safetensors +0 -3
- model-00002-of-00004.safetensors +0 -3
- model-00003-of-00004.safetensors +0 -3
- model-00004-of-00004.safetensors +0 -3
- requirements.txt +6 -0
- test_handler.py +28 -0
.gitignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# gitignore
|
2 |
+
|
3 |
+
.venv
|
4 |
+
test_image_1.jpg
|
5 |
+
test_image_2.jpg
|
6 |
+
test_image_3.jpg
|
7 |
+
test_image_4.jpg
|
8 |
+
test_image_5.jpg
|
9 |
+
test_image_6.jpg
|
10 |
+
test_image_7.jpg
|
11 |
+
test_image_8.jpg
|
12 |
+
test_image_9.jpg
|
13 |
+
|
handler.py
CHANGED
@@ -1,33 +1,27 @@
|
|
1 |
-
from PIL import Image
|
2 |
import torch
|
|
|
3 |
from transformers import AutoModel, AutoTokenizer
|
4 |
|
5 |
-
class
|
6 |
-
def __init__(self):
|
7 |
-
# Load the model and tokenizer with appropriate weights
|
8 |
self.model = AutoModel.from_pretrained(
|
9 |
-
|
10 |
trust_remote_code=True,
|
11 |
-
attn_implementation='sdpa',
|
12 |
-
torch_dtype=torch.
|
13 |
-
)
|
14 |
-
|
15 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
16 |
|
17 |
-
def
|
18 |
-
|
19 |
-
|
20 |
-
question = inputs.get("question", "Extract all data in the image. Be extremely careful to ensure that you don't miss anything. It's imperative that you extract and digitize everything on that page.")
|
21 |
msgs = [{'role': 'user', 'content': [image, question]}]
|
22 |
-
return msgs
|
23 |
-
|
24 |
-
def inference(self, msgs):
|
25 |
-
# Run inference on the model
|
26 |
-
result = self.model.chat(image=None, msgs=msgs, tokenizer=self.tokenizer)
|
27 |
-
return result
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
32 |
|
33 |
-
|
|
|
|
|
1 |
import torch
|
2 |
+
from PIL import Image
|
3 |
from transformers import AutoModel, AutoTokenizer
|
4 |
|
5 |
+
class EndpointHandler:
|
6 |
+
def __init__(self, path):
|
|
|
7 |
self.model = AutoModel.from_pretrained(
|
8 |
+
path,
|
9 |
trust_remote_code=True,
|
10 |
+
attn_implementation='sdpa',
|
11 |
+
torch_dtype=torch.float16
|
12 |
+
)
|
13 |
+
self.model = self.model.eval().cuda()
|
14 |
+
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
|
15 |
|
16 |
+
def __call__(self, data):
|
17 |
+
image = Image.open(data['inputs']['image'].file).convert('RGB')
|
18 |
+
question = data['inputs'].get("question", "Extract all data in the image. Be extremely careful to ensure that you don't miss anything. It's imperative that you extract and digitize everything on that page.")
|
|
|
19 |
msgs = [{'role': 'user', 'content': [image, question]}]
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
res = self.model.chat(
|
22 |
+
image=None,
|
23 |
+
msgs=msgs,
|
24 |
+
tokenizer=self.tokenizer
|
25 |
+
)
|
26 |
|
27 |
+
return {"generated_text": res}
|
model-00001-of-00004.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:327b5e63ac7729d00b0b17233a0d648a623fc080cc1069dce9cb79478cdaf2b5
|
3 |
-
size 4874808328
|
|
|
|
|
|
|
|
model-00002-of-00004.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:e6d7af1f12ba8c72fec16ba3f3951df5643bef24de8b0af5baec67290c07ca4f
|
3 |
-
size 4932751496
|
|
|
|
|
|
|
|
model-00003-of-00004.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:e9ba1332c74a92b9c8af5fbc2c2da1baa8cd5ed4fd8056b98a586ff4bc3e54cf
|
3 |
-
size 4330865648
|
|
|
|
|
|
|
|
model-00004-of-00004.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:16fcfafa8caaa67e16f9dae7cc79bf081eb59d4e248723a8e79d9678e42732fe
|
3 |
-
size 2060017080
|
|
|
|
|
|
|
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Pillow==10.1.0
|
2 |
+
torch==2.1.2
|
3 |
+
torchvision==0.16.2
|
4 |
+
transformers==4.40.0
|
5 |
+
sentencepiece==0.1.99
|
6 |
+
flash-attn==2.3.6
|
test_handler.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from handler import EndpointHandler
|
2 |
+
from PIL import Image
|
3 |
+
import io
|
4 |
+
|
5 |
+
def test_endpoint():
|
6 |
+
# Initialize the handler
|
7 |
+
handler = EndpointHandler("openbmb/MiniCPM-V-2_6")
|
8 |
+
|
9 |
+
# Load a test image
|
10 |
+
with open("test_image.jpg", "rb") as image_file:
|
11 |
+
image_bytes = image_file.read()
|
12 |
+
|
13 |
+
# Create a mock request data
|
14 |
+
mock_data = {
|
15 |
+
"inputs": {
|
16 |
+
"image": type('MockFile', (), {'file': io.BytesIO(image_bytes)})(),
|
17 |
+
"question": "What is in this image?"
|
18 |
+
}
|
19 |
+
}
|
20 |
+
|
21 |
+
# Call the handler
|
22 |
+
result = handler(mock_data)
|
23 |
+
|
24 |
+
# Print the result
|
25 |
+
print(result["generated_text"])
|
26 |
+
|
27 |
+
if __name__ == "__main__":
|
28 |
+
test_endpoint()
|