cursor / src /chat /adapter.rs
ronghua's picture
up
3c2af29
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use image::guess_format;
use prost::Message as _;
use reqwest::Client;
use uuid::Uuid;
use crate::{
app::{
constant::EMPTY_STRING,
lazy::DEFAULT_INSTRUCTIONS,
model::{AppConfig, VisionAbility},
},
common::client::HTTP_CLIENT,
};
use super::{
aiserver::v1::{
conversation_message, image_proto, AzureState, ChatExternalLink, ConversationMessage, ExplicitContext, GetChatRequest, ImageProto, ModelDetails
},
constant::{ERR_UNSUPPORTED_GIF, ERR_UNSUPPORTED_IMAGE_FORMAT, LONG_CONTEXT_MODELS},
model::{Message, MessageContent, Role},
};
async fn process_chat_inputs(
inputs: Vec<Message>,
disable_vision: bool,
) -> (String, Vec<ConversationMessage>, Vec<String>) {
// 收集 system 指令
let instructions = inputs
.iter()
.filter(|input| input.role == Role::System)
.map(|input| match &input.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Vision(contents) => contents
.iter()
.filter_map(|content| {
if content.content_type == "text" {
content.text.clone()
} else {
None
}
})
.collect::<Vec<String>>()
.join("\n"),
})
.collect::<Vec<String>>()
.join("\n\n");
// 使用默认指令或收集到的指令
let instructions = if instructions.is_empty() {
DEFAULT_INSTRUCTIONS.clone()
} else {
instructions
};
// 过滤出 user 和 assistant 对话
let mut chat_inputs: Vec<Message> = inputs
.into_iter()
.filter(|input| input.role == Role::User || input.role == Role::Assistant)
.collect();
// 处理空对话情况
if chat_inputs.is_empty() {
return (
instructions,
vec![ConversationMessage {
text: EMPTY_STRING.into(),
r#type: conversation_message::MessageType::Human as i32,
attached_code_chunks: vec![],
codebase_context_chunks: vec![],
commits: vec![],
pull_requests: vec![],
git_diffs: vec![],
assistant_suggested_diffs: vec![],
interpreter_results: vec![],
images: vec![],
attached_folders: vec![],
approximate_lint_errors: vec![],
bubble_id: Uuid::new_v4().to_string(),
server_bubble_id: None,
attached_folders_new: vec![],
lints: vec![],
user_responses_to_suggested_code_blocks: vec![],
relevant_files: vec![],
tool_results: vec![],
notepads: vec![],
is_capability_iteration: Some(false),
capabilities: vec![],
edit_trail_contexts: vec![],
suggested_code_blocks: vec![],
diffs_for_compressing_files: vec![],
multi_file_linter_errors: vec![],
diff_histories: vec![],
recently_viewed_files: vec![],
recent_locations_history: vec![],
is_agentic: false,
file_diff_trajectories: vec![],
conversation_summary: None,
}],
vec![],
);
}
// 处理 WebReferences 开头的 assistant 消息
chat_inputs = chat_inputs
.into_iter()
.map(|mut input| {
if let (Role::Assistant, MessageContent::Text(text)) = (&input.role, &input.content) {
if text.starts_with("WebReferences:") {
if let Some(pos) = text.find("\n\n") {
input.content = MessageContent::Text(text[pos + 2..].to_owned());
}
}
}
input
})
.collect();
// 如果第一条是 assistant,插入空的 user 消息
if chat_inputs
.first()
.map_or(false, |input| input.role == Role::Assistant)
{
chat_inputs.insert(
0,
Message {
role: Role::User,
content: MessageContent::Text(EMPTY_STRING.into()),
},
);
}
// 处理连续相同角色的情况
let mut i = 1;
while i < chat_inputs.len() {
if chat_inputs[i].role == chat_inputs[i - 1].role {
let insert_role = if chat_inputs[i].role == Role::User {
Role::Assistant
} else {
Role::User
};
chat_inputs.insert(
i,
Message {
role: insert_role,
content: MessageContent::Text(EMPTY_STRING.into()),
},
);
}
i += 1;
}
// 确保最后一条是 user
if chat_inputs
.last()
.map_or(false, |input| input.role == Role::Assistant)
{
chat_inputs.push(Message {
role: Role::User,
content: MessageContent::Text(EMPTY_STRING.into()),
});
}
// 转换为 proto messages
let mut messages = Vec::new();
for input in chat_inputs {
let (text, images) = match input.content {
MessageContent::Text(text) => (text, vec![]),
MessageContent::Vision(contents) => {
let mut text_parts = Vec::new();
let mut images = Vec::new();
for content in contents {
match content.content_type.as_str() {
"text" => {
if let Some(text) = content.text {
text_parts.push(text);
}
}
"image_url" => {
if !disable_vision {
if let Some(image_url) = &content.image_url {
let url = image_url.url.clone();
let client = HTTP_CLIENT.read().clone();
let result = tokio::spawn(async move {
fetch_image_data(&url, client).await
});
if let Ok(Ok((image_data, dimensions))) = result.await {
images.push(ImageProto {
data: image_data,
dimension: dimensions,
});
}
}
}
}
_ => {}
}
}
(text_parts.join("\n"), images)
}
};
messages.push(ConversationMessage {
text,
r#type: if input.role == Role::User {
conversation_message::MessageType::Human as i32
} else {
conversation_message::MessageType::Ai as i32
},
attached_code_chunks: vec![],
codebase_context_chunks: vec![],
commits: vec![],
pull_requests: vec![],
git_diffs: vec![],
assistant_suggested_diffs: vec![],
interpreter_results: vec![],
images,
attached_folders: vec![],
approximate_lint_errors: vec![],
bubble_id: Uuid::new_v4().to_string(),
server_bubble_id: None,
attached_folders_new: vec![],
lints: vec![],
user_responses_to_suggested_code_blocks: vec![],
relevant_files: vec![],
tool_results: vec![],
notepads: vec![],
is_capability_iteration: None,
capabilities: vec![],
edit_trail_contexts: vec![],
suggested_code_blocks: vec![],
diffs_for_compressing_files: vec![],
multi_file_linter_errors: vec![],
diff_histories: vec![],
recently_viewed_files: vec![],
recent_locations_history: vec![],
is_agentic: false,
file_diff_trajectories: vec![],
conversation_summary: None,
});
}
let mut urls = Vec::new();
if let Some(last_msg) = messages.last() {
if last_msg.r#type == conversation_message::MessageType::Human as i32 {
let text = &last_msg.text;
let mut chars = text.chars().peekable();
while let Some(c) = chars.next() {
if c == '@' {
let mut url = String::new();
while let Some(&next_char) = chars.peek() {
if next_char.is_whitespace() {
break;
}
url.push(chars.next().unwrap());
}
if let Ok(parsed_url) = url::Url::parse(&url) {
if parsed_url.scheme() == "http" || parsed_url.scheme() == "https" {
urls.push(url);
}
}
}
}
}
}
(instructions, messages, urls)
}
async fn fetch_image_data(
url: &str,
client: Client,
) -> Result<(Vec<u8>, Option<image_proto::Dimension>), Box<dyn std::error::Error + Send + Sync>> {
// 在进入异步操作前获取并释放锁
let vision_ability = AppConfig::get_vision_ability();
match vision_ability {
VisionAbility::None => Err("图片功能已禁用".into()),
VisionAbility::Base64 => {
if !url.starts_with("data:image/") {
return Err("仅支持 base64 编码的图片".into());
}
process_base64_image(url)
}
VisionAbility::All => {
if url.starts_with("data:image/") {
process_base64_image(url)
} else {
process_http_image(url, client).await
}
}
}
}
// 处理 base64 编码的图片
fn process_base64_image(
url: &str,
) -> Result<(Vec<u8>, Option<image_proto::Dimension>), Box<dyn std::error::Error + Send + Sync>> {
let parts: Vec<&str> = url.split("base64,").collect();
if parts.len() != 2 {
return Err("无效的 base64 图片格式".into());
}
// 检查图片格式
let format = parts[0].to_lowercase();
if !format.contains("png")
&& !format.contains("jpeg")
&& !format.contains("jpg")
&& !format.contains("webp")
&& !format.contains("gif")
{
return Err(ERR_UNSUPPORTED_IMAGE_FORMAT.into());
}
let image_data = BASE64.decode(parts[1])?;
// 检查是否为动态 GIF
if format.contains("gif") {
if let Ok(frames) = gif::DecodeOptions::new().read_info(std::io::Cursor::new(&image_data)) {
if frames.into_iter().count() > 1 {
return Err(ERR_UNSUPPORTED_GIF.into());
}
}
}
// 获取图片尺寸
let dimensions = if let Ok(img) = image::load_from_memory(&image_data) {
Some(image_proto::Dimension {
width: img.width() as i32,
height: img.height() as i32,
})
} else {
None
};
Ok((image_data, dimensions))
}
// 处理 HTTP 图片 URL
async fn process_http_image(
url: &str,
client: Client,
) -> Result<(Vec<u8>, Option<image_proto::Dimension>), Box<dyn std::error::Error + Send + Sync>> {
let response = client.get(url).send().await?;
let image_data = response.bytes().await?.to_vec();
let format = guess_format(&image_data)?;
// 检查图片格式
match format {
image::ImageFormat::Png | image::ImageFormat::Jpeg | image::ImageFormat::WebP => {
// 这些格式都支持
}
image::ImageFormat::Gif => {
if let Ok(frames) =
gif::DecodeOptions::new().read_info(std::io::Cursor::new(&image_data))
{
if frames.into_iter().count() > 1 {
return Err(ERR_UNSUPPORTED_GIF.into());
}
}
}
_ => return Err(ERR_UNSUPPORTED_IMAGE_FORMAT.into()),
}
// 获取图片尺寸
let dimensions = if let Ok(img) = image::load_from_memory_with_format(&image_data, format) {
Some(image_proto::Dimension {
width: img.width() as i32,
height: img.height() as i32,
})
} else {
None
};
Ok((image_data, dimensions))
}
pub async fn encode_chat_message(
inputs: Vec<Message>,
model_name: &str,
disable_vision: bool,
enable_slow_pool: bool,
is_search: bool,
) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
// 在进入异步操作前获取并释放锁
let enable_slow_pool = {
if enable_slow_pool {
Some(true)
} else {
None
}
};
let (instructions, messages, urls) = process_chat_inputs(inputs, disable_vision).await;
let explicit_context = if !instructions.trim().is_empty() {
Some(ExplicitContext {
context: instructions,
repo_context: None,
})
} else {
None
};
let base_uuid = rand::random::<u16>();
let external_links = urls.into_iter().enumerate().map(|(i, url)| {
let uuid = base_uuid.wrapping_add(i as u16);
ChatExternalLink {
url,
uuid: uuid.to_string(),
}
}).collect();
let chat = GetChatRequest {
current_file: None,
conversation: messages,
repositories: vec![],
explicit_context,
workspace_root_path: None,
code_blocks: vec![],
model_details: Some(ModelDetails {
model_name: Some(model_name.to_string()),
api_key: None,
enable_ghost_mode: None,
azure_state: Some(AzureState {
api_key: String::new(),
base_url: String::new(),
deployment: String::new(),
use_azure: false,
}),
enable_slow_pool,
openai_api_base_url: None,
}),
documentation_identifiers: vec![],
request_id: Uuid::new_v4().to_string(),
linter_errors: None,
summary: None,
summary_up_until_index: None,
allow_long_file_scan: Some(false),
is_bash: Some(false),
conversation_id: Uuid::new_v4().to_string(),
can_handle_filenames_after_language_ids: Some(true),
use_web: if is_search {
Some("full_search".to_string())
} else {
None
},
quotes: vec![],
debug_info: None,
workspace_id: None,
external_links,
commit_notes: vec![],
long_context_mode: Some(LONG_CONTEXT_MODELS.contains(&model_name)),
is_eval: Some(false),
desired_max_tokens: None,
context_ast: None,
is_composer: None,
runnable_code_blocks: Some(false),
should_cache: Some(false),
};
let mut encoded = Vec::new();
chat.encode(&mut encoded)?;
let len_prefix = format!("{:010x}", encoded.len()).to_uppercase();
let content = hex::encode_upper(&encoded);
Ok(hex::decode(len_prefix + &content)?)
}