|
#pragma once |
|
|
|
#include <c10/core/Allocator.h> |
|
#include <c10/core/DeviceType.h> |
|
|
|
|
|
|
|
namespace c10 { namespace hip { |
|
|
|
|
|
|
|
|
|
class HIPAllocatorMasqueradingAsCUDA final : public Allocator { |
|
Allocator* allocator_; |
|
public: |
|
explicit HIPAllocatorMasqueradingAsCUDA(Allocator* allocator) |
|
: allocator_(allocator) {} |
|
DataPtr allocate(size_t size) const override { |
|
DataPtr r = allocator_->allocate(size); |
|
r.unsafe_set_device(Device(DeviceType::CUDA, r.device().index())); |
|
return r; |
|
} |
|
DeleterFnPtr raw_deleter() const override { |
|
return allocator_->raw_deleter(); |
|
} |
|
}; |
|
|
|
}} |
|
|