|
use crate::chat::{ |
|
aiserver::v1::StreamChatResponse, |
|
error::{ChatError, StreamError}, |
|
}; |
|
use flate2::read::GzDecoder; |
|
use prost::Message; |
|
use std::{collections::BTreeMap, io::Read}; |
|
|
|
|
|
fn decompress_gzip(data: &[u8]) -> Option<Vec<u8>> { |
|
let mut decoder = GzDecoder::new(data); |
|
let mut decompressed = Vec::new(); |
|
|
|
match decoder.read_to_end(&mut decompressed) { |
|
Ok(_) => Some(decompressed), |
|
Err(_) => { |
|
|
|
None |
|
} |
|
} |
|
} |
|
|
|
pub trait ToMarkdown { |
|
fn to_markdown(&self) -> String; |
|
} |
|
|
|
impl ToMarkdown for BTreeMap<String, String> { |
|
fn to_markdown(&self) -> String { |
|
if self.is_empty() { |
|
return String::new(); |
|
} |
|
|
|
let mut result = String::from("WebReferences:\n"); |
|
for (i, (url, title)) in self.iter().enumerate() { |
|
result.push_str(&format!("{}. [{}]({})\n", i + 1, title, url)); |
|
} |
|
result.push_str("\n"); |
|
result |
|
} |
|
} |
|
|
|
#[derive(PartialEq, Clone, Debug)] |
|
pub enum StreamMessage { |
|
|
|
Debug(String), |
|
|
|
WebReference(BTreeMap<String, String>), |
|
|
|
ContentStart, |
|
|
|
Content(String), |
|
|
|
StreamEnd, |
|
} |
|
|
|
impl StreamMessage { |
|
fn convert_web_ref_to_content(self) -> Self { |
|
match self { |
|
StreamMessage::WebReference(refs) => StreamMessage::Content(refs.to_markdown()), |
|
other => other, |
|
} |
|
} |
|
} |
|
|
|
pub struct StreamDecoder { |
|
buffer: Vec<u8>, |
|
first_result: Option<Vec<StreamMessage>>, |
|
first_result_ready: bool, |
|
first_result_taken: bool, |
|
} |
|
|
|
impl StreamDecoder { |
|
pub fn new() -> Self { |
|
Self { |
|
buffer: Vec::new(), |
|
first_result: None, |
|
first_result_ready: false, |
|
first_result_taken: false, |
|
} |
|
} |
|
|
|
pub fn take_first_result(&mut self) -> Option<Vec<StreamMessage>> { |
|
if !self.buffer.is_empty() { |
|
return None; |
|
} |
|
if self.first_result.is_some() { |
|
self.first_result_taken = true; |
|
} |
|
self.first_result.take() |
|
} |
|
|
|
#[cfg(test)] |
|
fn is_incomplete(&self) -> bool { |
|
!self.buffer.is_empty() |
|
} |
|
|
|
pub fn is_first_result_ready(&self) -> bool { |
|
self.first_result_ready |
|
} |
|
|
|
pub fn decode(&mut self, data: &[u8], convert_web_ref: bool) -> Result<Vec<StreamMessage>, StreamError> { |
|
self.buffer.extend_from_slice(data); |
|
|
|
if self.buffer.len() < 5 { |
|
if self.buffer.len() == 0 { |
|
return Err(StreamError::EmptyStream); |
|
} |
|
crate::debug_println!("数据长度小于5字节,当前数据: {}", hex::encode(&self.buffer)); |
|
return Err(StreamError::DataLengthLessThan5); |
|
} |
|
|
|
let mut messages = Vec::new(); |
|
let mut offset = 0; |
|
|
|
while offset + 5 <= self.buffer.len() { |
|
let msg_type = self.buffer[offset]; |
|
let msg_len = u32::from_be_bytes([ |
|
self.buffer[offset + 1], |
|
self.buffer[offset + 2], |
|
self.buffer[offset + 3], |
|
self.buffer[offset + 4], |
|
]) as usize; |
|
|
|
if msg_len == 0 { |
|
offset += 5; |
|
messages.push(StreamMessage::ContentStart); |
|
continue; |
|
} |
|
|
|
if offset + 5 + msg_len > self.buffer.len() { |
|
break; |
|
} |
|
|
|
let msg_data = &self.buffer[offset + 5..offset + 5 + msg_len]; |
|
|
|
match self.process_message(msg_type, msg_data)? { |
|
Some(msg) => { |
|
if convert_web_ref { |
|
messages.push(msg.convert_web_ref_to_content()); |
|
} else { |
|
messages.push(msg); |
|
} |
|
} |
|
_ => {} |
|
} |
|
|
|
offset += 5 + msg_len; |
|
} |
|
|
|
self.buffer.drain(..offset); |
|
|
|
if !self.first_result_taken && !messages.is_empty() { |
|
if self.first_result.is_none() { |
|
self.first_result = Some(messages.clone()); |
|
} else if !self.first_result_ready { |
|
self.first_result.as_mut().unwrap().extend(messages.clone()); |
|
} |
|
} |
|
if !self.first_result_ready { |
|
self.first_result_ready = self.first_result.is_some() && self.buffer.is_empty() && !self.first_result_taken; |
|
} |
|
Ok(messages) |
|
} |
|
|
|
fn process_message( |
|
&self, |
|
msg_type: u8, |
|
msg_data: &[u8], |
|
) -> Result<Option<StreamMessage>, StreamError> { |
|
match msg_type { |
|
0 => self.handle_text_message(msg_data), |
|
1 => self.handle_gzip_message(msg_data), |
|
2 => self.handle_json_message(msg_data), |
|
3 => self.handle_gzip_json_message(msg_data), |
|
t => { |
|
eprintln!("收到未知消息类型: {},请尝试联系开发者以获取支持", t); |
|
crate::debug_println!("消息类型: {},消息内容: {}", t, hex::encode(msg_data)); |
|
Ok(None) |
|
} |
|
} |
|
} |
|
|
|
fn handle_text_message(&self, msg_data: &[u8]) -> Result<Option<StreamMessage>, StreamError> { |
|
if let Ok(response) = StreamChatResponse::decode(msg_data) { |
|
|
|
if !response.text.is_empty() { |
|
Ok(Some(StreamMessage::Content(response.text))) |
|
} else if let Some(filled_prompt) = response.filled_prompt { |
|
Ok(Some(StreamMessage::Debug(filled_prompt))) |
|
} else if let Some(web_citation) = response.web_citation { |
|
let mut refs = BTreeMap::new(); |
|
for reference in web_citation.references { |
|
refs.insert(reference.url, reference.title); |
|
} |
|
Ok(Some(StreamMessage::WebReference(refs))) |
|
} else { |
|
Ok(None) |
|
} |
|
} else { |
|
Ok(None) |
|
} |
|
} |
|
|
|
fn handle_gzip_message(&self, msg_data: &[u8]) -> Result<Option<StreamMessage>, StreamError> { |
|
if let Some(text) = decompress_gzip(msg_data) { |
|
if let Ok(response) = StreamChatResponse::decode(&text[..]) { |
|
|
|
if !response.text.is_empty() { |
|
Ok(Some(StreamMessage::Content(response.text))) |
|
} else if let Some(filled_prompt) = response.filled_prompt { |
|
Ok(Some(StreamMessage::Debug(filled_prompt))) |
|
} else if let Some(web_citation) = response.web_citation { |
|
let mut refs = BTreeMap::new(); |
|
for reference in web_citation.references { |
|
refs.insert(reference.url, reference.title); |
|
} |
|
Ok(Some(StreamMessage::WebReference(refs))) |
|
} else { |
|
Ok(None) |
|
} |
|
} else { |
|
Ok(None) |
|
} |
|
} else { |
|
Ok(None) |
|
} |
|
} |
|
|
|
fn handle_json_message(&self, msg_data: &[u8]) -> Result<Option<StreamMessage>, StreamError> { |
|
if msg_data.len() == 2 { |
|
return Ok(Some(StreamMessage::StreamEnd)); |
|
} |
|
if let Ok(text) = String::from_utf8(msg_data.to_vec()) { |
|
|
|
if let Ok(error) = serde_json::from_str::<ChatError>(&text) { |
|
return Err(StreamError::ChatError(error)); |
|
} |
|
} |
|
Ok(None) |
|
} |
|
|
|
fn handle_gzip_json_message( |
|
&self, |
|
msg_data: &[u8], |
|
) -> Result<Option<StreamMessage>, StreamError> { |
|
if let Some(text) = decompress_gzip(msg_data) { |
|
if text.len() == 2 { |
|
return Ok(Some(StreamMessage::StreamEnd)); |
|
} |
|
if let Ok(text) = String::from_utf8(text) { |
|
|
|
if let Ok(error) = serde_json::from_str::<ChatError>(&text) { |
|
return Err(StreamError::ChatError(error)); |
|
} |
|
} |
|
} |
|
Ok(None) |
|
} |
|
} |
|
|
|
#[cfg(test)] |
|
mod tests { |
|
use super::*; |
|
|
|
#[test] |
|
fn test_single_chunk() { |
|
|
|
let stream_data = include_str!("../../../tests/data/stream_data.txt"); |
|
|
|
|
|
let bytes: Vec<u8> = stream_data |
|
.as_bytes() |
|
.chunks(2) |
|
.map(|chunk| { |
|
let hex_str = std::str::from_utf8(chunk).unwrap(); |
|
u8::from_str_radix(hex_str, 16).unwrap() |
|
}) |
|
.collect(); |
|
|
|
|
|
let mut decoder = StreamDecoder::new(); |
|
|
|
match decoder.decode(&bytes, false) { |
|
Ok(messages) => { |
|
for message in messages { |
|
match message { |
|
StreamMessage::StreamEnd => { |
|
println!("流结束"); |
|
break; |
|
} |
|
StreamMessage::Content(msg) => { |
|
println!("消息内容: {}", msg); |
|
} |
|
StreamMessage::WebReference(refs) => { |
|
println!("网页引用:"); |
|
for (i, (url, title)) in refs.iter().enumerate() { |
|
println!("{}. {} - {}", i, url, title); |
|
} |
|
} |
|
StreamMessage::Debug(prompt) => { |
|
println!("调试信息: {}", prompt); |
|
} |
|
StreamMessage::ContentStart => { |
|
println!("流开始"); |
|
} |
|
} |
|
} |
|
} |
|
Err(e) => { |
|
println!("解析错误: {}", e); |
|
} |
|
} |
|
if decoder.is_incomplete() { |
|
println!("数据不完整"); |
|
} |
|
} |
|
|
|
#[test] |
|
fn test_multiple_chunks() { |
|
|
|
let stream_data = include_str!("../../../tests/data/stream_data.txt"); |
|
|
|
|
|
let bytes: Vec<u8> = stream_data |
|
.as_bytes() |
|
.chunks(2) |
|
.map(|chunk| { |
|
let hex_str = std::str::from_utf8(chunk).unwrap(); |
|
u8::from_str_radix(hex_str, 16).unwrap() |
|
}) |
|
.collect(); |
|
|
|
|
|
let mut decoder = StreamDecoder::new(); |
|
|
|
|
|
fn find_next_message_boundary(bytes: &[u8]) -> usize { |
|
if bytes.len() < 5 { |
|
return bytes.len(); |
|
} |
|
let msg_len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize; |
|
5 + msg_len |
|
} |
|
|
|
|
|
fn bytes_to_hex(bytes: &[u8]) -> String { |
|
bytes |
|
.iter() |
|
.map(|b| format!("{:02X}", b)) |
|
.collect::<Vec<String>>() |
|
.join("") |
|
} |
|
|
|
|
|
let mut offset = 0; |
|
let mut should_break = false; |
|
|
|
while offset < bytes.len() { |
|
let remaining_bytes = &bytes[offset..]; |
|
let msg_boundary = find_next_message_boundary(remaining_bytes); |
|
let current_msg_bytes = &remaining_bytes[..msg_boundary]; |
|
let hex_str = bytes_to_hex(current_msg_bytes); |
|
|
|
match decoder.decode(current_msg_bytes, false) { |
|
Ok(messages) => { |
|
for message in messages { |
|
match message { |
|
StreamMessage::StreamEnd => { |
|
println!("流结束 [hex: {}]", hex_str); |
|
should_break = true; |
|
break; |
|
} |
|
StreamMessage::Content(msg) => { |
|
println!("消息内容 [hex: {}]: {}", hex_str, msg); |
|
} |
|
StreamMessage::WebReference(refs) => { |
|
println!("网页引用 [hex: {}]:", hex_str); |
|
for (i, (url, title)) in refs.iter().enumerate() { |
|
println!("{}. {} - {}", i, url, title); |
|
} |
|
} |
|
StreamMessage::Debug(prompt) => { |
|
println!("调试信息 [hex: {}]: {}", hex_str, prompt); |
|
} |
|
StreamMessage::ContentStart => { |
|
println!("流开始 [hex: {}]", hex_str); |
|
} |
|
} |
|
} |
|
if should_break { |
|
break; |
|
} |
|
if decoder.is_incomplete() { |
|
println!("数据不完整 [hex: {}]", hex_str); |
|
break; |
|
} |
|
offset += msg_boundary; |
|
} |
|
Err(e) => { |
|
println!("解析错误 [hex: {}]: {}", hex_str, e); |
|
break; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|