|
|
#include "ggml.h" |
|
|
#include "utils.hpp" |
|
|
|
|
|
zdnn_data_types ggml_zdnn_type_mapping(ggml_type type) { |
|
|
switch (type) { |
|
|
case GGML_TYPE_F32: |
|
|
return FP32; |
|
|
case GGML_TYPE_F16: |
|
|
return FP16; |
|
|
case GGML_TYPE_BF16: |
|
|
return BFLOAT; |
|
|
case GGML_TYPE_Q8_0: |
|
|
return INT8; |
|
|
case GGML_TYPE_I8: |
|
|
return INT8; |
|
|
case GGML_TYPE_I32: |
|
|
return INT32; |
|
|
default: |
|
|
GGML_ABORT("%s: fatal: unable to determine zTensor data type", |
|
|
__func__); |
|
|
break; |
|
|
} |
|
|
} |
|
|
|
|
|
void ggml_zdnn_create_tensor(zdnn_tensor_desc & pre_tfm_desc, |
|
|
zdnn_tensor_desc & tfm_desc, |
|
|
zdnn_ztensor & ztensor, |
|
|
const ggml_tensor * src, |
|
|
const int64_t * ne, |
|
|
const zdnn_data_layouts layout) { |
|
|
zdnn_init_pre_transformed_desc( |
|
|
layout, |
|
|
ggml_zdnn_type_mapping(src->type), |
|
|
&pre_tfm_desc, |
|
|
ne[3], ne[2], ne[1], ne[0] |
|
|
); |
|
|
|
|
|
ZDNN_CHECK(zdnn_generate_transformed_desc(&pre_tfm_desc, &tfm_desc)); |
|
|
ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&pre_tfm_desc, &tfm_desc, &ztensor)); |
|
|
} |
|
|
|
|
|
void ggml_zdnn_load_tensor(zdnn_ztensor & ztensor, void * buffer) { |
|
|
ZDNN_CHECK(zdnn_transform_ztensor(&ztensor, buffer)); |
|
|
} |
|
|
|
|
|
void ggml_zdnn_init_tensor(ggml_backend_zdnn_buffer * buffer, const ggml_tensor * tensor) { |
|
|
switch (tensor->op) { |
|
|
case GGML_OP_MUL_MAT: |
|
|
{ |
|
|
zdnn_init_pre_transformed_desc( |
|
|
ZDNN_2D, |
|
|
ggml_zdnn_type_mapping(tensor->type), |
|
|
&buffer->pre_tfm_desc, |
|
|
tensor->ne[1], tensor->ne[0] |
|
|
); |
|
|
} break; |
|
|
|
|
|
default: |
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
zdnn_init_pre_transformed_desc( |
|
|
ZDNN_NHWC, |
|
|
ggml_zdnn_type_mapping(tensor->type), |
|
|
&buffer->pre_tfm_desc, |
|
|
tensor->ne[3], tensor->ne[2], tensor->ne[1], tensor->ne[0] |
|
|
); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} break; |
|
|
} |
|
|
|
|
|
ZDNN_CHECK(zdnn_generate_transformed_desc(&buffer->pre_tfm_desc, &buffer->tfm_desc)); |
|
|
ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&buffer->pre_tfm_desc, &buffer->tfm_desc, &buffer->ztensor)); |
|
|
} |
|
|
|