Weiyun1025's picture
Upload folder using huggingface_hub
2abfccb verified
raw
history blame
11 kB
#include <aws/core/Aws.h>
#include <aws/s3/S3Client.h>
#include <aws/s3/model/Bucket.h>
#include <aws/core/auth/AWSCredentialsProvider.h>
#include <aws/s3/model/GetObjectRequest.h>
#include <aws/s3/model/PutObjectRequest.h>
#include <aws/s3/model/ListObjectsRequest.h>
#include <aws/s3/model/Object.h>
#include <aws/s3/S3Errors.h>
#include <string>
#include <list>
#include <fstream>
#include "s3client.h"
#include <aws/s3/model/CreateMultipartUploadRequest.h>
#include <aws/s3/model/CompletedPart.h>
#include <aws/s3/model/UploadPartRequest.h>
#include <aws/s3/model/CompleteMultipartUploadRequest.h>
#include <aws/s3/model/CompletedMultipartUpload.h>
#include <aws/s3/model/AbortMultipartUploadRequest.h>
#include <iostream>
#include <vector>
// mutipart相关
#include <aws/core/utils/threading/Executor.h>
#include <aws/transfer/TransferManager.h>
#include <aws/transfer/TransferHandle.h>
#include <aws/core/utils/memory/AWSMemory.h>
#include <aws/core/utils/memory/stl/AWSStreamFwd.h>
#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
#include <aws/core/utils/StringUtils.h>
#include <aws/s3/model/DeleteObjectRequest.h>
#include <aws/s3/model/HeadObjectRequest.h>
#include <aws/s3/model/ListObjectsV2Request.h>
#define ERROR_ITEM(error_name) \
{ \
static_cast<int>(Aws::S3::S3Errors::error_name), #error_name \
}
static std::list<std::pair<int, std::string>> ERROR_LIST = {
ERROR_ITEM(INCOMPLETE_SIGNATURE),
ERROR_ITEM(INTERNAL_FAILURE),
ERROR_ITEM(INVALID_ACTION),
ERROR_ITEM(INVALID_CLIENT_TOKEN_ID),
ERROR_ITEM(INVALID_PARAMETER_COMBINATION),
ERROR_ITEM(INVALID_QUERY_PARAMETER),
ERROR_ITEM(INVALID_PARAMETER_VALUE),
ERROR_ITEM(MISSING_ACTION),
ERROR_ITEM(MISSING_AUTHENTICATION_TOKEN),
ERROR_ITEM(MISSING_PARAMETER),
ERROR_ITEM(OPT_IN_REQUIRED),
ERROR_ITEM(REQUEST_EXPIRED),
ERROR_ITEM(SERVICE_UNAVAILABLE),
ERROR_ITEM(THROTTLING),
ERROR_ITEM(VALIDATION),
ERROR_ITEM(ACCESS_DENIED),
ERROR_ITEM(RESOURCE_NOT_FOUND),
ERROR_ITEM(UNRECOGNIZED_CLIENT),
ERROR_ITEM(MALFORMED_QUERY_STRING),
ERROR_ITEM(SLOW_DOWN),
ERROR_ITEM(REQUEST_TIME_TOO_SKEWED),
ERROR_ITEM(INVALID_SIGNATURE),
ERROR_ITEM(SIGNATURE_DOES_NOT_MATCH),
ERROR_ITEM(INVALID_ACCESS_KEY_ID),
ERROR_ITEM(REQUEST_TIMEOUT),
ERROR_ITEM(NETWORK_CONNECTION),
ERROR_ITEM(UNKNOWN),
ERROR_ITEM(BUCKET_ALREADY_EXISTS),
ERROR_ITEM(BUCKET_ALREADY_OWNED_BY_YOU),
ERROR_ITEM(NO_SUCH_BUCKET),
ERROR_ITEM(NO_SUCH_KEY),
ERROR_ITEM(NO_SUCH_UPLOAD),
ERROR_ITEM(OBJECT_ALREADY_IN_ACTIVE_TIER),
ERROR_ITEM(OBJECT_NOT_IN_ACTIVE_TIER),
};
std::list<std::pair<int, std::string>> get_error_list()
{
return ERROR_LIST;
}
static Aws::SDKOptions options;
static std::unordered_map<std::string, Aws::Utils::Logging::LogLevel> log_level_map = {
{"off", Aws::Utils::Logging::LogLevel::Off},
{"fatal", Aws::Utils::Logging::LogLevel::Fatal},
{"error", Aws::Utils::Logging::LogLevel::Error},
{"warn", Aws::Utils::Logging::LogLevel::Warn},
{"info", Aws::Utils::Logging::LogLevel::Info},
{"debug", Aws::Utils::Logging::LogLevel::Debug},
{"trace", Aws::Utils::Logging::LogLevel::Trace},
};
void init_api(const std::string &level)
{
auto itr = log_level_map.find(level);
if (itr != log_level_map.end())
{
options.loggingOptions.logLevel = itr->second;
}
Aws::InitAPI(options);
}
void shutdown_api()
{
Aws::ShutdownAPI(options);
}
S3Client::S3Client(const std::string &ak, const std::string &sk, const std::string &endpoint, bool verify_ssl, bool enable_https, bool use_dual_stack, int threads_num)
{
Aws::Client::ClientConfiguration config;
config.endpointOverride = endpoint.c_str();
config.verifySSL = verify_ssl;
if (enable_https)
{
config.scheme = Aws::Http::Scheme::HTTPS;
}
else
{
config.scheme = Aws::Http::Scheme::HTTP;
}
config.useDualStack = use_dual_stack;
Aws::Auth::AWSCredentials cred = Aws::Auth::AWSCredentials(ak.c_str(), sk.c_str());
// this->client = new Aws::S3::S3Client(cred, config, Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, false);
this->client = Aws::MakeShared<Aws::S3::S3Client>("S3Client", cred, config, Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, false);
// multipart使用
this->threads_num = threads_num;
return;
}
S3Client::~S3Client()
{
// delete this->client;
}
int S3Client::get_object(const std::string &bucket, const std::string &key, int &error_type, std::string &error_message, std::string &result, std::string &range)
{
Aws::S3::Model::GetObjectRequest object_request;
object_request.SetBucket(bucket.c_str());
object_request.SetKey(key.c_str());
if ("" != range)
object_request.SetRange(std::string("bytes="+range).c_str());
/* 这里真正请求数据 */
auto get_object_outcome = this->client->GetObject(object_request);
if (get_object_outcome.IsSuccess())
{
auto &&get_result = get_object_outcome.GetResultWithOwnership();
auto content_length = get_result.GetContentLength();
auto &retrieved_object = get_result.GetBody();
result.resize(content_length);
long read_offset = 0;
while (true)
{
retrieved_object.read(&result[read_offset], content_length - read_offset);
auto read_count = retrieved_object.gcount();
read_offset += read_count;
if (read_offset >= content_length || read_count == 0)
{
break;
}
}
if (read_offset != content_length)
{
//todo
}
return 0;
}
else
{
auto error = get_object_outcome.GetError();
auto message = error.GetMessage();
if (!message.empty())
{
error_message = message.c_str();
}
error_type = static_cast<int>(error.GetErrorType());
return -1;
}
return 0;
}
int S3Client::put_object(const std::string &bucket, const std::string &key, const std::string &data, int &error_type, std::string &error_message)
{
const std::shared_ptr<Aws::IOStream> input_data = Aws::MakeShared<Aws::StringStream>("");
*input_data << data;
Aws::S3::Model::PutObjectRequest request;
request.SetBucket(bucket.c_str());
request.SetKey(key.c_str());
request.SetBody(input_data);
auto outcome = this->client->PutObject(request);
if (outcome.IsSuccess())
{
return 0;
}
else
{
auto error = outcome.GetError();
auto message = error.GetMessage();
if (!message.empty())
{
error_message = message.c_str();
}
error_type = static_cast<int>(error.GetErrorType());
return -1;
}
}
int S3Client::multipart_download_concurrency(const std::string &bucket, const std::string &key, const std::string &filename, int &error_type, std::string &error_message)
{
auto executor = Aws::MakeShared<Aws::Utils::Threading::PooledThreadExecutor>("executor", this->threads_num);
Aws::Transfer::TransferManagerConfiguration transfer_config(executor.get());
transfer_config.s3Client = this->client;
auto transfer_manager = Aws::Transfer::TransferManager::Create(transfer_config);
auto downloadHandle = transfer_manager->DownloadFile(bucket.c_str(),
key.c_str(),
[=](){
return Aws::New<Aws::FStream>("S3_DOWNLOAD", filename.c_str(), std::ios_base::out | std::ios_base::binary);
});
downloadHandle->WaitUntilFinished();// Block calling thread until download is complete.
auto downStat = downloadHandle->GetStatus();
if (downStat != Aws::Transfer::TransferStatus::COMPLETED)
{
auto error = downloadHandle->GetLastError();
error_message = error.GetMessage().c_str();
error_type = static_cast<int>(error.GetErrorType());
return -1;
}
return 0;
}
int S3Client::multipart_upload_concurrency(const std::string bucket, const std::string key, const std::string filename, int &error_type, std::string &error_message)
{
auto executor = Aws::MakeShared<Aws::Utils::Threading::PooledThreadExecutor>("executor", this->threads_num);
Aws::Transfer::TransferManagerConfiguration transfer_config(executor.get());
transfer_config.s3Client = this->client;
auto transfer_manager = Aws::Transfer::TransferManager::Create(transfer_config);
auto uploadHandle = transfer_manager->UploadFile(filename.c_str(), bucket.c_str(), key.c_str(), "text/plain", Aws::Map<Aws::String, Aws::String>());
uploadHandle->WaitUntilFinished();
bool success = uploadHandle->GetStatus() == Aws::Transfer::TransferStatus::COMPLETED;
if (!success)
{
auto error = uploadHandle->GetLastError();
error_message = error.GetMessage().c_str();
error_type = static_cast<int>(error.GetErrorType());
return -1;
}
else
{
return 0;
}
}
int S3Client::delete_obj(const std::string &bucket, const std::string &key, int error_type, std::string &error_message)
{
Aws::S3::Model::DeleteObjectRequest request;
request.WithBucket(bucket.c_str()).WithKey(key.c_str());
auto outcome = this->client->DeleteObject(request);
return outcome.IsSuccess();
}
int S3Client::contains(const std::string &bucket, const std::string &key, int error_type, std::string &error_message)
{
Aws::S3::Model::HeadObjectRequest request;
request.WithBucket(bucket.c_str()).WithKey(key.c_str());
const auto response = client->HeadObject(request);
auto outcome = this->client->HeadObject(request);
return outcome.IsSuccess();
}
std::vector<std::string> S3Client::list(const std::string &bucket, const std::string &key, int error_type, std::string &error_message)
{
Aws::S3::Model::ListObjectsRequest request;
request.WithBucket(bucket.c_str()).WithPrefix(key.c_str());
std::vector<std::string> res;
auto outcome = this->client->ListObjects(request);
if (!outcome.IsSuccess()) {
auto error = outcome.GetError();
auto message = error.GetMessage();
if (!message.empty())
{
error_message = message.c_str();
}
error_type = static_cast<int>(error.GetErrorType());
return res;
}
else {
Aws::Vector<Aws::S3::Model::Object> objects =
outcome.GetResult().GetContents();
for (Aws::S3::Model::Object &object: objects) {
std::string full_path = object.GetKey().c_str();
int pos = full_path.find('/', key.size());
int len = pos - key.size();
if (-1 != pos) {
len += 1;
}
std::string first_path = full_path.substr(key.size(), len);
res.push_back(first_path);
}
}
return res;
}