import os from time import perf_counter from typing import Any, List, Tuple, Union import numpy as np import requests import torch from PIL import Image from transformers import AutoModelForCausalLM, LlamaTokenizer from inference.core.entities.requests.cogvlm import CogVLMInferenceRequest from inference.core.entities.responses.cogvlm import CogVLMResponse from inference.core.env import ( API_KEY, COGVLM_LOAD_4BIT, COGVLM_LOAD_8BIT, COGVLM_VERSION_ID, MODEL_CACHE_DIR, ) from inference.core.models.base import Model, PreprocessReturnMetadata from inference.core.utils.image_utils import load_image_rgb DEVICE = "cuda" if torch.cuda.is_available() else "cpu" class CogVLM(Model): def __init__(self, model_id=f"cogvlm/{COGVLM_VERSION_ID}", **kwargs): self.model_id = model_id self.endpoint = model_id self.api_key = API_KEY self.dataset_id, self.version_id = model_id.split("/") if COGVLM_LOAD_4BIT and COGVLM_LOAD_8BIT: raise ValueError( "Only one of environment variable `COGVLM_LOAD_4BIT` or `COGVLM_LOAD_8BIT` can be true" ) self.cache_dir = os.path.join(MODEL_CACHE_DIR, self.endpoint) with torch.inference_mode(): self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5") self.model = AutoModelForCausalLM.from_pretrained( f"THUDM/{self.version_id}", torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True, load_in_4bit=COGVLM_LOAD_4BIT, load_in_8bit=COGVLM_LOAD_8BIT, cache_dir=self.cache_dir, ).eval() self.task_type = "lmm" def preprocess( self, image: Any, **kwargs ) -> Tuple[Image.Image, PreprocessReturnMetadata]: pil_image = Image.fromarray(load_image_rgb(image)) return pil_image, PreprocessReturnMetadata({}) def postprocess( self, predictions: Tuple[str], preprocess_return_metadata: PreprocessReturnMetadata, **kwargs, ) -> Any: return predictions[0] def predict(self, image_in: Image.Image, prompt="", history=None, **kwargs): images = [image_in] if history is None: history = [] built_inputs = self.model.build_conversation_input_ids( self.tokenizer, query=prompt, history=history, images=images ) # chat mode inputs = { "input_ids": built_inputs["input_ids"].unsqueeze(0).to(DEVICE), "token_type_ids": built_inputs["token_type_ids"].unsqueeze(0).to(DEVICE), "attention_mask": built_inputs["attention_mask"].unsqueeze(0).to(DEVICE), "images": [[built_inputs["images"][0].to(DEVICE).to(torch.float16)]], } gen_kwargs = {"max_length": 2048, "do_sample": False} with torch.inference_mode(): outputs = self.model.generate(**inputs, **gen_kwargs) outputs = outputs[:, inputs["input_ids"].shape[1] :] text = self.tokenizer.decode(outputs[0]) if text.endswith(""): text = text[:-4] return (text,) def infer_from_request(self, request: CogVLMInferenceRequest) -> CogVLMResponse: t1 = perf_counter() text = self.infer(**request.dict()) response = CogVLMResponse(response=text) response.time = perf_counter() - t1 return response if __name__ == "__main__": m = CogVLM() m.infer()