|
import faulthandler |
|
faulthandler.enable() |
|
import os |
|
import random |
|
import time |
|
import signal |
|
from multiprocessing import Process, Queue, Event |
|
import numpy as np |
|
from rkllm_binding import * |
|
from rknnlite.api.rknn_lite import RKNNLite |
|
import threading |
|
import librosa |
|
from transformers import WhisperFeatureExtractor |
|
|
|
|
|
def audio_encoder_process(load_ready_queue, embedding_queue, audio_path_queue, start_event): |
|
|
|
AUDIO_ENCODER_PATH = "audio_encoder.rknn" |
|
|
|
|
|
audio_encoder = RKNNLite(verbose=False) |
|
model_size = os.path.getsize(AUDIO_ENCODER_PATH) |
|
print(f"Start loading audio encoder model (size: {model_size / 1024 / 1024:.2f} MB)") |
|
start_time = time.time() |
|
audio_encoder.load_rknn(AUDIO_ENCODER_PATH) |
|
end_time = time.time() |
|
print(f"Audio encoder loaded in {end_time - start_time:.2f} seconds") |
|
audio_encoder.init_runtime() |
|
|
|
|
|
feature_extractor = WhisperFeatureExtractor.from_pretrained(".") |
|
|
|
|
|
load_ready_queue.put("audio_ready") |
|
|
|
|
|
start_event.wait() |
|
|
|
def process_audio(audio_path, audio_encoder, feature_extractor): |
|
try: |
|
print("Start audio inference...") |
|
audio, _ = librosa.load(audio_path, sr=feature_extractor.sampling_rate) |
|
feature_extractor_output = feature_extractor( |
|
audio, |
|
sampling_rate=feature_extractor.sampling_rate, |
|
return_attention_mask=True, |
|
padding="max_length" |
|
) |
|
|
|
start_time = time.time() |
|
audio_embeddings = audio_encoder.inference(inputs=[ |
|
feature_extractor_output.input_features.astype(np.float32), |
|
feature_extractor_output.attention_mask.astype(np.float32) |
|
], data_format="nhwc")[0].astype(np.float32) |
|
end_time = time.time() |
|
print(f"Audio encoder inference time: {end_time - start_time:.2f} seconds") |
|
|
|
effective_length = feature_extractor_output.attention_mask.sum(-1)[0] |
|
effective_length = (effective_length - 1) // 2 + 1 |
|
output_lengths = (effective_length - 2) // 2 + 1 |
|
audio_embeddings = audio_embeddings[:, :output_lengths] |
|
print(audio_embeddings.shape) |
|
return audio_embeddings |
|
except Exception as e: |
|
print(f"Error processing audio: {e}") |
|
return None |
|
|
|
while True: |
|
audio_path = audio_path_queue.get() |
|
if audio_path == "STOP": |
|
break |
|
embeddings = process_audio(audio_path, audio_encoder, feature_extractor) |
|
if embeddings is not None: |
|
embedding_queue.put(embeddings) |
|
else: |
|
embedding_queue.put("ERROR") |
|
|
|
|
|
def llm_process(load_ready_queue, embedding_queue, prompt_queue, inference_done_queue, start_event): |
|
|
|
|
|
MODEL_PATH = "/home/firefly/qwen.rkllm" |
|
handle = None |
|
import locale |
|
|
|
|
|
system_lang = locale.getdefaultlocale()[0] |
|
is_chinese = system_lang and system_lang.startswith('zh') |
|
|
|
|
|
|
|
progress_messages_zh = [ |
|
"🚀 启动量子加速引擎...", |
|
"🧠 神经网络正在苏醒...", |
|
"🔄 并行宇宙计算进行中...", |
|
"🌟 正在注入能量矩阵...", |
|
"🔥 CPU已经到达工作温度,全力运转中...", |
|
"🎯 特征向量正在跳跃式生长...", |
|
"🎭 多头注意力机制开始营业...", |
|
"💨 散热风扇已经进入超音速状态...", |
|
"📚 语义解析器正在啃食数据...", |
|
"🔍 上下文关联分析师正在加班...", |
|
"🎨 视觉特征正在调色盘中混合...", |
|
"🤝 跨模态对齐正在相亲相爱中...", |
|
"⚡ 深度特征提取器已经深入地心...", |
|
"🧪 神经网络正在炼丹中...", |
|
"🎲 张量计算已经进入量子态...", |
|
"📦 模型参数正在装箱搬运...", |
|
"⚖️ 权重矩阵正在天平上找平衡...", |
|
"🗺 语义向量正在绘制航海图...", |
|
"🎭 注意力头们正在开会讨论...", |
|
"🏗 残差模块正在搭建天梯...", |
|
"🌈 激活函数正在调制彩虹...", |
|
"🎮 张量核心正在玩魔方...", |
|
"🎪 循环神经网络正在马戏团表演...", |
|
"🎨 特征图正在画饼充饥...", |
|
"🔮 模型正在占卜未来...", |
|
"🎯 优化器正在进行火箭轨道计算...", |
|
"🎪 批归一化正在杂技表演...", |
|
"🎭 Dropout正在玩捉迷藏...", |
|
"🌪 梯度正在形成龙卷风...", |
|
"🎢 反向传播正在过山车..." |
|
] |
|
|
|
progress_messages_en = [ |
|
"Loading...", |
|
"Extracting...", |
|
"Image fusion in progress...", |
|
"Matrix multiplication...", |
|
"Chip heating up...", |
|
"Feature vector calculation...", |
|
"Attention mechanism processing...", |
|
"Fan speed increasing...", |
|
"Semantic parsing...", |
|
"Context analysis...", |
|
"Visual feature encoding...", |
|
"Cross-modal alignment...", |
|
"Deep feature extraction...", |
|
"Neural network inference...", |
|
"Tensor operations...", |
|
"Loading model parameters...", |
|
"Weight matrix calculation...", |
|
"Semantic vector mapping...", |
|
"Multi-head attention...", |
|
"Residual connection..." |
|
] |
|
|
|
|
|
progress_messages = progress_messages_zh if is_chinese else progress_messages_en |
|
|
|
|
|
progress_stop_event = threading.Event() |
|
|
|
|
|
def show_progress(): |
|
while not progress_stop_event.is_set(): |
|
for msg in progress_messages: |
|
if progress_stop_event.is_set(): |
|
break |
|
print(f"{msg}", flush=True) |
|
time.sleep(random.uniform(0.1, 0.4)) |
|
|
|
def signal_handler(signal, frame): |
|
print("Ctrl-C pressed, exiting...") |
|
global handle |
|
if handle: |
|
abort(handle) |
|
destroy(handle) |
|
exit(0) |
|
|
|
signal.signal(signal.SIGINT, signal_handler) |
|
os.environ["RKLLM_LOG_LEVEL"] = "1" |
|
|
|
inference_count = 0 |
|
inference_start_time = 0 |
|
def result_callback(result, userdata, state): |
|
nonlocal inference_start_time, inference_count |
|
if state == LLMCallState.RKLLM_RUN_NORMAL: |
|
if inference_count == 0: |
|
progress_stop_event.set() |
|
first_token_time = time.time() |
|
print("🎉 完成!") |
|
print(f"\nTime to first token: {first_token_time - inference_start_time:.2f} seconds") |
|
inference_count += 1 |
|
print(result.contents.text.decode(), end="", flush=True) |
|
elif state == LLMCallState.RKLLM_RUN_FINISH: |
|
print("\n\n(finished)") |
|
inference_done_queue.put("DONE") |
|
elif state == LLMCallState.RKLLM_RUN_ERROR: |
|
print("\nError occurred during LLM call") |
|
inference_done_queue.put("ERROR") |
|
|
|
|
|
param = create_default_param() |
|
param.model_path = MODEL_PATH.encode() |
|
param.img_start = "<|audio_bos|>".encode() |
|
param.img_end = "<|audio_eos|>".encode() |
|
param.img_content = "<|AUDIO|>".encode() |
|
param.max_context_len = 768 |
|
param.max_new_tokens = 256 |
|
extend_param = RKLLMExtendParam() |
|
extend_param.base_domain_id = 1 |
|
param.extend_param = extend_param |
|
|
|
model_size = os.path.getsize(MODEL_PATH) |
|
print(f"Start loading language model (size: {model_size / 1024 / 1024:.2f} MB)") |
|
start_time = time.time() |
|
handle = init(param, result_callback) |
|
end_time = time.time() |
|
print(f"Language model loaded in {end_time - start_time:.2f} seconds") |
|
|
|
|
|
load_ready_queue.put("llm_ready") |
|
|
|
|
|
infer_param = RKLLMInferParam() |
|
infer_param.mode = RKLLMInferMode.RKLLM_INFER_GENERATE.value |
|
|
|
while True: |
|
prompt = prompt_queue.get() |
|
print(f"Received prompt: ===={prompt}\n====") |
|
if prompt == "STOP": |
|
break |
|
|
|
|
|
inference_count = 0 |
|
progress_stop_event.clear() |
|
|
|
|
|
progress_thread = threading.Thread(target=show_progress) |
|
progress_thread.daemon = True |
|
|
|
|
|
image_embeddings = embedding_queue.get() |
|
if isinstance(image_embeddings, str) and image_embeddings == "ERROR": |
|
print("Error processing audio") |
|
continue |
|
print(image_embeddings.shape) |
|
rkllm_input = create_rkllm_input(RKLLMInputType.RKLLM_INPUT_MULTIMODAL, |
|
prompt=prompt, |
|
image_embed=image_embeddings) |
|
print(f"Start LLM inference...") |
|
inference_start_time = time.time() |
|
run(handle, rkllm_input, infer_param, None) |
|
|
|
|
|
destroy(handle) |
|
|
|
def main(): |
|
load_ready_queue = Queue() |
|
embedding_queue = Queue() |
|
audio_path_queue = Queue() |
|
prompt_queue = Queue() |
|
inference_done_queue = Queue() |
|
start_event = Event() |
|
|
|
audio_process = Process(target=audio_encoder_process, |
|
args=(load_ready_queue, embedding_queue, audio_path_queue, start_event)) |
|
lm_process = Process(target=llm_process, |
|
args=(load_ready_queue, embedding_queue, prompt_queue, inference_done_queue, start_event)) |
|
|
|
audio_process.start() |
|
time.sleep(10) |
|
lm_process.start() |
|
|
|
|
|
ready_count = 0 |
|
while ready_count < 2: |
|
status = load_ready_queue.get() |
|
print(f"Received ready signal: {status}") |
|
ready_count += 1 |
|
|
|
print("All models loaded, starting interactive mode...") |
|
start_event.set() |
|
|
|
|
|
try: |
|
while True: |
|
print(""" |
|
Enter your input (3 empty lines to start inference, Ctrl+C to exit, for example: |
|
这是什么声音{{glass-breaking.wav}}? |
|
What kind of sound is in {{./test.mp3}}? |
|
Describe the audio in {{./test.mp3}} |
|
这是什么动物的叫声{{./jntm.mp3}}? |
|
): |
|
""") |
|
user_input = [] |
|
empty_lines = 0 |
|
|
|
while empty_lines < 3: |
|
line = input() |
|
if line.strip() == "": |
|
empty_lines += 1 |
|
else: |
|
empty_lines = 0 |
|
user_input.append(line) |
|
|
|
|
|
full_input = "\n".join(user_input[:-3]) |
|
import re |
|
img_match = re.search(r'\{\{(.+?)\}\}', full_input) |
|
if not img_match: |
|
print("No image path found in input") |
|
continue |
|
|
|
img_path = img_match.group(1) |
|
|
|
prompt = f"""<|im_start|>system |
|
You are a helpful assistant.<|im_end|> |
|
<|im_start|>user |
|
Audio 1: <image> |
|
{full_input.replace(img_match.group(0), '')}<|im_end|> |
|
<|im_start|>assistant |
|
""" |
|
audio_path_queue.put(img_path) |
|
prompt_queue.put(prompt) |
|
|
|
|
|
status = inference_done_queue.get() |
|
if status == "ERROR": |
|
print("Inference failed") |
|
|
|
except KeyboardInterrupt: |
|
print("\nExiting...") |
|
audio_path_queue.put("STOP") |
|
prompt_queue.put("STOP") |
|
|
|
audio_process.join() |
|
lm_process.join() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|