cogvlm-chat-hf / handler.py
Marlon Wiprud
handler
a33ae41
raw
history blame
No virus
4.81 kB
from typing import Dict, List, Any
from transformers import pipeline
from PIL import Image
import requests
from transformers import AutoModelForCausalLM, LlamaTokenizer
import torch
from accelerate import (
init_empty_weights,
infer_auto_device_map,
load_checkpoint_and_dispatch,
)
class EndpointHandler:
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# self.pipeline = pipeline(
# "text-generation", model="THUDM/cogvlm-chat-hf", trust_remote_code=True
# )
# self.model = AutoModelForCausalLM.from_pretrained(
# "THUDM/cogvlm-chat-hf", trust_remote_code=True
# )
self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
# self.model = (
# AutoModelForCausalLM.from_pretrained(
# "THUDM/cogvlm-chat-hf",
# torch_dtype=torch.bfloat16,
# low_cpu_mem_usage=True,
# trust_remote_code=True,
# )
# .to("cuda")
# .eval()
# )
# DISTRIBUTED GPUS
with init_empty_weights():
self.model = AutoModelForCausalLM.from_pretrained(
"THUDM/cogvlm-chat-hf",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
device_map = infer_auto_device_map(
self.model,
max_memory={
0: "16GiB",
1: "16GiB",
2: "16GiB",
3: "16GiB",
"cpu": "180GiB",
},
no_split_module_classes=["CogVLMDecoderLayer"],
)
self.model = load_checkpoint_and_dispatch(
self.model,
"~/.cache/huggingface/hub/models--THUDM--cogvlm-chat-hf/snapshots/8abca878c4257412c4c38eeafaed3fe27a036730",
"~/.cache/huggingface/modules/transformers_modules/THUDM/cogvlm-chat-hf/8abca878c4257412c4c38eeafaed3fe27a036730", # typical, '~/.cache/huggingface/hub/models--THUDM--cogvlm-chat-hf/snapshots/balabala'
# "/home/ec2-user/.cache/huggingface/hub/models--THUDM--cogvlm-chat-hf/snapshots/8abca878c4257412c4c38eeafaed3fe27a036730", # typical, '~/.cache/huggingface/hub/models--THUDM--cogvlm-chat-hf/snapshots/balabala'
device_map=device_map,
no_split_module_classes=["CogVLMDecoderLayer"],
)
self.model = self.model.eval()
## DISTRIBUTED GPUS
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
query = data["query"]
img_uri = data["img_uri"]
image = Image.open(
requests.get(
img_uri,
stream=True,
).raw
).convert("RGB")
inputs = self.model.build_conversation_input_ids(
self.tokenizer,
query=query,
history=[],
images=[image],
template_version="vqa",
) # vqa mode
inputs = {
"input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
"token_type_ids": inputs["token_type_ids"].unsqueeze(0).to("cuda"),
"attention_mask": inputs["attention_mask"].unsqueeze(0).to("cuda"),
"images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)]],
}
gen_kwargs = {"max_length": 2048, "do_sample": False}
with torch.no_grad():
outputs = self.model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs["input_ids"].shape[1] :]
response = self.tokenizer.decode(outputs[0])
return response
# query = "How many houses are there in this cartoon?"
# image = Image.open(
# requests.get(
# "https://github.com/THUDM/CogVLM/blob/main/examples/3.jpg?raw=true", stream=True
# ).raw
# ).convert("RGB")
# inputs = model.build_conversation_input_ids(
# tokenizer, query=query, history=[], images=[image], template_version="vqa"
# ) # vqa mode
# inputs = {
# "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
# "token_type_ids": inputs["token_type_ids"].unsqueeze(0).to("cuda"),
# "attention_mask": inputs["attention_mask"].unsqueeze(0).to("cuda"),
# "images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)]],
# }
# gen_kwargs = {"max_length": 2048, "do_sample": False}
# with torch.no_grad():
# outputs = model.generate(**inputs, **gen_kwargs)
# outputs = outputs[:, inputs["input_ids"].shape[1] :]
# print(tokenizer.decode(outputs[0]))