File size: 3,087 Bytes
68677a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511bef8
c362d01
68677a4
 
 
 
 
 
 
 
 
 
511bef8
68677a4
c362d01
68677a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import Dict, List, Any
from transformers import pipeline
from PIL import Image
import requests
from transformers import AutoModelForCausalLM, LlamaTokenizer
import torch

# from accelerate import (
#     init_empty_weights,
#     infer_auto_device_map,
#     load_checkpoint_and_dispatch,
# )
import os

import logging

# from transformers import logging as hf_logging
# hf_logging.set_verbosity_debug()

logging.basicConfig(level=logging.INFO)


class EndpointHandler:
    def __init__(self, path=""):
        self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")

        self.model = (
            AutoModelForCausalLM.from_pretrained(
                "THUDM/cogvlm-grounding-generalist-hf",
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                trust_remote_code=True,
            )
            .to("cuda")
            .eval()
        )

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
         data args:
              inputs (:obj: `str` | `PIL.Image` | `np.array`)
              kwargs
        Return:
              A :obj:`list` | `dict`: will be serialized and returned
        """

        query = data["inputs"]
        img_uri = data["img_uri"]

        image = Image.open(
            requests.get(
                img_uri,
                stream=True,
            ).raw
        ).convert("RGB")

        inputs = self.model.build_conversation_input_ids(
            self.tokenizer, query=query, images=[image]
        )
        inputs = {
            "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
            "token_type_ids": inputs["token_type_ids"].unsqueeze(0).to("cuda"),
            "attention_mask": inputs["attention_mask"].unsqueeze(0).to("cuda"),
            "images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)]],
        }
        gen_kwargs = {"max_length": 2048, "do_sample": False}

        with torch.no_grad():
            outputs = self.model.generate(**inputs, **gen_kwargs)
            outputs = outputs[:, inputs["input_ids"].shape[1] :]
            result = self.tokenizer.decode(outputs[0])
            return result


# query = "How many houses are there in this cartoon?"
# image = Image.open(
#     requests.get(
#         "https://github.com/THUDM/CogVLM/blob/main/examples/3.jpg?raw=true", stream=True
#     ).raw
# ).convert("RGB")
# inputs = model.build_conversation_input_ids(
#     tokenizer, query=query, history=[], images=[image], template_version="vqa"
# )  # vqa mode
# inputs = {
#     "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
#     "token_type_ids": inputs["token_type_ids"].unsqueeze(0).to("cuda"),
#     "attention_mask": inputs["attention_mask"].unsqueeze(0).to("cuda"),
#     "images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)]],
# }
# gen_kwargs = {"max_length": 2048, "do_sample": False}

# with torch.no_grad():
#     outputs = model.generate(**inputs, **gen_kwargs)
#     outputs = outputs[:, inputs["input_ids"].shape[1] :]
#     print(tokenizer.decode(outputs[0]))