|
|
#ifdef TEST_ON_CUDA |
|
|
#include <mma.h> |
|
|
|
|
|
#include <cuda_fp16.h> |
|
|
#include <cuda_fp8.h> |
|
|
|
|
|
namespace wmma = nvcuda::wmma; |
|
|
|
|
|
#define LIB_CALL(call) \ |
|
|
do { \ |
|
|
cudaError_t err = call; \ |
|
|
if (err != cudaSuccess) { \ |
|
|
abort(); \ |
|
|
} \ |
|
|
} while (0) |
|
|
|
|
|
#define HOST_TYPE(x) cuda##x |
|
|
|
|
|
#else |
|
|
|
|
|
#ifndef HIP_HEADERS__ |
|
|
#include <hip/hip_runtime.h> |
|
|
#include <hip/hip_fp8.h> |
|
|
#include <hip/hip_fp16.h> |
|
|
#include <rocwmma/rocwmma.hpp> |
|
|
#define HIP_HEADERS__ |
|
|
#endif |
|
|
|
|
|
namespace wmma = rocwmma; |
|
|
|
|
|
#define LIB_CALL(call) \ |
|
|
do { \ |
|
|
hipError_t err = call; \ |
|
|
if (err != hipSuccess) { \ |
|
|
abort(); \ |
|
|
} \ |
|
|
} while (0) |
|
|
|
|
|
#define HOST_TYPE(x) hip##x |
|
|
|
|
|
#endif |
|
|
|
|
|
|