Spaces:
Running
Running
// Presents an interface to a storage backend such as s3 or local disk. | |
// The interface is a simple key-value store, which maps to s3 well. | |
// For now the interface fetches a file and stores it at a specific | |
// location on disk. This is not ideal for s3, but it is a start. | |
// Ideally we would support streaming the file from s3 to the index | |
// but the current implementation of hnswlib makes this complicated. | |
// Once we move to our own implementation of hnswlib we can support | |
// streaming from s3. | |
use super::{config::StorageConfig, Storage}; | |
use crate::config::{Configurable, WorkerConfig}; | |
use crate::errors::ChromaError; | |
use async_trait::async_trait; | |
use aws_sdk_s3; | |
use aws_sdk_s3::error::SdkError; | |
use aws_sdk_s3::operation::create_bucket::CreateBucketError; | |
use aws_smithy_types::byte_stream::ByteStream; | |
use std::clone::Clone; | |
use std::io::Write; | |
struct S3Storage { | |
bucket: String, | |
client: aws_sdk_s3::Client, | |
} | |
impl S3Storage { | |
fn new(bucket: &str, client: aws_sdk_s3::Client) -> S3Storage { | |
return S3Storage { | |
bucket: bucket.to_string(), | |
client: client, | |
}; | |
} | |
async fn create_bucket(&self) -> Result<(), String> { | |
// Creates a public bucket with default settings in the region. | |
// This should only be used for testing and in production | |
// the bucket should be provisioned ahead of time. | |
let res = self | |
.client | |
.create_bucket() | |
.bucket(self.bucket.clone()) | |
.send() | |
.await; | |
match res { | |
Ok(_) => { | |
println!("created bucket {}", self.bucket); | |
return Ok(()); | |
} | |
Err(e) => match e { | |
SdkError::ServiceError(err) => match err.into_err() { | |
CreateBucketError::BucketAlreadyExists(msg) => { | |
println!("bucket already exists: {}", msg); | |
return Ok(()); | |
} | |
CreateBucketError::BucketAlreadyOwnedByYou(msg) => { | |
println!("bucket already owned by you: {}", msg); | |
return Ok(()); | |
} | |
e => { | |
println!("error: {}", e.to_string()); | |
return Err::<(), String>(e.to_string()); | |
} | |
}, | |
_ => { | |
println!("error: {}", e); | |
return Err::<(), String>(e.to_string()); | |
} | |
}, | |
} | |
} | |
} | |
impl Configurable for S3Storage { | |
async fn try_from_config(config: &WorkerConfig) -> Result<Self, Box<dyn ChromaError>> { | |
match &config.storage { | |
StorageConfig::S3(s3_config) => { | |
let config = aws_config::load_from_env().await; | |
let client = aws_sdk_s3::Client::new(&config); | |
let storage = S3Storage::new(&s3_config.bucket, client); | |
return Ok(storage); | |
} | |
} | |
} | |
} | |
impl Storage for S3Storage { | |
async fn get(&self, key: &str, path: &str) -> Result<(), String> { | |
let mut file = std::fs::File::create(path); | |
let res = self | |
.client | |
.get_object() | |
.bucket(self.bucket.clone()) | |
.key(key) | |
.send() | |
.await; | |
match res { | |
Ok(mut res) => { | |
match file { | |
Ok(mut file) => { | |
while let bytes = res.body.next().await { | |
match bytes { | |
Some(bytes) => match bytes { | |
Ok(bytes) => { | |
file.write_all(&bytes).unwrap(); | |
} | |
Err(e) => { | |
println!("error: {}", e); | |
return Err::<(), String>(e.to_string()); | |
} | |
}, | |
None => { | |
// Stream is done | |
return Ok(()); | |
} | |
} | |
} | |
} | |
Err(e) => { | |
println!("error: {}", e); | |
return Err::<(), String>(e.to_string()); | |
} | |
} | |
return Ok(()); | |
} | |
Err(e) => { | |
println!("error: {}", e); | |
return Err::<(), String>(e.to_string()); | |
} | |
} | |
} | |
async fn put(&self, key: &str, path: &str) -> Result<(), String> { | |
// Puts from a file on disk to s3. | |
let bytestream = ByteStream::from_path(path).await; | |
match bytestream { | |
Ok(bytestream) => { | |
let res = self | |
.client | |
.put_object() | |
.bucket(self.bucket.clone()) | |
.key(key) | |
.body(bytestream) | |
.send() | |
.await; | |
match res { | |
Ok(_) => { | |
println!("put object {} to bucket {}", key, self.bucket); | |
return Ok(()); | |
} | |
Err(e) => { | |
println!("error: {}", e); | |
return Err::<(), String>(e.to_string()); | |
} | |
} | |
} | |
Err(e) => { | |
println!("error: {}", e); | |
return Err::<(), String>(e.to_string()); | |
} | |
} | |
} | |
} | |
mod tests { | |
use super::*; | |
use tempfile::tempdir; | |
async fn test_get() { | |
// Set up credentials assuming minio is running locally | |
let cred = aws_sdk_s3::config::Credentials::new( | |
"minio", | |
"minio123", | |
None, | |
None, | |
"loaded-from-env", | |
); | |
// Set up s3 client | |
let config = aws_sdk_s3::config::Builder::new() | |
.endpoint_url("http://127.0.0.1:9000".to_string()) | |
.credentials_provider(cred) | |
.behavior_version_latest() | |
.region(aws_sdk_s3::config::Region::new("us-east-1")) | |
.force_path_style(true) | |
.build(); | |
let client = aws_sdk_s3::Client::from_conf(config); | |
let storage = S3Storage { | |
bucket: "test".to_string(), | |
client: client, | |
}; | |
storage.create_bucket().await.unwrap(); | |
// Write some data to a test file, put it in s3, get it back and verify its contents | |
let tmp_dir = tempdir().unwrap(); | |
let persist_path = tmp_dir.path().to_str().unwrap().to_string(); | |
let test_data = "test data"; | |
let test_file_in = format!("{}/test_file_in", persist_path); | |
let test_file_out = format!("{}/test_file_out", persist_path); | |
std::fs::write(&test_file_in, test_data).unwrap(); | |
storage.put("test", &test_file_in).await.unwrap(); | |
storage.get("test", &test_file_out).await.unwrap(); | |
let contents = std::fs::read_to_string(test_file_out).unwrap(); | |
assert_eq!(contents, test_data); | |
} | |
} | |