""" A model worker executes the model. """ import os import json import time import uuid import asyncio import requests import argparse import threading from threading import Thread from functools import partial from typing import Iterator, List, Optional, Tuple import uvicorn from fastapi import FastAPI, Request, BackgroundTasks from fastapi.responses import StreamingResponse import torch import decord import numpy as np from PIL import Image from decord import VideoReader, cpu from transformers import TextIteratorStreamer from videollama2.constants import WORKER_HEART_BEAT_INTERVAL from videollama2.utils import (build_logger, server_error_msg, pretty_print_semaphore) from videollama2.model.builder import load_pretrained_model from videollama2.mm_utils import process_images, process_videos, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria, tokenizer_MMODAL_token from videollama2.mm_utils import chunk_list, frame_expansion from videollama2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_TOKEN, NUM_FRAMES, MMODAL_TOKEN_INDEX GB = 1 << 30 worker_id = str(uuid.uuid4())[:6] logger = build_logger("model_worker", f"model_worker_{worker_id}.log") global_counter = 0 model_semaphore = None # variable_content = os.getenv('MY_VARIABLE', '') # KEYWORDS_LIST = set(variable_content.split('\n')) KEYWORDS_LIST = [] path = 'assets/keywords.txt' if os.path.exists(path): with open(path, 'r', encoding='utf-8') as file: for line in file: KEYWORDS_LIST.append(line.strip()) else: KEYWORDS_LIST = [] KEYWORD_BLOCK_MESSAGE2 = "The output contains political, erotic and other unsafe content that violates local laws. Please re-enter your question." KEYWORD_BLOCK_MESSAGE1 = "Your input question contains political, erotic and other unsafe content that violates local laws. Please re-enter your question." STREAM_CHECK_MULTIPLE = 20 def heart_beat_worker(controller): while True: time.sleep(WORKER_HEART_BEAT_INTERVAL) controller.send_heart_beat() def safety_check(text, history=None, ) -> Optional[str]: if len(KEYWORDS_LIST) > 0 and any(x in text.lower() for x in KEYWORDS_LIST): print('############') return KEYWORD_BLOCK_MESSAGE2 return None def input_safety_check(text) -> Optional[str]: if len(KEYWORDS_LIST) > 0 and any(x in text.lower() for x in KEYWORDS_LIST): print('######## Input keyword alarm triggered:', text) return KEYWORD_BLOCK_MESSAGE1 return None class ModelWorker: def __init__(self, controller_addr, worker_addr, worker_id, no_register, model_path, model_base, model_name, load_8bit, load_4bit, device): self.controller_addr = controller_addr self.worker_addr = worker_addr self.worker_id = worker_id self.model_path = model_path if model_path.endswith("/"): model_path = model_path[:-1] if model_name is None: model_paths = model_path.split("/") if model_paths[-1].startswith('checkpoint-'): self.model_name = model_paths[-2] + "_" + model_paths[-1] else: self.model_name = model_paths[-1] else: self.model_name = model_name self.device = device logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device) self.is_multimodal = 'videollama2' in self.model_name.lower() or 'vlb' in self.model_name.lower() if not no_register: self.register_to_controller() self.heart_beat_thread = threading.Thread( target=heart_beat_worker, args=(self,)) self.heart_beat_thread.start() def register_to_controller(self): logger.info("Register to controller") url = self.controller_addr + "/register_worker" data = { "worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status() } r = requests.post(url, json=data) assert r.status_code == 200 def send_heart_beat(self): logger.info(f"Send heart beat. Models: {[self.model_name]}. " f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " f"global_counter: {global_counter}") url = self.controller_addr + "/receive_heart_beat" while True: try: ret = requests.post(url, json={ "worker_name": self.worker_addr, "queue_length": self.get_queue_length()}, timeout=5) exist = ret.json()["exist"] break except requests.exceptions.RequestException as e: logger.error(f"heart beat error: {e}") time.sleep(5) if not exist: self.register_to_controller() def get_queue_length(self): if model_semaphore is None: return 0 else: return args.limit_model_concurrency - model_semaphore._value + (len( model_semaphore._waiters) if model_semaphore._waiters is not None else 0) def get_status(self): return { "model_names": [self.model_name], "speed": 1, "queue_length": self.get_queue_length(), } @torch.inference_mode() def generate_stream(self, params): tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor prompt = params["prompt"] ori_prompt = prompt images_or_videos = params.get("images", None) #print("Input images:", images_or_videos) num_image_tokens = 0 modal_list = [] if images_or_videos is not None and len(images_or_videos) and self.is_multimodal: if len(images_or_videos) > 0: if len(images_or_videos) != prompt.count(DEFAULT_IMAGE_TOKEN) and len(images_or_videos) != (prompt.count(DEFAULT_VIDEO_TOKEN)): raise ValueError("Number of images/videos does not match number of /