|
|
|
const BASE64_CHARS: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_"; |
|
|
|
|
|
const BASE64_LOOKUP: [i8; 256] = { |
|
let mut lookup = [-1i8; 256]; |
|
let mut i = 0; |
|
while i < BASE64_CHARS.len() { |
|
lookup[BASE64_CHARS[i] as usize] = i as i8; |
|
i += 1; |
|
} |
|
lookup |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn to_base64(bytes: &[u8]) -> String { |
|
|
|
let capacity = (bytes.len() + 2) / 3 * 4; |
|
let mut result = Vec::with_capacity(capacity); |
|
|
|
|
|
for chunk in bytes.chunks(3) { |
|
|
|
let b1 = chunk[0] as u32; |
|
let b2 = chunk.get(1).map_or(0, |&b| b as u32); |
|
let b3 = chunk.get(2).map_or(0, |&b| b as u32); |
|
|
|
let n = (b1 << 16) | (b2 << 8) | b3; |
|
|
|
|
|
result.push(BASE64_CHARS[(n >> 18) as usize]); |
|
result.push(BASE64_CHARS[((n >> 12) & 0x3F) as usize]); |
|
|
|
|
|
if chunk.len() > 1 { |
|
result.push(BASE64_CHARS[((n >> 6) & 0x3F) as usize]); |
|
|
|
if chunk.len() > 2 { |
|
result.push(BASE64_CHARS[(n & 0x3F) as usize]); |
|
} |
|
} |
|
} |
|
|
|
|
|
unsafe { String::from_utf8_unchecked(result) } |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn from_base64(input: &str) -> Option<Vec<u8>> { |
|
let input = input.as_bytes(); |
|
|
|
|
|
if input.is_empty() || input.len() % 4 == 1 { |
|
return None; |
|
} |
|
|
|
|
|
if input.iter().any(|&b| BASE64_LOOKUP[b as usize] == -1) { |
|
return None; |
|
} |
|
|
|
|
|
let capacity = input.len() / 4 * 3; |
|
let mut result = Vec::with_capacity(capacity); |
|
|
|
|
|
let mut chunks = input.chunks_exact(4); |
|
for chunk in &mut chunks { |
|
|
|
let n1 = BASE64_LOOKUP[chunk[0] as usize] as u32; |
|
let n2 = BASE64_LOOKUP[chunk[1] as usize] as u32; |
|
let n3 = BASE64_LOOKUP[chunk[2] as usize] as u32; |
|
let n4 = BASE64_LOOKUP[chunk[3] as usize] as u32; |
|
|
|
|
|
let n = (n1 << 18) | (n2 << 12) | (n3 << 6) | n4; |
|
result.push((n >> 16) as u8); |
|
result.push(((n >> 8) & 0xFF) as u8); |
|
result.push((n & 0xFF) as u8); |
|
} |
|
|
|
|
|
let remainder = chunks.remainder(); |
|
if !remainder.is_empty() { |
|
let n1 = BASE64_LOOKUP[remainder[0] as usize] as u32; |
|
let n2 = BASE64_LOOKUP[remainder[1] as usize] as u32; |
|
|
|
let mut n = (n1 << 18) | (n2 << 12); |
|
result.push((n >> 16) as u8); |
|
|
|
|
|
if remainder.len() > 2 { |
|
let n3 = BASE64_LOOKUP[remainder[2] as usize] as u32; |
|
n |= n3 << 6; |
|
result.push(((n >> 8) & 0xFF) as u8); |
|
} |
|
} |
|
|
|
Some(result) |
|
} |
|
|
|
#[cfg(test)] |
|
mod tests { |
|
use super::*; |
|
|
|
#[test] |
|
fn test_base64_roundtrip() { |
|
let test_cases = vec![ |
|
vec![0u8, 1, 2, 3], |
|
vec![255u8, 254, 253], |
|
vec![0u8], |
|
vec![0u8, 1], |
|
vec![0u8, 1, 2], |
|
vec![255u8; 1000], |
|
]; |
|
|
|
for case in test_cases { |
|
let encoded = to_base64(&case); |
|
let decoded = from_base64(&encoded).unwrap(); |
|
assert_eq!(case, decoded); |
|
} |
|
} |
|
|
|
#[test] |
|
fn test_invalid_input() { |
|
assert_eq!(from_base64(""), None); |
|
assert_eq!(from_base64("a"), None); |
|
assert_eq!(from_base64("!@#$"), None); |
|
assert_eq!(from_base64("YWJj!"), None); |
|
assert!(from_base64("YWJj").is_some()); |
|
} |
|
} |
|
|