|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use crate::core::Id; |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)] |
|
|
pub enum Role { |
|
|
|
|
|
System, |
|
|
|
|
|
User, |
|
|
|
|
|
Assistant, |
|
|
|
|
|
Tool, |
|
|
|
|
|
Context, |
|
|
} |
|
|
|
|
|
impl Role { |
|
|
pub fn as_str(&self) -> &'static str { |
|
|
match self { |
|
|
Role::System => "system", |
|
|
Role::User => "user", |
|
|
Role::Assistant => "assistant", |
|
|
Role::Tool => "tool", |
|
|
Role::Context => "context", |
|
|
} |
|
|
} |
|
|
|
|
|
pub fn from_str(s: &str) -> Option<Self> { |
|
|
match s.to_lowercase().as_str() { |
|
|
"system" => Some(Role::System), |
|
|
"user" => Some(Role::User), |
|
|
"assistant" => Some(Role::Assistant), |
|
|
"tool" | "function" => Some(Role::Tool), |
|
|
"context" | "retrieved" => Some(Role::Context), |
|
|
_ => None, |
|
|
} |
|
|
} |
|
|
|
|
|
fn to_byte(&self) -> u8 { |
|
|
match self { |
|
|
Role::System => 0, |
|
|
Role::User => 1, |
|
|
Role::Assistant => 2, |
|
|
Role::Tool => 3, |
|
|
Role::Context => 4, |
|
|
} |
|
|
} |
|
|
|
|
|
fn from_byte(b: u8) -> Option<Self> { |
|
|
match b { |
|
|
0 => Some(Role::System), |
|
|
1 => Some(Role::User), |
|
|
2 => Some(Role::Assistant), |
|
|
3 => Some(Role::Tool), |
|
|
4 => Some(Role::Context), |
|
|
_ => None, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub struct CompressedKV { |
|
|
|
|
|
pub model_id: String, |
|
|
|
|
|
|
|
|
pub num_layers: u32, |
|
|
|
|
|
|
|
|
pub num_heads: u32, |
|
|
|
|
|
|
|
|
pub head_dim: u32, |
|
|
|
|
|
|
|
|
pub seq_len: u32, |
|
|
|
|
|
|
|
|
pub quantization: String, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub data: Vec<u8>, |
|
|
} |
|
|
|
|
|
impl CompressedKV { |
|
|
|
|
|
pub fn size_bytes(&self) -> usize { |
|
|
self.data.len() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn placeholder(model_id: &str) -> Self { |
|
|
Self { |
|
|
model_id: model_id.to_string(), |
|
|
num_layers: 0, |
|
|
num_heads: 0, |
|
|
head_dim: 0, |
|
|
seq_len: 0, |
|
|
quantization: "none".to_string(), |
|
|
data: vec![], |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn to_bytes(&self) -> Vec<u8> { |
|
|
let mut bytes = Vec::new(); |
|
|
|
|
|
|
|
|
let model_bytes = self.model_id.as_bytes(); |
|
|
bytes.extend_from_slice(&(model_bytes.len() as u32).to_le_bytes()); |
|
|
bytes.extend_from_slice(model_bytes); |
|
|
|
|
|
|
|
|
bytes.extend_from_slice(&self.num_layers.to_le_bytes()); |
|
|
bytes.extend_from_slice(&self.num_heads.to_le_bytes()); |
|
|
bytes.extend_from_slice(&self.head_dim.to_le_bytes()); |
|
|
bytes.extend_from_slice(&self.seq_len.to_le_bytes()); |
|
|
|
|
|
|
|
|
let quant_bytes = self.quantization.as_bytes(); |
|
|
bytes.extend_from_slice(&(quant_bytes.len() as u32).to_le_bytes()); |
|
|
bytes.extend_from_slice(quant_bytes); |
|
|
|
|
|
|
|
|
bytes.extend_from_slice(&(self.data.len() as u64).to_le_bytes()); |
|
|
bytes.extend_from_slice(&self.data); |
|
|
|
|
|
bytes |
|
|
} |
|
|
|
|
|
|
|
|
pub fn from_bytes(data: &[u8]) -> Option<(Self, usize)> { |
|
|
let mut offset = 0; |
|
|
|
|
|
|
|
|
if data.len() < offset + 4 { |
|
|
return None; |
|
|
} |
|
|
let model_len = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?) as usize; |
|
|
offset += 4; |
|
|
|
|
|
if data.len() < offset + model_len { |
|
|
return None; |
|
|
} |
|
|
let model_id = String::from_utf8(data[offset..offset + model_len].to_vec()).ok()?; |
|
|
offset += model_len; |
|
|
|
|
|
|
|
|
if data.len() < offset + 16 { |
|
|
return None; |
|
|
} |
|
|
let num_layers = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?); |
|
|
offset += 4; |
|
|
let num_heads = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?); |
|
|
offset += 4; |
|
|
let head_dim = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?); |
|
|
offset += 4; |
|
|
let seq_len = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?); |
|
|
offset += 4; |
|
|
|
|
|
|
|
|
if data.len() < offset + 4 { |
|
|
return None; |
|
|
} |
|
|
let quant_len = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?) as usize; |
|
|
offset += 4; |
|
|
|
|
|
if data.len() < offset + quant_len { |
|
|
return None; |
|
|
} |
|
|
let quantization = String::from_utf8(data[offset..offset + quant_len].to_vec()).ok()?; |
|
|
offset += quant_len; |
|
|
|
|
|
|
|
|
if data.len() < offset + 8 { |
|
|
return None; |
|
|
} |
|
|
let data_len = u64::from_le_bytes(data[offset..offset + 8].try_into().ok()?) as usize; |
|
|
offset += 8; |
|
|
|
|
|
if data.len() < offset + data_len { |
|
|
return None; |
|
|
} |
|
|
let kv_data = data[offset..offset + data_len].to_vec(); |
|
|
offset += data_len; |
|
|
|
|
|
Some(( |
|
|
Self { |
|
|
model_id, |
|
|
num_layers, |
|
|
num_heads, |
|
|
head_dim, |
|
|
seq_len, |
|
|
quantization, |
|
|
data: kv_data, |
|
|
}, |
|
|
offset, |
|
|
)) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub struct AttentionState { |
|
|
|
|
|
pub id: Id, |
|
|
|
|
|
|
|
|
pub timestamp_ms: u64, |
|
|
|
|
|
|
|
|
pub role: Role, |
|
|
|
|
|
|
|
|
pub text: String, |
|
|
|
|
|
|
|
|
pub embedding: Vec<f32>, |
|
|
|
|
|
|
|
|
pub kv_cache: Option<CompressedKV>, |
|
|
|
|
|
|
|
|
pub metadata: std::collections::HashMap<String, String>, |
|
|
} |
|
|
|
|
|
impl AttentionState { |
|
|
|
|
|
pub fn new(role: Role, text: String, embedding: Vec<f32>) -> Self { |
|
|
Self { |
|
|
id: Id::now(), |
|
|
timestamp_ms: std::time::SystemTime::now() |
|
|
.duration_since(std::time::UNIX_EPOCH) |
|
|
.unwrap() |
|
|
.as_millis() as u64, |
|
|
role, |
|
|
text, |
|
|
embedding, |
|
|
kv_cache: None, |
|
|
metadata: std::collections::HashMap::new(), |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn with_kv_cache(mut self, kv: CompressedKV) -> Self { |
|
|
self.kv_cache = Some(kv); |
|
|
self |
|
|
} |
|
|
|
|
|
|
|
|
pub fn with_metadata(mut self, key: &str, value: &str) -> Self { |
|
|
self.metadata.insert(key.to_string(), value.to_string()); |
|
|
self |
|
|
} |
|
|
|
|
|
|
|
|
pub fn size_bytes(&self) -> usize { |
|
|
16 + |
|
|
8 + |
|
|
1 + |
|
|
self.text.len() + |
|
|
self.embedding.len() * 4 + |
|
|
self.kv_cache.as_ref().map(|kv| kv.size_bytes()).unwrap_or(0) + |
|
|
self.metadata.iter().map(|(k, v)| k.len() + v.len() + 8).sum::<usize>() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn to_bytes(&self) -> Vec<u8> { |
|
|
let mut bytes = Vec::new(); |
|
|
|
|
|
|
|
|
bytes.extend_from_slice(b"ATTN"); |
|
|
bytes.extend_from_slice(&1u32.to_le_bytes()); |
|
|
|
|
|
|
|
|
bytes.extend_from_slice(self.id.as_bytes()); |
|
|
|
|
|
|
|
|
bytes.extend_from_slice(&self.timestamp_ms.to_le_bytes()); |
|
|
|
|
|
|
|
|
bytes.push(self.role.to_byte()); |
|
|
|
|
|
|
|
|
let text_bytes = self.text.as_bytes(); |
|
|
bytes.extend_from_slice(&(text_bytes.len() as u32).to_le_bytes()); |
|
|
bytes.extend_from_slice(text_bytes); |
|
|
|
|
|
|
|
|
bytes.extend_from_slice(&(self.embedding.len() as u32).to_le_bytes()); |
|
|
for &v in &self.embedding { |
|
|
bytes.extend_from_slice(&v.to_le_bytes()); |
|
|
} |
|
|
|
|
|
|
|
|
if let Some(ref kv) = self.kv_cache { |
|
|
bytes.push(1); |
|
|
let kv_bytes = kv.to_bytes(); |
|
|
bytes.extend_from_slice(&(kv_bytes.len() as u64).to_le_bytes()); |
|
|
bytes.extend_from_slice(&kv_bytes); |
|
|
} else { |
|
|
bytes.push(0); |
|
|
} |
|
|
|
|
|
|
|
|
bytes.extend_from_slice(&(self.metadata.len() as u32).to_le_bytes()); |
|
|
for (key, value) in &self.metadata { |
|
|
let key_bytes = key.as_bytes(); |
|
|
let value_bytes = value.as_bytes(); |
|
|
bytes.extend_from_slice(&(key_bytes.len() as u32).to_le_bytes()); |
|
|
bytes.extend_from_slice(key_bytes); |
|
|
bytes.extend_from_slice(&(value_bytes.len() as u32).to_le_bytes()); |
|
|
bytes.extend_from_slice(value_bytes); |
|
|
} |
|
|
|
|
|
bytes |
|
|
} |
|
|
|
|
|
|
|
|
pub fn from_bytes(data: &[u8]) -> Result<Self, AttentionError> { |
|
|
let mut offset = 0; |
|
|
|
|
|
|
|
|
if data.len() < 8 { |
|
|
return Err(AttentionError::InvalidFormat("Too short".into())); |
|
|
} |
|
|
if &data[0..4] != b"ATTN" { |
|
|
return Err(AttentionError::InvalidMagic); |
|
|
} |
|
|
offset += 4; |
|
|
|
|
|
|
|
|
let version = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); |
|
|
if version != 1 { |
|
|
return Err(AttentionError::UnsupportedVersion(version)); |
|
|
} |
|
|
offset += 4; |
|
|
|
|
|
|
|
|
if data.len() < offset + 16 { |
|
|
return Err(AttentionError::InvalidFormat("Missing ID".into())); |
|
|
} |
|
|
let mut id_bytes = [0u8; 16]; |
|
|
id_bytes.copy_from_slice(&data[offset..offset + 16]); |
|
|
let id = Id::from_bytes(id_bytes); |
|
|
offset += 16; |
|
|
|
|
|
|
|
|
if data.len() < offset + 8 { |
|
|
return Err(AttentionError::InvalidFormat("Missing timestamp".into())); |
|
|
} |
|
|
let timestamp_ms = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); |
|
|
offset += 8; |
|
|
|
|
|
|
|
|
if data.len() < offset + 1 { |
|
|
return Err(AttentionError::InvalidFormat("Missing role".into())); |
|
|
} |
|
|
let role = Role::from_byte(data[offset]) |
|
|
.ok_or_else(|| AttentionError::InvalidFormat("Invalid role".into()))?; |
|
|
offset += 1; |
|
|
|
|
|
|
|
|
if data.len() < offset + 4 { |
|
|
return Err(AttentionError::InvalidFormat("Missing text length".into())); |
|
|
} |
|
|
let text_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
|
|
offset += 4; |
|
|
|
|
|
if data.len() < offset + text_len { |
|
|
return Err(AttentionError::InvalidFormat("Text truncated".into())); |
|
|
} |
|
|
let text = String::from_utf8(data[offset..offset + text_len].to_vec()) |
|
|
.map_err(|_| AttentionError::InvalidFormat("Invalid UTF-8 in text".into()))?; |
|
|
offset += text_len; |
|
|
|
|
|
|
|
|
if data.len() < offset + 4 { |
|
|
return Err(AttentionError::InvalidFormat("Missing embedding length".into())); |
|
|
} |
|
|
let emb_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
|
|
offset += 4; |
|
|
|
|
|
if data.len() < offset + emb_len * 4 { |
|
|
return Err(AttentionError::InvalidFormat("Embedding truncated".into())); |
|
|
} |
|
|
let mut embedding = Vec::with_capacity(emb_len); |
|
|
for _ in 0..emb_len { |
|
|
embedding.push(f32::from_le_bytes(data[offset..offset + 4].try_into().unwrap())); |
|
|
offset += 4; |
|
|
} |
|
|
|
|
|
|
|
|
if data.len() < offset + 1 { |
|
|
return Err(AttentionError::InvalidFormat("Missing KV flag".into())); |
|
|
} |
|
|
let has_kv = data[offset] != 0; |
|
|
offset += 1; |
|
|
|
|
|
let kv_cache = if has_kv { |
|
|
if data.len() < offset + 8 { |
|
|
return Err(AttentionError::InvalidFormat("Missing KV length".into())); |
|
|
} |
|
|
let kv_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; |
|
|
offset += 8; |
|
|
|
|
|
if data.len() < offset + kv_len { |
|
|
return Err(AttentionError::InvalidFormat("KV data truncated".into())); |
|
|
} |
|
|
let (kv, _) = CompressedKV::from_bytes(&data[offset..offset + kv_len]) |
|
|
.ok_or_else(|| AttentionError::InvalidFormat("Invalid KV cache".into()))?; |
|
|
offset += kv_len; |
|
|
Some(kv) |
|
|
} else { |
|
|
None |
|
|
}; |
|
|
|
|
|
|
|
|
if data.len() < offset + 4 { |
|
|
return Err(AttentionError::InvalidFormat("Missing metadata count".into())); |
|
|
} |
|
|
let meta_count = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
|
|
offset += 4; |
|
|
|
|
|
let mut metadata = std::collections::HashMap::new(); |
|
|
for _ in 0..meta_count { |
|
|
|
|
|
if data.len() < offset + 4 { |
|
|
return Err(AttentionError::InvalidFormat("Missing key length".into())); |
|
|
} |
|
|
let key_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
|
|
offset += 4; |
|
|
|
|
|
if data.len() < offset + key_len { |
|
|
return Err(AttentionError::InvalidFormat("Key truncated".into())); |
|
|
} |
|
|
let key = String::from_utf8(data[offset..offset + key_len].to_vec()) |
|
|
.map_err(|_| AttentionError::InvalidFormat("Invalid UTF-8 in key".into()))?; |
|
|
offset += key_len; |
|
|
|
|
|
|
|
|
if data.len() < offset + 4 { |
|
|
return Err(AttentionError::InvalidFormat("Missing value length".into())); |
|
|
} |
|
|
let value_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
|
|
offset += 4; |
|
|
|
|
|
if data.len() < offset + value_len { |
|
|
return Err(AttentionError::InvalidFormat("Value truncated".into())); |
|
|
} |
|
|
let value = String::from_utf8(data[offset..offset + value_len].to_vec()) |
|
|
.map_err(|_| AttentionError::InvalidFormat("Invalid UTF-8 in value".into()))?; |
|
|
offset += value_len; |
|
|
|
|
|
metadata.insert(key, value); |
|
|
} |
|
|
|
|
|
Ok(Self { |
|
|
id, |
|
|
timestamp_ms, |
|
|
role, |
|
|
text, |
|
|
embedding, |
|
|
kv_cache, |
|
|
metadata, |
|
|
}) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub enum AttentionError { |
|
|
InvalidMagic, |
|
|
UnsupportedVersion(u32), |
|
|
InvalidFormat(String), |
|
|
} |
|
|
|
|
|
impl std::fmt::Display for AttentionError { |
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
|
|
match self { |
|
|
AttentionError::InvalidMagic => write!(f, "Invalid magic bytes"), |
|
|
AttentionError::UnsupportedVersion(v) => write!(f, "Unsupported version: {}", v), |
|
|
AttentionError::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg), |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
impl std::error::Error for AttentionError {} |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub struct AttentionBatch { |
|
|
|
|
|
pub states: Vec<AttentionState>, |
|
|
|
|
|
|
|
|
pub session_id: Option<Id>, |
|
|
|
|
|
|
|
|
pub document_id: Option<Id>, |
|
|
} |
|
|
|
|
|
impl AttentionBatch { |
|
|
pub fn new() -> Self { |
|
|
Self { |
|
|
states: Vec::new(), |
|
|
session_id: None, |
|
|
document_id: None, |
|
|
} |
|
|
} |
|
|
|
|
|
pub fn with_session(mut self, session_id: Id) -> Self { |
|
|
self.session_id = Some(session_id); |
|
|
self |
|
|
} |
|
|
|
|
|
pub fn with_document(mut self, document_id: Id) -> Self { |
|
|
self.document_id = Some(document_id); |
|
|
self |
|
|
} |
|
|
|
|
|
pub fn add(&mut self, state: AttentionState) { |
|
|
self.states.push(state); |
|
|
} |
|
|
|
|
|
|
|
|
pub fn size_bytes(&self) -> usize { |
|
|
self.states.iter().map(|s| s.size_bytes()).sum() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn to_bytes(&self) -> Vec<u8> { |
|
|
let mut bytes = Vec::new(); |
|
|
|
|
|
|
|
|
bytes.extend_from_slice(b"ATNB"); |
|
|
bytes.extend_from_slice(&1u32.to_le_bytes()); |
|
|
|
|
|
|
|
|
if let Some(sid) = self.session_id { |
|
|
bytes.push(1); |
|
|
bytes.extend_from_slice(sid.as_bytes()); |
|
|
} else { |
|
|
bytes.push(0); |
|
|
} |
|
|
|
|
|
|
|
|
if let Some(did) = self.document_id { |
|
|
bytes.push(1); |
|
|
bytes.extend_from_slice(did.as_bytes()); |
|
|
} else { |
|
|
bytes.push(0); |
|
|
} |
|
|
|
|
|
|
|
|
bytes.extend_from_slice(&(self.states.len() as u32).to_le_bytes()); |
|
|
|
|
|
|
|
|
for state in &self.states { |
|
|
let state_bytes = state.to_bytes(); |
|
|
bytes.extend_from_slice(&(state_bytes.len() as u64).to_le_bytes()); |
|
|
bytes.extend_from_slice(&state_bytes); |
|
|
} |
|
|
|
|
|
bytes |
|
|
} |
|
|
|
|
|
|
|
|
pub fn from_bytes(data: &[u8]) -> Result<Self, AttentionError> { |
|
|
let mut offset = 0; |
|
|
|
|
|
|
|
|
if data.len() < 8 { |
|
|
return Err(AttentionError::InvalidFormat("Too short".into())); |
|
|
} |
|
|
if &data[0..4] != b"ATNB" { |
|
|
return Err(AttentionError::InvalidMagic); |
|
|
} |
|
|
offset += 4; |
|
|
|
|
|
|
|
|
let version = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); |
|
|
if version != 1 { |
|
|
return Err(AttentionError::UnsupportedVersion(version)); |
|
|
} |
|
|
offset += 4; |
|
|
|
|
|
|
|
|
if data.len() < offset + 1 { |
|
|
return Err(AttentionError::InvalidFormat("Missing session flag".into())); |
|
|
} |
|
|
let has_session = data[offset] != 0; |
|
|
offset += 1; |
|
|
|
|
|
let session_id = if has_session { |
|
|
if data.len() < offset + 16 { |
|
|
return Err(AttentionError::InvalidFormat("Missing session ID".into())); |
|
|
} |
|
|
let mut id_bytes = [0u8; 16]; |
|
|
id_bytes.copy_from_slice(&data[offset..offset + 16]); |
|
|
offset += 16; |
|
|
Some(Id::from_bytes(id_bytes)) |
|
|
} else { |
|
|
None |
|
|
}; |
|
|
|
|
|
|
|
|
if data.len() < offset + 1 { |
|
|
return Err(AttentionError::InvalidFormat("Missing document flag".into())); |
|
|
} |
|
|
let has_document = data[offset] != 0; |
|
|
offset += 1; |
|
|
|
|
|
let document_id = if has_document { |
|
|
if data.len() < offset + 16 { |
|
|
return Err(AttentionError::InvalidFormat("Missing document ID".into())); |
|
|
} |
|
|
let mut id_bytes = [0u8; 16]; |
|
|
id_bytes.copy_from_slice(&data[offset..offset + 16]); |
|
|
offset += 16; |
|
|
Some(Id::from_bytes(id_bytes)) |
|
|
} else { |
|
|
None |
|
|
}; |
|
|
|
|
|
|
|
|
if data.len() < offset + 4 { |
|
|
return Err(AttentionError::InvalidFormat("Missing state count".into())); |
|
|
} |
|
|
let state_count = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
|
|
offset += 4; |
|
|
|
|
|
|
|
|
let mut states = Vec::with_capacity(state_count); |
|
|
for _ in 0..state_count { |
|
|
if data.len() < offset + 8 { |
|
|
return Err(AttentionError::InvalidFormat("Missing state length".into())); |
|
|
} |
|
|
let state_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; |
|
|
offset += 8; |
|
|
|
|
|
if data.len() < offset + state_len { |
|
|
return Err(AttentionError::InvalidFormat("State truncated".into())); |
|
|
} |
|
|
let state = AttentionState::from_bytes(&data[offset..offset + state_len])?; |
|
|
offset += state_len; |
|
|
states.push(state); |
|
|
} |
|
|
|
|
|
Ok(Self { |
|
|
states, |
|
|
session_id, |
|
|
document_id, |
|
|
}) |
|
|
} |
|
|
} |
|
|
|
|
|
impl Default for AttentionBatch { |
|
|
fn default() -> Self { |
|
|
Self::new() |
|
|
} |
|
|
} |
|
|
|
|
|
#[cfg(test)] |
|
|
mod tests { |
|
|
use super::*; |
|
|
|
|
|
#[test] |
|
|
fn test_role_roundtrip() { |
|
|
for role in [Role::System, Role::User, Role::Assistant, Role::Tool, Role::Context] { |
|
|
let byte = role.to_byte(); |
|
|
let restored = Role::from_byte(byte).unwrap(); |
|
|
assert_eq!(role, restored); |
|
|
} |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_attention_state_roundtrip() { |
|
|
let state = AttentionState::new( |
|
|
Role::User, |
|
|
"Hello, how are you?".to_string(), |
|
|
vec![0.1, 0.2, 0.3, 0.4], |
|
|
) |
|
|
.with_metadata("turn", "1"); |
|
|
|
|
|
let bytes = state.to_bytes(); |
|
|
let restored = AttentionState::from_bytes(&bytes).unwrap(); |
|
|
|
|
|
assert_eq!(state.role, restored.role); |
|
|
assert_eq!(state.text, restored.text); |
|
|
assert_eq!(state.embedding, restored.embedding); |
|
|
assert_eq!(state.metadata.get("turn"), restored.metadata.get("turn")); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_attention_state_with_kv() { |
|
|
let kv = CompressedKV { |
|
|
model_id: "llama-3-8b".to_string(), |
|
|
num_layers: 32, |
|
|
num_heads: 32, |
|
|
head_dim: 128, |
|
|
seq_len: 10, |
|
|
quantization: "fp16".to_string(), |
|
|
data: vec![1, 2, 3, 4, 5], |
|
|
}; |
|
|
|
|
|
let state = AttentionState::new( |
|
|
Role::Assistant, |
|
|
"I'm doing well!".to_string(), |
|
|
vec![0.5, 0.6, 0.7, 0.8], |
|
|
) |
|
|
.with_kv_cache(kv); |
|
|
|
|
|
let bytes = state.to_bytes(); |
|
|
let restored = AttentionState::from_bytes(&bytes).unwrap(); |
|
|
|
|
|
assert!(restored.kv_cache.is_some()); |
|
|
let restored_kv = restored.kv_cache.unwrap(); |
|
|
assert_eq!(restored_kv.model_id, "llama-3-8b"); |
|
|
assert_eq!(restored_kv.num_layers, 32); |
|
|
assert_eq!(restored_kv.data, vec![1, 2, 3, 4, 5]); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_batch_roundtrip() { |
|
|
let mut batch = AttentionBatch::new() |
|
|
.with_session(Id::now()); |
|
|
|
|
|
batch.add(AttentionState::new( |
|
|
Role::User, |
|
|
"Question 1".to_string(), |
|
|
vec![0.1, 0.2], |
|
|
)); |
|
|
|
|
|
batch.add(AttentionState::new( |
|
|
Role::Assistant, |
|
|
"Answer 1".to_string(), |
|
|
vec![0.3, 0.4], |
|
|
)); |
|
|
|
|
|
let bytes = batch.to_bytes(); |
|
|
let restored = AttentionBatch::from_bytes(&bytes).unwrap(); |
|
|
|
|
|
assert_eq!(restored.states.len(), 2); |
|
|
assert_eq!(restored.states[0].text, "Question 1"); |
|
|
assert_eq!(restored.states[1].text, "Answer 1"); |
|
|
assert!(restored.session_id.is_some()); |
|
|
} |
|
|
} |
|
|
|