Marlon Wiprud commited on
Commit
736b696
1 Parent(s): aabdd3c

feat: setup handler

Browse files
Files changed (2) hide show
  1. handler.py +97 -0
  2. requirements.txt +7 -0
handler.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+ import requests
5
+ from transformers import AutoModelForCausalLM, LlamaTokenizer
6
+ import torch
7
+
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ # Preload all the elements you are going to need at inference.
12
+
13
+ # self.pipeline = pipeline(
14
+ # "text-generation", model="THUDM/cogvlm-chat-hf", trust_remote_code=True
15
+ # )
16
+
17
+ # self.model = AutoModelForCausalLM.from_pretrained(
18
+ # "THUDM/cogvlm-chat-hf", trust_remote_code=True
19
+ # )
20
+
21
+ self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
22
+
23
+ self.model = (
24
+ AutoModelForCausalLM.from_pretrained(
25
+ "THUDM/cogvlm-chat-hf",
26
+ torch_dtype=torch.bfloat16,
27
+ low_cpu_mem_usage=True,
28
+ trust_remote_code=True,
29
+ )
30
+ .to("cuda")
31
+ .eval()
32
+ )
33
+
34
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
35
+ """
36
+ data args:
37
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
38
+ kwargs
39
+ Return:
40
+ A :obj:`list` | `dict`: will be serialized and returned
41
+ """
42
+
43
+ query = data["query"]
44
+ img_uri = data["img_uri"]
45
+
46
+ image = Image.open(
47
+ requests.get(
48
+ img_uri,
49
+ stream=True,
50
+ ).raw
51
+ ).convert("RGB")
52
+
53
+ inputs = self.model.build_conversation_input_ids(
54
+ self.tokenizer,
55
+ query=query,
56
+ history=[],
57
+ images=[image],
58
+ template_version="vqa",
59
+ ) # vqa mode
60
+
61
+ inputs = {
62
+ "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
63
+ "token_type_ids": inputs["token_type_ids"].unsqueeze(0).to("cuda"),
64
+ "attention_mask": inputs["attention_mask"].unsqueeze(0).to("cuda"),
65
+ "images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)]],
66
+ }
67
+
68
+ gen_kwargs = {"max_length": 2048, "do_sample": False}
69
+
70
+ with torch.no_grad():
71
+ outputs = self.model.generate(**inputs, **gen_kwargs)
72
+ outputs = outputs[:, inputs["input_ids"].shape[1] :]
73
+ response = self.tokenizer.decode(outputs[0])
74
+ return response
75
+
76
+
77
+ # query = "How many houses are there in this cartoon?"
78
+ # image = Image.open(
79
+ # requests.get(
80
+ # "https://github.com/THUDM/CogVLM/blob/main/examples/3.jpg?raw=true", stream=True
81
+ # ).raw
82
+ # ).convert("RGB")
83
+ # inputs = model.build_conversation_input_ids(
84
+ # tokenizer, query=query, history=[], images=[image], template_version="vqa"
85
+ # ) # vqa mode
86
+ # inputs = {
87
+ # "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
88
+ # "token_type_ids": inputs["token_type_ids"].unsqueeze(0).to("cuda"),
89
+ # "attention_mask": inputs["attention_mask"].unsqueeze(0).to("cuda"),
90
+ # "images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)]],
91
+ # }
92
+ # gen_kwargs = {"max_length": 2048, "do_sample": False}
93
+
94
+ # with torch.no_grad():
95
+ # outputs = model.generate(**inputs, **gen_kwargs)
96
+ # outputs = outputs[:, inputs["input_ids"].shape[1] :]
97
+ # print(tokenizer.decode(outputs[0]))
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.1.0
2
+ transformers==4.35.0
3
+ accelerate==0.24.1
4
+ sentencepiece==0.1.99
5
+ einops==0.7.0
6
+ xformers==0.0.22.post7
7
+ triton==2.1.0