nbroad HF staff commited on
Commit
2cdc125
1 Parent(s): 0876ee7

Upload 3 files

Browse files
Files changed (3) hide show
  1. handler.py +56 -0
  2. requirements.txt +2 -0
  3. test.ipynb +88 -0
handler.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+ from typing import Dict, List, Any
4
+ from transformers import Pix2StructForConditionalGeneration, AutoProcessor
5
+ from PIL import Image
6
+ import torch
7
+
8
+ class EndpointHandler():
9
+
10
+ def __init__(self):
11
+
12
+ model_name = "google/pix2struct-infographics-vqa-large"
13
+
14
+
15
+ self.model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
16
+ self.processor = AutoProcessor.from_pretrained(model_name)
17
+ self.text_prompt = None #
18
+
19
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ self.model.to(self.device)
22
+
23
+
24
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
25
+ """
26
+ Args:
27
+ data (:obj:):
28
+ includes the input data and the parameters for the inference.
29
+ Return:
30
+ a dictionary with the output of the model. The only key is `output` and the
31
+ value is a list of str.
32
+ """
33
+ inputs = data.pop("inputs", data)
34
+ parameters = data.pop("parameters", {})
35
+
36
+ if isinstance(inputs["image"], list):
37
+ img = [Image.open(BytesIO(base64.b64decode(img))) for img in inputs['image']]
38
+ else:
39
+ img = Image.open(BytesIO(base64.b64decode(inputs['image'])))
40
+
41
+ question = inputs['question']
42
+
43
+
44
+
45
+ with torch.inference_mode():
46
+ model_inputs = self.processor(images=img, text=question, return_tensors="pt").to(self.device)
47
+
48
+ raw_output = self.model.generate(**model_inputs, **parameters)
49
+
50
+ decoded_output = self.processor.batch_decode(raw_output, skip_special_tokens=True)
51
+
52
+
53
+ # postprocess the prediction
54
+ return {
55
+ "output": decoded_output
56
+ }
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers==4.35.1
2
+ sentencepiece==0.1.99
test.ipynb ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 6,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import handler\n",
10
+ "import importlib\n",
11
+ "\n",
12
+ "importlib.reload(handler)\n",
13
+ "\n",
14
+ "import handler\n",
15
+ "\n",
16
+ "h = handler.EndpointHandler()"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 13,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "import base64\n",
26
+ "from pathlib import Path\n",
27
+ "\n",
28
+ "if not Path(\"What-is-an-infographic.jpg\").exists():\n",
29
+ " !wget https://visme.co/blog/wp-content/uploads/2020/02/What-is-an-infographic.jpg\n",
30
+ "\n",
31
+ "with open(\"What-is-an-infographic.jpg\", \"rb\") as f:\n",
32
+ " b64 = base64.b64encode(f.read())\n",
33
+ "\n",
34
+ "question = \"What percent of information do we understand through body language?\"\n",
35
+ "\n",
36
+ "payload = {\n",
37
+ " \"inputs\": {\n",
38
+ " \"image\": [b64.decode(\"utf-8\")]*2, \n",
39
+ " \"question\": [question]*2\n",
40
+ " }, \n",
41
+ " \"parameters\":{\n",
42
+ " \"max_new_tokens\": 10,\n",
43
+ " }}"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": 14,
49
+ "metadata": {},
50
+ "outputs": [
51
+ {
52
+ "data": {
53
+ "text/plain": [
54
+ "{'output': ['55%', '55%']}"
55
+ ]
56
+ },
57
+ "execution_count": 14,
58
+ "metadata": {},
59
+ "output_type": "execute_result"
60
+ }
61
+ ],
62
+ "source": [
63
+ "h(payload)"
64
+ ]
65
+ }
66
+ ],
67
+ "metadata": {
68
+ "kernelspec": {
69
+ "display_name": "Python 3",
70
+ "language": "python",
71
+ "name": "python3"
72
+ },
73
+ "language_info": {
74
+ "codemirror_mode": {
75
+ "name": "ipython",
76
+ "version": 3
77
+ },
78
+ "file_extension": ".py",
79
+ "mimetype": "text/x-python",
80
+ "name": "python",
81
+ "nbconvert_exporter": "python",
82
+ "pygments_lexer": "ipython3",
83
+ "version": "3.10.12"
84
+ }
85
+ },
86
+ "nbformat": 4,
87
+ "nbformat_minor": 2
88
+ }