| // acl_runtime.h — per-rank ACL runtime init/teardown. | |
| class AclRuntime { | |
| public: | |
| AclRuntime() = default; | |
| ~AclRuntime() { shutdown(); } | |
| bool init(int device_id) { | |
| if (initialized_) return true; | |
| device_id_ = device_id; | |
| ACL_CHECK(aclInit(nullptr)); | |
| ACL_CHECK(aclrtSetDevice(device_id)); | |
| ACL_CHECK(aclrtCreateContext(&ctx_, device_id)); | |
| ACL_CHECK(aclrtCreateStream(&stream_)); | |
| initialized_ = true; | |
| return true; | |
| } | |
| void shutdown() { | |
| if (!initialized_) return; | |
| if (stream_) { aclrtDestroyStream(stream_); stream_ = nullptr; } | |
| if (ctx_) { aclrtDestroyContext(ctx_); ctx_ = nullptr; } | |
| aclrtResetDevice(device_id_); | |
| aclFinalize(); | |
| initialized_ = false; | |
| } | |
| void sync() { if (stream_) ACL_CHECK(aclrtSynchronizeStream(stream_)); } | |
| aclrtStream stream() const { return stream_; } | |
| int device_id() const { return device_id_; } | |
| private: | |
| bool initialized_ = false; | |
| int device_id_ = 0; | |
| aclrtContext ctx_ = nullptr; | |
| aclrtStream stream_ = nullptr; | |
| }; | |