chatGLM3-6B-Base / handler.py
Marlon Wiprud
chore: update
693f1b3
raw
history blame
No virus
4.08 kB
from typing import Dict, List, Any
from transformers import pipeline
from PIL import Image
import requests
from transformers import AutoModelForCausalLM, LlamaTokenizer
import torch
from accelerate import (
init_empty_weights,
infer_auto_device_map,
load_checkpoint_and_dispatch,
)
class EndpointHandler:
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# self.pipeline = pipeline(
# "text-generation", model="THUDM/cogvlm-chat-hf", trust_remote_code=True
# )
# self.model = AutoModelForCausalLM.from_pretrained(
# "THUDM/cogvlm-chat-hf", trust_remote_code=True
# )
self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
# with init_empty_weights():
# self.model = (
# AutoModelForCausalLM.from_pretrained(
# "THUDM/cogvlm-chat-hf",
# torch_dtype=torch.bfloat16,
# low_cpu_mem_usage=True,
# trust_remote_code=True,
# )
# .to("cuda")
# .eval()
# )
device_map = infer_auto_device_map(
model,
max_memory={
0: "16GiB",
1: "16GiB",
2: "16GiB",
3: "16GiB",
"cpu": "180GiB",
},
no_split_module_classes="CogVLMDecoderLayer",
)
self.model = load_checkpoint_and_dispatch(
model,
"~/.cache/huggingface/hub/models--THUDM--cogvlm-chat-hf/snapshots", # typical, '~/.cache/huggingface/hub/models--THUDM--cogvlm-chat-hf/snapshots/balabala'
device_map=device_map,
)
model = model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
query = data["query"]
img_uri = data["img_uri"]
image = Image.open(
requests.get(
img_uri,
stream=True,
).raw
).convert("RGB")
inputs = self.model.build_conversation_input_ids(
self.tokenizer,
query=query,
history=[],
images=[image],
template_version="vqa",
) # vqa mode
inputs = {
"input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
"token_type_ids": inputs["token_type_ids"].unsqueeze(0).to("cuda"),
"attention_mask": inputs["attention_mask"].unsqueeze(0).to("cuda"),
"images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)]],
}
gen_kwargs = {"max_length": 2048, "do_sample": False}
with torch.no_grad():
outputs = self.model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs["input_ids"].shape[1] :]
response = self.tokenizer.decode(outputs[0])
return response
# query = "How many houses are there in this cartoon?"
# image = Image.open(
# requests.get(
# "https://github.com/THUDM/CogVLM/blob/main/examples/3.jpg?raw=true", stream=True
# ).raw
# ).convert("RGB")
# inputs = model.build_conversation_input_ids(
# tokenizer, query=query, history=[], images=[image], template_version="vqa"
# ) # vqa mode
# inputs = {
# "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
# "token_type_ids": inputs["token_type_ids"].unsqueeze(0).to("cuda"),
# "attention_mask": inputs["attention_mask"].unsqueeze(0).to("cuda"),
# "images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)]],
# }
# gen_kwargs = {"max_length": 2048, "do_sample": False}
# with torch.no_grad():
# outputs = model.generate(**inputs, **gen_kwargs)
# outputs = outputs[:, inputs["input_ids"].shape[1] :]
# print(tokenizer.decode(outputs[0]))