gabrielearmento commited on
Commit
ef5ce50
1 Parent(s): 964ef5d

This commit refactors the `handler.py` file to improve the performance of the Visual Question Answering (VQA) model. The changes include:

Browse files

- Loading the VQA pipeline for the model
- Modifying the `__call__` method to extract the image and question from the request
- Performing the VQA using the pipeline

These changes aim to enhance the efficiency and accuracy of the VQA process.

Files changed (1) hide show
  1. handler.py +34 -20
handler.py CHANGED
@@ -1,24 +1,38 @@
1
- from typing import Dict, List, Any
2
- from transformers import AutoModel, AutoTokenizer
3
- from PIL import Image
4
 
5
- class EndpointHandler():
 
 
 
6
  def __init__(self, path=""):
7
- # Preload all the elements you are going to need at inference.
8
- self.model = AutoModel.from_pretrained('openbmb/MiniCPM-Llama3-V-2_5-int4', trust_remote_code=True)
9
- self.tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-Llama3-V-2_5-int4', trust_remote_code=True)
10
 
11
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
12
- image_url = data.pop("image_url")
13
- image = Image.open(image_url).convert("RGB")
14
- message = data.pop("message")
15
- messages = [{'role': 'user', 'content': message}]
16
- return model.chat(
17
- image=image,
18
- msgs=msgs,
19
- tokenizer=self.tokenizer,
20
- sampling=True, # if sampling=False, beam_search will be used by default
21
- temperature=0.7,
22
- # system_prompt='' # pass system_prompt if needed
23
  )
24
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
 
 
2
 
3
+ from transformers import AutoModel, AutoTokenizer, pipeline
4
+
5
+
6
+ class EndpointHandler:
7
  def __init__(self, path=""):
 
 
 
8
 
9
+ # Load the pipeline for the model
10
+ model = AutoModel.from_pretrained(
11
+ "openbmb/MiniCPM-Llama3-V-2_5-int4",
12
+ trust_remote_code=True,
13
+ )
14
+ tokenizer = AutoTokenizer.from_pretrained(
15
+ "openbmb/MiniCPM-Llama3-V-2_5-int4", trust_remote_code=True
16
+ )
17
+ self.pipeline = pipeline(
18
+ "visual-question-answering",
19
+ model=model,
20
+ tokenizer=tokenizer,
21
  )
22
+
23
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
24
+ # Get the image and question from the request
25
+ image = data.get("image")
26
+ question = data.get("question")
27
+
28
+ # Perform the VQA
29
+ return self.pipeline(image, question)
30
+
31
+
32
+ # if __name__ == "__main__":
33
+ # handler = EndpointHandler()
34
+ # data = {
35
+ # "image": "https://pwm.im-cdn.it/image/1524723057/xxl.jpg",
36
+ # "question": "Describe the image:",
37
+ # }
38
+ # print(handler(data))