Spaces:
Runtime error
Runtime error
# This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors. | |
# Based on https://github.com/haotian-liu/LLaVA. | |
""" | |
This file demonstrates an implementation of a multiprocess Real-time Long Video Understanding System. With a multiprocess logging module. | |
main process: CLI server I/O, LLM inference | |
process-1: logger listener | |
process-2: frame generator, | |
process-3: frame memory manager | |
Author: Haoji Zhang, Haotian Liu | |
(This code is based on https://github.com/haotian-liu/LLaVA) | |
""" | |
import argparse | |
import requests | |
import logging | |
import torch | |
import numpy as np | |
import time | |
import os | |
from flash_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | |
from flash_vstream.conversation import conv_templates, SeparatorStyle | |
from flash_vstream.model.builder import load_pretrained_model | |
from flash_vstream.utils import disable_torch_init | |
from flash_vstream.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria | |
from torch.multiprocessing import Process, Queue, Manager | |
from transformers import TextStreamer | |
from decord import VideoReader | |
from datetime import datetime | |
from PIL import Image | |
from io import BytesIO | |
class _Metric: | |
def __init__(self): | |
self._latest_value = None | |
self._sum = 0.0 | |
self._max = 0.0 | |
self._count = 0 | |
def val(self): | |
return self._latest_value | |
def max(self): | |
return self._max | |
def avg(self): | |
if self._count == 0: | |
return float('nan') | |
return self._sum / self._count | |
def add(self, value): | |
self._latest_value = value | |
self._sum += value | |
self._count += 1 | |
if value > self._max: | |
self._max = value | |
def __str__(self): | |
latest_formatted = f"{self.val:.6f}" if self.val is not None else "None" | |
average_formatted = f"{self.avg:.6f}" | |
max_formatted = f"{self.max:.6f}" | |
return f"{latest_formatted} ({average_formatted}, {max_formatted})" | |
class MetricMeter: | |
def __init__(self): | |
self._metrics = {} | |
def add(self, key, value): | |
if key not in self._metrics: | |
self._metrics[key] = _Metric() | |
self._metrics[key].add(value) | |
def val(self, key): | |
metric = self._metrics.get(key) | |
if metric is None or metric.val is None: | |
raise ValueError(f"No values have been added for key '{key}'.") | |
return metric.val | |
def avg(self, key): | |
metric = self._metrics.get(key) | |
if metric is None: | |
raise ValueError(f"No values have been added for key '{key}'.") | |
return metric.avg | |
def max(self, key): | |
metric = self._metrics.get(key) | |
if metric is None: | |
raise ValueError(f"No values have been added for key '{key}'.") | |
return metric.max | |
def __getitem__(self, key): | |
metric = self._metrics.get(key) | |
if metric is None: | |
raise KeyError(f"The key '{key}' does not exist.") | |
return str(metric) | |
def load_image(image_file): | |
if image_file.startswith('http://') or image_file.startswith('https://'): | |
response = requests.get(image_file) | |
image = Image.open(BytesIO(response.content)).convert('RGB') | |
else: | |
image = Image.open(image_file).convert('RGB') | |
return image | |
def listener(queue, filename): | |
############## Start sub process-1: Listener ############# | |
import sys, traceback | |
root = logging.getLogger() | |
root.setLevel(logging.DEBUG) | |
# h = logging.StreamHandler(sys.stdout) | |
h = logging.FileHandler(filename) | |
f = logging.Formatter('%(asctime)s %(processName)-10s %(name)s %(levelname)-8s %(message)s') | |
h.setFormatter(f) | |
root.addHandler(h) | |
while True: | |
try: | |
record = queue.get() | |
if record is None: # None is a signal to finish | |
break | |
logger = logging.getLogger(record.name) | |
logger.handle(record) # No level or filter logic applied - just do it! | |
except Exception: | |
import sys, traceback | |
print('Whoops! Problem:', file=sys.stderr) | |
traceback.print_exc(file=sys.stderr) | |
def worker_configurer(queue): | |
h = logging.handlers.QueueHandler(queue) # Just the one handler needed | |
root = logging.getLogger() | |
root.addHandler(h) | |
root.setLevel(logging.DEBUG) | |
def video_stream_similator(video_file, frame_queue, log_queue, video_fps=1.0, play_speed=1.0): | |
############## Start sub process-2: Simulator ############# | |
worker_configurer(log_queue) | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
vr = VideoReader(video_file) | |
sample_fps = round(vr.get_avg_fps() / video_fps) | |
frame_idx = [i for i in range(0, len(vr), sample_fps)] | |
video = vr.get_batch(frame_idx).asnumpy() | |
video = np.repeat(video, 6, axis=0) | |
length = video.shape[0] | |
sleep_time = 1 / video_fps / play_speed | |
time_meter = MetricMeter() | |
logger.info(f'Simulator Process: start, length = {length}') | |
try: | |
for start in range(0, length): | |
start_time = time.perf_counter() | |
end = min(start + 1, length) | |
video_clip = video[start:end] | |
frame_queue.put(video_clip) | |
if start > 0: | |
time_meter.add('real_sleep', start_time - last_start) | |
logger.info(f'Simulator: write {end - start} frames,\t{start} to {end},\treal_sleep={time_meter["real_sleep"]}') | |
if end < length: | |
time.sleep(sleep_time) | |
last_start = start_time | |
frame_queue.put(None) | |
except Exception as e: | |
print(f'Simulator Exception: {e}') | |
time.sleep(0.1) | |
logger.info(f'Simulator Process: end') | |
def frame_memory_manager(model, image_processor, frame_queue, log_queue): | |
############## Start sub process-3: Memory Manager ############# | |
worker_configurer(log_queue) | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
time_meter = MetricMeter() | |
logger.info(f'MemManager Process: start') | |
frame_cnt = 0 | |
while True: | |
try: | |
video_clip = frame_queue.get() | |
start_time = time.perf_counter() | |
if video_clip is None: | |
logger.info(f'MemManager: Ooops, get None') | |
break | |
logger.info(f'MemManager: get {video_clip.shape[0]} frames from queue') | |
image = image_processor.preprocess(video_clip, return_tensors='pt')['pixel_values'] | |
image = image.unsqueeze(0) | |
image_tensor = image.to(model.device, dtype=torch.float16) | |
# time_2 = time.perf_counter() | |
logger.info(f'MemManager: Start embedding') | |
with torch.inference_mode(): | |
model.embed_video_streaming(image_tensor) | |
logger.info(f'MemManager: End embedding') | |
end_time = time.perf_counter() | |
if frame_cnt > 0: | |
time_meter.add('memory_latency', end_time - start_time) | |
logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={time_meter["memory_latency"]}') | |
else: | |
logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={end_time - start_time:.6f}, not logged') | |
frame_cnt += video_clip.shape[0] | |
except Exception as e: | |
print(f'MemManager Exception: {e}') | |
time.sleep(0.1) | |
logger.info(f'MemManager Process: end') | |
def main(args): | |
# torch.multiprocessing.log_to_stderr(logging.DEBUG) | |
torch.multiprocessing.set_start_method('spawn', force=True) | |
disable_torch_init() | |
log_queue = Queue() | |
frame_queue = Queue(maxsize=10) | |
processes = [] | |
############## Start listener process ############# | |
p1 = Process(target=listener, args=(log_queue, args.log_file)) | |
processes.append(p1) | |
p1.start() | |
############## Start main process ############# | |
worker_configurer(log_queue) | |
logger = logging.getLogger(__name__) | |
model_name = get_model_name_from_path(args.model_path) | |
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) | |
logger.info(f'Using conv_mode={args.conv_mode}') | |
conv = conv_templates[args.conv_mode].copy() | |
if "mpt" in model_name.lower(): | |
roles = ('user', 'assistant') | |
else: | |
roles = conv.roles | |
with Manager() as manager: | |
image_tensor = None | |
model.use_video_streaming_mode = True | |
model.video_embedding_memory = manager.list() | |
if args.video_max_frames is not None: | |
model.config.video_max_frames = args.video_max_frames | |
logger.info(f'Important: set model.config.video_max_frames = {model.config.video_max_frames}') | |
logger.info(f'Important: set video_fps = {args.video_fps}') | |
logger.info(f'Important: set play_speed = {args.play_speed}') | |
############## Start simulator process ############# | |
p2 = Process(target=video_stream_similator, | |
args=(args.video_file, frame_queue, log_queue, args.video_fps, args.play_speed)) | |
processes.append(p2) | |
p2.start() | |
############## Start memory manager process ############# | |
p3 = Process(target=frame_memory_manager, | |
args=(model, image_processor, frame_queue, log_queue)) | |
processes.append(p3) | |
p3.start() | |
# start QA server | |
start_time = datetime.now() | |
time_meter = MetricMeter() | |
conv_cnt = 0 | |
while True: | |
time.sleep(5) | |
try: | |
# inp = input(f"{roles[0]}: ") | |
inp = "what is in the video?" | |
except EOFError: | |
inp = "" | |
if not inp: | |
print("exit...") | |
break | |
# 获取当前时间 | |
now = datetime.now() | |
conv_start_time = time.perf_counter() | |
# 将当前时间格式化为字符串 | |
current_time = now.strftime("%H:%M:%S") | |
duration = now.timestamp() - start_time.timestamp() | |
# 打印当前时间 | |
print("\nCurrent Time:", current_time, "Run for:", duration) | |
print(f"{roles[0]}: {inp}", end="\n") | |
print(f"{roles[1]}: ", end="") | |
# every conversation is a new conversation | |
conv = conv_templates[args.conv_mode].copy() | |
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp | |
conv.append_message(conv.roles[0], inp) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) | |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
keywords = [stop_str] | |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) | |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
llm_start_time = time.perf_counter() | |
with torch.inference_mode(): | |
output_ids = model.generate( | |
input_ids, | |
images=image_tensor, | |
do_sample=True if args.temperature > 0 else False, | |
temperature=args.temperature, | |
max_new_tokens=args.max_new_tokens, | |
streamer=streamer, | |
use_cache=True, | |
stopping_criteria=[stopping_criteria] | |
) | |
llm_end_time = time.perf_counter() | |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() | |
conv.messages[-1][-1] = outputs | |
conv_end_time = time.perf_counter() | |
if conv_cnt > 0: | |
time_meter.add('conv_latency', conv_end_time - conv_start_time) | |
time_meter.add('llm_latency', llm_end_time - llm_start_time) | |
time_meter.add('real_sleep', conv_start_time - last_conv_start_time) | |
logger.info(f'CliServer: idx={conv_cnt},\treal_sleep={time_meter["real_sleep"]},\tconv_latency={time_meter["conv_latency"]},\tllm_latency={time_meter["llm_latency"]}') | |
else: | |
logger.info(f'CliServer: idx={conv_cnt},\tconv_latency={conv_end_time - conv_start_time},\tllm_latency={llm_end_time - llm_start_time}') | |
conv_cnt += 1 | |
last_conv_start_time = conv_start_time | |
for p in processes: | |
p.terminate() | |
print("All processes finished.") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model-path", type=str, default="facebook/opt-350m") | |
parser.add_argument("--model-base", type=str, default=None) | |
parser.add_argument("--image-file", type=str, default=None) | |
parser.add_argument("--video-file", type=str, default=None) | |
parser.add_argument("--device", type=str, default="cuda") | |
parser.add_argument("--conv-mode", type=str, default="vicuna_v1") | |
parser.add_argument("--temperature", type=float, default=0.2) | |
parser.add_argument("--max-new-tokens", type=int, default=512) | |
parser.add_argument("--load-8bit", action="store_true") | |
parser.add_argument("--load-4bit", action="store_true") | |
parser.add_argument("--debug", action="store_true") | |
parser.add_argument("--log-file", type=str, default="tmp_cli.log") | |
parser.add_argument("--use_1process", action="store_true") | |
parser.add_argument("--video_max_frames", type=int, default=None) | |
parser.add_argument("--video_fps", type=float, default=1.0) | |
parser.add_argument("--play_speed", type=float, default=1.0) | |
args = parser.parse_args() | |
main(args) | |