|
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; |
|
use image::guess_format; |
|
use prost::Message as _; |
|
use reqwest::Client; |
|
use uuid::Uuid; |
|
|
|
use crate::{ |
|
app::{ |
|
constant::EMPTY_STRING, |
|
lazy::DEFAULT_INSTRUCTIONS, |
|
model::{AppConfig, VisionAbility}, |
|
}, |
|
common::client::HTTP_CLIENT, |
|
}; |
|
|
|
use super::{ |
|
aiserver::v1::{ |
|
conversation_message, image_proto, AzureState, ChatExternalLink, ConversationMessage, ExplicitContext, GetChatRequest, ImageProto, ModelDetails |
|
}, |
|
constant::{ERR_UNSUPPORTED_GIF, ERR_UNSUPPORTED_IMAGE_FORMAT, LONG_CONTEXT_MODELS}, |
|
model::{Message, MessageContent, Role}, |
|
}; |
|
|
|
async fn process_chat_inputs( |
|
inputs: Vec<Message>, |
|
disable_vision: bool, |
|
) -> (String, Vec<ConversationMessage>, Vec<String>) { |
|
|
|
let instructions = inputs |
|
.iter() |
|
.filter(|input| input.role == Role::System) |
|
.map(|input| match &input.content { |
|
MessageContent::Text(text) => text.clone(), |
|
MessageContent::Vision(contents) => contents |
|
.iter() |
|
.filter_map(|content| { |
|
if content.content_type == "text" { |
|
content.text.clone() |
|
} else { |
|
None |
|
} |
|
}) |
|
.collect::<Vec<String>>() |
|
.join("\n"), |
|
}) |
|
.collect::<Vec<String>>() |
|
.join("\n\n"); |
|
|
|
|
|
let instructions = if instructions.is_empty() { |
|
DEFAULT_INSTRUCTIONS.clone() |
|
} else { |
|
instructions |
|
}; |
|
|
|
|
|
let mut chat_inputs: Vec<Message> = inputs |
|
.into_iter() |
|
.filter(|input| input.role == Role::User || input.role == Role::Assistant) |
|
.collect(); |
|
|
|
|
|
if chat_inputs.is_empty() { |
|
return ( |
|
instructions, |
|
vec![ConversationMessage { |
|
text: EMPTY_STRING.into(), |
|
r#type: conversation_message::MessageType::Human as i32, |
|
attached_code_chunks: vec![], |
|
codebase_context_chunks: vec![], |
|
commits: vec![], |
|
pull_requests: vec![], |
|
git_diffs: vec![], |
|
assistant_suggested_diffs: vec![], |
|
interpreter_results: vec![], |
|
images: vec![], |
|
attached_folders: vec![], |
|
approximate_lint_errors: vec![], |
|
bubble_id: Uuid::new_v4().to_string(), |
|
server_bubble_id: None, |
|
attached_folders_new: vec![], |
|
lints: vec![], |
|
user_responses_to_suggested_code_blocks: vec![], |
|
relevant_files: vec![], |
|
tool_results: vec![], |
|
notepads: vec![], |
|
is_capability_iteration: Some(false), |
|
capabilities: vec![], |
|
edit_trail_contexts: vec![], |
|
suggested_code_blocks: vec![], |
|
diffs_for_compressing_files: vec![], |
|
multi_file_linter_errors: vec![], |
|
diff_histories: vec![], |
|
recently_viewed_files: vec![], |
|
recent_locations_history: vec![], |
|
is_agentic: false, |
|
file_diff_trajectories: vec![], |
|
conversation_summary: None, |
|
}], |
|
vec![], |
|
); |
|
} |
|
|
|
|
|
chat_inputs = chat_inputs |
|
.into_iter() |
|
.map(|mut input| { |
|
if let (Role::Assistant, MessageContent::Text(text)) = (&input.role, &input.content) { |
|
if text.starts_with("WebReferences:") { |
|
if let Some(pos) = text.find("\n\n") { |
|
input.content = MessageContent::Text(text[pos + 2..].to_owned()); |
|
} |
|
} |
|
} |
|
input |
|
}) |
|
.collect(); |
|
|
|
|
|
if chat_inputs |
|
.first() |
|
.map_or(false, |input| input.role == Role::Assistant) |
|
{ |
|
chat_inputs.insert( |
|
0, |
|
Message { |
|
role: Role::User, |
|
content: MessageContent::Text(EMPTY_STRING.into()), |
|
}, |
|
); |
|
} |
|
|
|
|
|
let mut i = 1; |
|
while i < chat_inputs.len() { |
|
if chat_inputs[i].role == chat_inputs[i - 1].role { |
|
let insert_role = if chat_inputs[i].role == Role::User { |
|
Role::Assistant |
|
} else { |
|
Role::User |
|
}; |
|
chat_inputs.insert( |
|
i, |
|
Message { |
|
role: insert_role, |
|
content: MessageContent::Text(EMPTY_STRING.into()), |
|
}, |
|
); |
|
} |
|
i += 1; |
|
} |
|
|
|
|
|
if chat_inputs |
|
.last() |
|
.map_or(false, |input| input.role == Role::Assistant) |
|
{ |
|
chat_inputs.push(Message { |
|
role: Role::User, |
|
content: MessageContent::Text(EMPTY_STRING.into()), |
|
}); |
|
} |
|
|
|
|
|
let mut messages = Vec::new(); |
|
for input in chat_inputs { |
|
let (text, images) = match input.content { |
|
MessageContent::Text(text) => (text, vec![]), |
|
MessageContent::Vision(contents) => { |
|
let mut text_parts = Vec::new(); |
|
let mut images = Vec::new(); |
|
|
|
for content in contents { |
|
match content.content_type.as_str() { |
|
"text" => { |
|
if let Some(text) = content.text { |
|
text_parts.push(text); |
|
} |
|
} |
|
"image_url" => { |
|
if !disable_vision { |
|
if let Some(image_url) = &content.image_url { |
|
let url = image_url.url.clone(); |
|
let client = HTTP_CLIENT.read().clone(); |
|
let result = tokio::spawn(async move { |
|
fetch_image_data(&url, client).await |
|
}); |
|
if let Ok(Ok((image_data, dimensions))) = result.await { |
|
images.push(ImageProto { |
|
data: image_data, |
|
dimension: dimensions, |
|
}); |
|
} |
|
} |
|
} |
|
} |
|
_ => {} |
|
} |
|
} |
|
(text_parts.join("\n"), images) |
|
} |
|
}; |
|
|
|
messages.push(ConversationMessage { |
|
text, |
|
r#type: if input.role == Role::User { |
|
conversation_message::MessageType::Human as i32 |
|
} else { |
|
conversation_message::MessageType::Ai as i32 |
|
}, |
|
attached_code_chunks: vec![], |
|
codebase_context_chunks: vec![], |
|
commits: vec![], |
|
pull_requests: vec![], |
|
git_diffs: vec![], |
|
assistant_suggested_diffs: vec![], |
|
interpreter_results: vec![], |
|
images, |
|
attached_folders: vec![], |
|
approximate_lint_errors: vec![], |
|
bubble_id: Uuid::new_v4().to_string(), |
|
server_bubble_id: None, |
|
attached_folders_new: vec![], |
|
lints: vec![], |
|
user_responses_to_suggested_code_blocks: vec![], |
|
relevant_files: vec![], |
|
tool_results: vec![], |
|
notepads: vec![], |
|
is_capability_iteration: None, |
|
capabilities: vec![], |
|
edit_trail_contexts: vec![], |
|
suggested_code_blocks: vec![], |
|
diffs_for_compressing_files: vec![], |
|
multi_file_linter_errors: vec![], |
|
diff_histories: vec![], |
|
recently_viewed_files: vec![], |
|
recent_locations_history: vec![], |
|
is_agentic: false, |
|
file_diff_trajectories: vec![], |
|
conversation_summary: None, |
|
}); |
|
} |
|
|
|
let mut urls = Vec::new(); |
|
if let Some(last_msg) = messages.last() { |
|
if last_msg.r#type == conversation_message::MessageType::Human as i32 { |
|
let text = &last_msg.text; |
|
let mut chars = text.chars().peekable(); |
|
|
|
while let Some(c) = chars.next() { |
|
if c == '@' { |
|
let mut url = String::new(); |
|
while let Some(&next_char) = chars.peek() { |
|
if next_char.is_whitespace() { |
|
break; |
|
} |
|
url.push(chars.next().unwrap()); |
|
} |
|
if let Ok(parsed_url) = url::Url::parse(&url) { |
|
if parsed_url.scheme() == "http" || parsed_url.scheme() == "https" { |
|
urls.push(url); |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
(instructions, messages, urls) |
|
} |
|
|
|
async fn fetch_image_data( |
|
url: &str, |
|
client: Client, |
|
) -> Result<(Vec<u8>, Option<image_proto::Dimension>), Box<dyn std::error::Error + Send + Sync>> { |
|
|
|
let vision_ability = AppConfig::get_vision_ability(); |
|
|
|
match vision_ability { |
|
VisionAbility::None => Err("图片功能已禁用".into()), |
|
|
|
VisionAbility::Base64 => { |
|
if !url.starts_with("data:image/") { |
|
return Err("仅支持 base64 编码的图片".into()); |
|
} |
|
process_base64_image(url) |
|
} |
|
|
|
VisionAbility::All => { |
|
if url.starts_with("data:image/") { |
|
process_base64_image(url) |
|
} else { |
|
process_http_image(url, client).await |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
fn process_base64_image( |
|
url: &str, |
|
) -> Result<(Vec<u8>, Option<image_proto::Dimension>), Box<dyn std::error::Error + Send + Sync>> { |
|
let parts: Vec<&str> = url.split("base64,").collect(); |
|
if parts.len() != 2 { |
|
return Err("无效的 base64 图片格式".into()); |
|
} |
|
|
|
|
|
let format = parts[0].to_lowercase(); |
|
if !format.contains("png") |
|
&& !format.contains("jpeg") |
|
&& !format.contains("jpg") |
|
&& !format.contains("webp") |
|
&& !format.contains("gif") |
|
{ |
|
return Err(ERR_UNSUPPORTED_IMAGE_FORMAT.into()); |
|
} |
|
|
|
let image_data = BASE64.decode(parts[1])?; |
|
|
|
|
|
if format.contains("gif") { |
|
if let Ok(frames) = gif::DecodeOptions::new().read_info(std::io::Cursor::new(&image_data)) { |
|
if frames.into_iter().count() > 1 { |
|
return Err(ERR_UNSUPPORTED_GIF.into()); |
|
} |
|
} |
|
} |
|
|
|
|
|
let dimensions = if let Ok(img) = image::load_from_memory(&image_data) { |
|
Some(image_proto::Dimension { |
|
width: img.width() as i32, |
|
height: img.height() as i32, |
|
}) |
|
} else { |
|
None |
|
}; |
|
|
|
Ok((image_data, dimensions)) |
|
} |
|
|
|
|
|
async fn process_http_image( |
|
url: &str, |
|
client: Client, |
|
) -> Result<(Vec<u8>, Option<image_proto::Dimension>), Box<dyn std::error::Error + Send + Sync>> { |
|
let response = client.get(url).send().await?; |
|
let image_data = response.bytes().await?.to_vec(); |
|
let format = guess_format(&image_data)?; |
|
|
|
|
|
match format { |
|
image::ImageFormat::Png | image::ImageFormat::Jpeg | image::ImageFormat::WebP => { |
|
|
|
} |
|
image::ImageFormat::Gif => { |
|
if let Ok(frames) = |
|
gif::DecodeOptions::new().read_info(std::io::Cursor::new(&image_data)) |
|
{ |
|
if frames.into_iter().count() > 1 { |
|
return Err(ERR_UNSUPPORTED_GIF.into()); |
|
} |
|
} |
|
} |
|
_ => return Err(ERR_UNSUPPORTED_IMAGE_FORMAT.into()), |
|
} |
|
|
|
|
|
let dimensions = if let Ok(img) = image::load_from_memory_with_format(&image_data, format) { |
|
Some(image_proto::Dimension { |
|
width: img.width() as i32, |
|
height: img.height() as i32, |
|
}) |
|
} else { |
|
None |
|
}; |
|
|
|
Ok((image_data, dimensions)) |
|
} |
|
|
|
pub async fn encode_chat_message( |
|
inputs: Vec<Message>, |
|
model_name: &str, |
|
disable_vision: bool, |
|
enable_slow_pool: bool, |
|
is_search: bool, |
|
) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> { |
|
|
|
let enable_slow_pool = { |
|
if enable_slow_pool { |
|
Some(true) |
|
} else { |
|
None |
|
} |
|
}; |
|
|
|
let (instructions, messages, urls) = process_chat_inputs(inputs, disable_vision).await; |
|
|
|
let explicit_context = if !instructions.trim().is_empty() { |
|
Some(ExplicitContext { |
|
context: instructions, |
|
repo_context: None, |
|
}) |
|
} else { |
|
None |
|
}; |
|
|
|
let base_uuid = rand::random::<u16>(); |
|
let external_links = urls.into_iter().enumerate().map(|(i, url)| { |
|
let uuid = base_uuid.wrapping_add(i as u16); |
|
ChatExternalLink { |
|
url, |
|
uuid: uuid.to_string(), |
|
} |
|
}).collect(); |
|
|
|
let chat = GetChatRequest { |
|
current_file: None, |
|
conversation: messages, |
|
repositories: vec![], |
|
explicit_context, |
|
workspace_root_path: None, |
|
code_blocks: vec![], |
|
model_details: Some(ModelDetails { |
|
model_name: Some(model_name.to_string()), |
|
api_key: None, |
|
enable_ghost_mode: None, |
|
azure_state: Some(AzureState { |
|
api_key: String::new(), |
|
base_url: String::new(), |
|
deployment: String::new(), |
|
use_azure: false, |
|
}), |
|
enable_slow_pool, |
|
openai_api_base_url: None, |
|
}), |
|
documentation_identifiers: vec![], |
|
request_id: Uuid::new_v4().to_string(), |
|
linter_errors: None, |
|
summary: None, |
|
summary_up_until_index: None, |
|
allow_long_file_scan: Some(false), |
|
is_bash: Some(false), |
|
conversation_id: Uuid::new_v4().to_string(), |
|
can_handle_filenames_after_language_ids: Some(true), |
|
use_web: if is_search { |
|
Some("full_search".to_string()) |
|
} else { |
|
None |
|
}, |
|
quotes: vec![], |
|
debug_info: None, |
|
workspace_id: None, |
|
external_links, |
|
commit_notes: vec![], |
|
long_context_mode: Some(LONG_CONTEXT_MODELS.contains(&model_name)), |
|
is_eval: Some(false), |
|
desired_max_tokens: None, |
|
context_ast: None, |
|
is_composer: None, |
|
runnable_code_blocks: Some(false), |
|
should_cache: Some(false), |
|
}; |
|
|
|
let mut encoded = Vec::new(); |
|
chat.encode(&mut encoded)?; |
|
|
|
let len_prefix = format!("{:010x}", encoded.len()).to_uppercase(); |
|
let content = hex::encode_upper(&encoded); |
|
|
|
Ok(hex::decode(len_prefix + &content)?) |
|
} |
|
|