fredaddy commited on
Commit
bb2a012
·
1 Parent(s): e56d8f8

add custom handler

Browse files
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # gitignore
2
+
3
+ .venv
4
+ test_image_1.jpg
5
+ test_image_2.jpg
6
+ test_image_3.jpg
7
+ test_image_4.jpg
8
+ test_image_5.jpg
9
+ test_image_6.jpg
10
+ test_image_7.jpg
11
+ test_image_8.jpg
12
+ test_image_9.jpg
13
+
handler.py CHANGED
@@ -1,33 +1,27 @@
1
- from PIL import Image
2
  import torch
 
3
  from transformers import AutoModel, AutoTokenizer
4
 
5
- class ModelHandler:
6
- def __init__(self):
7
- # Load the model and tokenizer with appropriate weights
8
  self.model = AutoModel.from_pretrained(
9
- 'fredaddy/MiniCPM-V-2_6',
10
  trust_remote_code=True,
11
- attn_implementation='sdpa',
12
- torch_dtype=torch.bfloat16
13
- ).eval().cuda()
14
-
15
- self.tokenizer = AutoTokenizer.from_pretrained('fredaddy/MiniCPM-V-2_6', trust_remote_code=True)
16
 
17
- def preprocess(self, inputs):
18
- # Preprocess image input
19
- image = Image.open(inputs['image'].file).convert('RGB')
20
- question = inputs.get("question", "Extract all data in the image. Be extremely careful to ensure that you don't miss anything. It's imperative that you extract and digitize everything on that page.")
21
  msgs = [{'role': 'user', 'content': [image, question]}]
22
- return msgs
23
-
24
- def inference(self, msgs):
25
- # Run inference on the model
26
- result = self.model.chat(image=None, msgs=msgs, tokenizer=self.tokenizer)
27
- return result
28
 
29
- def postprocess(self, result):
30
- # Postprocess the output from the model
31
- return {"generated_text": result}
 
 
32
 
33
- service = ModelHandler()
 
 
1
  import torch
2
+ from PIL import Image
3
  from transformers import AutoModel, AutoTokenizer
4
 
5
+ class EndpointHandler:
6
+ def __init__(self, path):
 
7
  self.model = AutoModel.from_pretrained(
8
+ path,
9
  trust_remote_code=True,
10
+ attn_implementation='sdpa',
11
+ torch_dtype=torch.float16
12
+ )
13
+ self.model = self.model.eval().cuda()
14
+ self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
15
 
16
+ def __call__(self, data):
17
+ image = Image.open(data['inputs']['image'].file).convert('RGB')
18
+ question = data['inputs'].get("question", "Extract all data in the image. Be extremely careful to ensure that you don't miss anything. It's imperative that you extract and digitize everything on that page.")
 
19
  msgs = [{'role': 'user', 'content': [image, question]}]
 
 
 
 
 
 
20
 
21
+ res = self.model.chat(
22
+ image=None,
23
+ msgs=msgs,
24
+ tokenizer=self.tokenizer
25
+ )
26
 
27
+ return {"generated_text": res}
model-00001-of-00004.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:327b5e63ac7729d00b0b17233a0d648a623fc080cc1069dce9cb79478cdaf2b5
3
- size 4874808328
 
 
 
 
model-00002-of-00004.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e6d7af1f12ba8c72fec16ba3f3951df5643bef24de8b0af5baec67290c07ca4f
3
- size 4932751496
 
 
 
 
model-00003-of-00004.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e9ba1332c74a92b9c8af5fbc2c2da1baa8cd5ed4fd8056b98a586ff4bc3e54cf
3
- size 4330865648
 
 
 
 
model-00004-of-00004.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:16fcfafa8caaa67e16f9dae7cc79bf081eb59d4e248723a8e79d9678e42732fe
3
- size 2060017080
 
 
 
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Pillow==10.1.0
2
+ torch==2.1.2
3
+ torchvision==0.16.2
4
+ transformers==4.40.0
5
+ sentencepiece==0.1.99
6
+ flash-attn==2.3.6
test_handler.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+ from PIL import Image
3
+ import io
4
+
5
+ def test_endpoint():
6
+ # Initialize the handler
7
+ handler = EndpointHandler("openbmb/MiniCPM-V-2_6")
8
+
9
+ # Load a test image
10
+ with open("test_image.jpg", "rb") as image_file:
11
+ image_bytes = image_file.read()
12
+
13
+ # Create a mock request data
14
+ mock_data = {
15
+ "inputs": {
16
+ "image": type('MockFile', (), {'file': io.BytesIO(image_bytes)})(),
17
+ "question": "What is in this image?"
18
+ }
19
+ }
20
+
21
+ # Call the handler
22
+ result = handler(mock_data)
23
+
24
+ # Print the result
25
+ print(result["generated_text"])
26
+
27
+ if __name__ == "__main__":
28
+ test_endpoint()