|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef ONEAPI_DNNL_DNNL_THREADPOOL_HPP
|
|
|
#define ONEAPI_DNNL_DNNL_THREADPOOL_HPP
|
|
|
|
|
|
#include "oneapi/dnnl/dnnl.hpp"
|
|
|
#include "oneapi/dnnl/dnnl_threadpool.h"
|
|
|
|
|
|
#include "oneapi/dnnl/dnnl_threadpool_iface.hpp"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace dnnl {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace threadpool_interop {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline dnnl::stream make_stream(
|
|
|
const dnnl::engine &aengine, threadpool_iface *threadpool) {
|
|
|
dnnl_stream_t c_stream;
|
|
|
dnnl::error::wrap_c_api(dnnl_threadpool_interop_stream_create(
|
|
|
&c_stream, aengine.get(), threadpool),
|
|
|
"could not create stream");
|
|
|
return dnnl::stream(c_stream);
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inline threadpool_iface *get_threadpool(const dnnl::stream &astream) {
|
|
|
void *tp;
|
|
|
dnnl::error::wrap_c_api(
|
|
|
dnnl_threadpool_interop_stream_get_threadpool(astream.get(), &tp),
|
|
|
"could not get stream threadpool");
|
|
|
return static_cast<threadpool_iface *>(tp);
|
|
|
}
|
|
|
|
|
|
|
|
|
inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
|
|
|
dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
|
|
|
const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc,
|
|
|
threadpool_iface *threadpool) {
|
|
|
return static_cast<status>(dnnl_threadpool_interop_sgemm(transa, transb, M,
|
|
|
N, K, alpha, A, lda, B, ldb, beta, C, ldc, threadpool));
|
|
|
}
|
|
|
|
|
|
inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
|
|
|
dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
|
|
|
dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
|
|
|
float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co,
|
|
|
threadpool_iface *threadpool) {
|
|
|
return static_cast<status>(dnnl_threadpool_interop_gemm_u8s8s32(transa,
|
|
|
transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo, beta, C,
|
|
|
ldc, co, threadpool));
|
|
|
}
|
|
|
|
|
|
|
|
|
inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
|
|
|
dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
|
|
|
dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
|
|
|
float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co,
|
|
|
threadpool_iface *threadpool) {
|
|
|
return static_cast<status>(dnnl_threadpool_interop_gemm_s8s8s32(transa,
|
|
|
transb, offsetc, M, N, K, alpha, A, lda, ao, B, ldb, bo, beta, C,
|
|
|
ldc, co, threadpool));
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|