Spaces:
Running
Vulkan: VK_KHR_cooperative_matrix support to speed up prompt processing (llama/10597)
Browse files* Vulkan: Implement VK_KHR_cooperative_matrix support in the matrix matrix multiplication shader
* Improve performance with better q4_k and q5_k dequant and store unrolling
* Add Vulkan MUL_MAT and MUL_MAT_ID accumulator precision selection
* Rework mulmat shader selection and compilation logic, avoid compiling shaders that won't get used by device
* Vulkan: Implement accumulator switch for specific mul mat mat shaders
* Vulkan: Unroll more loops for more mul mat mat performance
* Vulkan: Add VK_AMD_shader_core_properties2 support to read Compute Unit count for split_k logic
* Disable coopmat support on AMD proprietary driver
* Remove redundant checks
* Add environment variable GGML_VK_DISABLE_COOPMAT to disable VK_KHR_cooperative_matrix support
* Fix rebase typo
* Fix coopmat2 MUL_MAT_ID pipeline selection
|
@@ -1,7 +1,8 @@
|
|
| 1 |
#include "ggml-vulkan.h"
|
| 2 |
#include <vulkan/vulkan_core.h>
|
| 3 |
-
#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF)
|
| 4 |
#include <chrono>
|
|
|
|
| 5 |
#endif
|
| 6 |
|
| 7 |
#include <vulkan/vulkan.hpp>
|
|
@@ -169,8 +170,22 @@ struct vk_device_struct {
|
|
| 169 |
bool uma;
|
| 170 |
bool coopmat2;
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
size_t idx;
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
vk_matmul_pipeline pipeline_matmul_f32;
|
| 175 |
vk_matmul_pipeline pipeline_matmul_f32_f16;
|
| 176 |
vk_matmul_pipeline2 pipeline_matmul_f16;
|
|
@@ -181,10 +196,10 @@ struct vk_device_struct {
|
|
| 181 |
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
|
| 182 |
|
| 183 |
vk_matmul_pipeline pipeline_matmul_id_f32;
|
| 184 |
-
|
| 185 |
-
|
| 186 |
|
| 187 |
-
|
| 188 |
|
| 189 |
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
|
| 190 |
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
|
|
@@ -1325,6 +1340,18 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
|
|
| 1325 |
return {64, 64};
|
| 1326 |
};
|
| 1327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1328 |
|
| 1329 |
static void ggml_vk_load_shaders(vk_device& device) {
|
| 1330 |
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
|
@@ -1382,12 +1409,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1382 |
m_align = 64;
|
| 1383 |
s_align = 32;
|
| 1384 |
} else {
|
| 1385 |
-
|
| 1386 |
-
|
| 1387 |
-
|
| 1388 |
-
|
| 1389 |
-
|
| 1390 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1391 |
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
| 1392 |
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
|
| 1393 |
s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
|
|
@@ -1428,25 +1468,36 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1428 |
// assert mul_mat_mat_id shaders will fit.
|
| 1429 |
GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
|
| 1430 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1431 |
}
|
| 1432 |
|
| 1433 |
device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1434 |
device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1435 |
|
| 1436 |
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1437 |
-
device->pipeline_matmul_id_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1438 |
-
device->pipeline_matmul_id_f16 = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1439 |
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1440 |
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1441 |
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1442 |
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1443 |
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1444 |
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1445 |
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1446 |
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1447 |
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1448 |
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1449 |
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1450 |
|
| 1451 |
std::vector<std::future<void>> compiles;
|
| 1452 |
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false) {
|
|
@@ -1543,119 +1594,191 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1543 |
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 1544 |
|
| 1545 |
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
| 1546 |
-
|
| 1547 |
-
|
| 1548 |
-
|
| 1549 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1550 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1551 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1552 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1553 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1554 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1555 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1556 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1557 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1558 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1559 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1560 |
#undef CREATE_MM
|
| 1561 |
#undef CREATE_MM2
|
| 1562 |
} else
|
| 1563 |
-
#endif
|
| 1564 |
-
if (device->
|
| 1565 |
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
| 1566 |
-
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
| 1567 |
-
|
| 1568 |
-
|
| 1569 |
-
|
| 1570 |
-
|
| 1571 |
-
|
| 1572 |
-
|
| 1573 |
-
|
| 1574 |
-
|
| 1575 |
-
|
| 1576 |
-
|
| 1577 |
-
|
| 1578 |
-
|
| 1579 |
-
|
| 1580 |
-
|
| 1581 |
-
|
| 1582 |
-
CREATE_MM(
|
| 1583 |
-
CREATE_MM(
|
| 1584 |
-
|
| 1585 |
-
CREATE_MM(
|
| 1586 |
-
CREATE_MM(
|
| 1587 |
-
|
| 1588 |
-
|
| 1589 |
-
|
| 1590 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1591 |
|
| 1592 |
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
| 1593 |
-
if (device->
|
| 1594 |
-
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
|
| 1595 |
-
|
| 1596 |
-
|
| 1597 |
-
|
| 1598 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, ,
|
| 1599 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, ,
|
| 1600 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, ,
|
| 1601 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, ,
|
| 1602 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, ,
|
| 1603 |
-
|
| 1604 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, ,
|
| 1605 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, ,
|
| 1606 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, ,
|
| 1607 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, ,
|
| 1608 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, ,
|
| 1609 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, ,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1610 |
}
|
| 1611 |
#undef CREATE_MM
|
| 1612 |
} else {
|
| 1613 |
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
| 1614 |
-
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
| 1615 |
-
|
| 1616 |
-
|
| 1617 |
-
|
| 1618 |
-
|
| 1619 |
-
|
| 1620 |
-
|
| 1621 |
-
|
| 1622 |
-
|
| 1623 |
-
|
| 1624 |
-
|
| 1625 |
-
|
| 1626 |
-
|
| 1627 |
-
|
| 1628 |
-
CREATE_MM(
|
| 1629 |
-
CREATE_MM(
|
| 1630 |
-
CREATE_MM(
|
| 1631 |
-
CREATE_MM(
|
| 1632 |
-
|
| 1633 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
| 1634 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
| 1635 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
| 1636 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
| 1637 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
| 1638 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1639 |
|
| 1640 |
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
| 1641 |
-
if (device->
|
| 1642 |
-
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
|
| 1643 |
-
CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
|
| 1644 |
-
CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
|
| 1645 |
-
|
| 1646 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1647 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1648 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1649 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1650 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1651 |
-
|
| 1652 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1653 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1654 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1655 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1656 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1657 |
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
|
| 1658 |
}
|
|
|
|
| 1659 |
#undef CREATE_MM
|
| 1660 |
}
|
| 1661 |
|
|
@@ -1851,8 +1974,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 1851 |
bool fp16_compute = false;
|
| 1852 |
bool maintenance4_support = false;
|
| 1853 |
bool sm_builtins = false;
|
|
|
|
| 1854 |
bool pipeline_robustness = false;
|
| 1855 |
bool coopmat2_support = false;
|
|
|
|
| 1856 |
|
| 1857 |
// Check if maintenance4 is supported
|
| 1858 |
for (const auto& properties : ext_props) {
|
|
@@ -1864,10 +1989,18 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 1864 |
fp16_compute = true;
|
| 1865 |
} else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
|
| 1866 |
sm_builtins = true;
|
|
|
|
|
|
|
| 1867 |
} else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
|
| 1868 |
pipeline_robustness = true;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1869 |
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
| 1870 |
-
!getenv("
|
| 1871 |
coopmat2_support = true;
|
| 1872 |
}
|
| 1873 |
}
|
|
@@ -1876,11 +2009,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 1876 |
vk::PhysicalDeviceMaintenance3Properties props3;
|
| 1877 |
vk::PhysicalDeviceMaintenance4Properties props4;
|
| 1878 |
vk::PhysicalDeviceSubgroupProperties subgroup_props;
|
|
|
|
| 1879 |
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
|
|
|
|
| 1880 |
props2.pNext = &props3;
|
| 1881 |
props3.pNext = &subgroup_props;
|
|
|
|
| 1882 |
|
| 1883 |
-
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&
|
| 1884 |
|
| 1885 |
if (maintenance4_support) {
|
| 1886 |
last_struct->pNext = (VkBaseOutStructure *)&props4;
|
|
@@ -1890,6 +2026,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 1890 |
last_struct->pNext = (VkBaseOutStructure *)&sm_props;
|
| 1891 |
last_struct = (VkBaseOutStructure *)&sm_props;
|
| 1892 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1893 |
|
| 1894 |
#if defined(VK_NV_cooperative_matrix2)
|
| 1895 |
vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
|
|
@@ -1905,7 +2045,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 1905 |
const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
|
| 1906 |
|
| 1907 |
if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
|
| 1908 |
-
device->max_memory_allocation_size = std::
|
| 1909 |
} else if (maintenance4_support) {
|
| 1910 |
device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
|
| 1911 |
} else {
|
|
@@ -1917,15 +2057,22 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 1917 |
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
| 1918 |
if (sm_builtins) {
|
| 1919 |
device->shader_core_count = sm_props.shaderSMCount;
|
|
|
|
|
|
|
| 1920 |
} else {
|
| 1921 |
device->shader_core_count = 0;
|
| 1922 |
}
|
| 1923 |
|
| 1924 |
-
const
|
| 1925 |
-
const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
|
| 1926 |
|
| 1927 |
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
| 1928 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1929 |
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
| 1930 |
|
| 1931 |
// Try to find a non-graphics compute queue and transfer-focused queues
|
|
@@ -1976,6 +2123,16 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 1976 |
device_extensions.push_back("VK_EXT_pipeline_robustness");
|
| 1977 |
}
|
| 1978 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1979 |
#if defined(VK_NV_cooperative_matrix2)
|
| 1980 |
VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
|
| 1981 |
coopmat2_features.pNext = nullptr;
|
|
@@ -1993,6 +2150,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 1993 |
|
| 1994 |
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
|
| 1995 |
|
|
|
|
|
|
|
| 1996 |
if (coopmat2_support) {
|
| 1997 |
#if defined(VK_NV_cooperative_matrix2)
|
| 1998 |
if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
|
|
@@ -2083,6 +2242,74 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 2083 |
if (device->fp16) {
|
| 2084 |
device_extensions.push_back("VK_KHR_shader_float16_int8");
|
| 2085 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2086 |
device->name = GGML_VK_NAME + std::to_string(idx);
|
| 2087 |
|
| 2088 |
device_create_info = {
|
|
@@ -2098,6 +2325,37 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
| 2098 |
ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);
|
| 2099 |
|
| 2100 |
// Shaders
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2101 |
ggml_vk_load_shaders(device);
|
| 2102 |
|
| 2103 |
if (!device->single_queue) {
|
|
@@ -2155,15 +2413,24 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
| 2155 |
|
| 2156 |
bool fp16_storage = false;
|
| 2157 |
bool fp16_compute = false;
|
|
|
|
| 2158 |
|
| 2159 |
for (auto properties : ext_props) {
|
| 2160 |
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
| 2161 |
fp16_storage = true;
|
| 2162 |
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
|
| 2163 |
fp16_compute = true;
|
|
|
|
|
|
|
| 2164 |
}
|
| 2165 |
}
|
| 2166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2167 |
const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
|
| 2168 |
bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
|
| 2169 |
|
|
@@ -2186,13 +2453,28 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
| 2186 |
vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
|
| 2187 |
vk11_features.pNext = &vk12_features;
|
| 2188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2189 |
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
|
| 2190 |
|
| 2191 |
fp16 = fp16 && vk12_features.shaderFloat16;
|
| 2192 |
|
|
|
|
|
|
|
| 2193 |
std::string device_name = props2.properties.deviceName.data();
|
| 2194 |
-
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu\n",
|
| 2195 |
-
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size);
|
| 2196 |
|
| 2197 |
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
|
| 2198 |
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
|
|
@@ -2428,7 +2710,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
| 2428 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
| 2429 |
return ctx->device->pipeline_matmul_f32_f16;
|
| 2430 |
}
|
| 2431 |
-
if (prec == GGML_PREC_DEFAULT && ctx->device->
|
| 2432 |
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
| 2433 |
return ctx->device->pipeline_matmul_f16_f32.f16acc;
|
| 2434 |
}
|
|
@@ -2469,7 +2751,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
| 2469 |
assert(src1_type == GGML_TYPE_F16);
|
| 2470 |
return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc;
|
| 2471 |
}
|
| 2472 |
-
return ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
|
| 2473 |
}
|
| 2474 |
|
| 2475 |
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
|
|
@@ -2498,16 +2780,25 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
|
| 2498 |
return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
|
| 2499 |
}
|
| 2500 |
|
| 2501 |
-
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
|
| 2502 |
VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()");
|
| 2503 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
| 2504 |
return ctx->device->pipeline_matmul_id_f32;
|
| 2505 |
}
|
| 2506 |
-
if (
|
| 2507 |
-
|
| 2508 |
-
|
| 2509 |
-
|
| 2510 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2511 |
}
|
| 2512 |
|
| 2513 |
GGML_ASSERT(src1_type == GGML_TYPE_F32);
|
|
@@ -2529,7 +2820,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
|
| 2529 |
return nullptr;
|
| 2530 |
}
|
| 2531 |
|
| 2532 |
-
return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
|
| 2533 |
}
|
| 2534 |
|
| 2535 |
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
|
|
@@ -3119,62 +3410,31 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
|
|
| 3119 |
return split_k;
|
| 3120 |
}
|
| 3121 |
|
| 3122 |
-
static vk_pipeline
|
| 3123 |
-
if (m <= 32 || n <= 32) {
|
| 3124 |
-
return aligned ? mmp->a_s : mmp->s;
|
| 3125 |
-
}
|
| 3126 |
-
return aligned ? mmp->a_m : mmp->m;
|
| 3127 |
-
|
| 3128 |
-
GGML_UNUSED(ctx);
|
| 3129 |
-
}
|
| 3130 |
-
|
| 3131 |
-
static vk_pipeline ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
|
| 3132 |
-
return aligned ? mmp->a_m : mmp->m;
|
| 3133 |
-
|
| 3134 |
-
GGML_UNUSED(ctx);
|
| 3135 |
-
}
|
| 3136 |
-
|
| 3137 |
-
static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
|
| 3138 |
-
return aligned ? mmp->a_s : mmp->s;
|
| 3139 |
-
|
| 3140 |
-
GGML_UNUSED(ctx);
|
| 3141 |
-
}
|
| 3142 |
-
|
| 3143 |
-
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
|
| 3144 |
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
|
| 3145 |
-
switch (ctx->device->vendor_id) {
|
| 3146 |
-
case VK_VENDOR_ID_AMD:
|
| 3147 |
-
return ggml_vk_guess_matmul_pipeline_amd(ctx, mmp, m, n, aligned);
|
| 3148 |
-
case VK_VENDOR_ID_APPLE:
|
| 3149 |
-
return ggml_vk_guess_matmul_pipeline_apple(ctx, mmp, aligned);
|
| 3150 |
-
case VK_VENDOR_ID_INTEL:
|
| 3151 |
-
return ggml_vk_guess_matmul_pipeline_intel(ctx, mmp, aligned);
|
| 3152 |
-
default:
|
| 3153 |
-
break;
|
| 3154 |
-
}
|
| 3155 |
|
| 3156 |
if (ctx->device->coopmat2) {
|
| 3157 |
-
if ((m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) {
|
| 3158 |
return aligned ? mmp->a_l : mmp->l;
|
| 3159 |
}
|
| 3160 |
-
if ((m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) {
|
| 3161 |
return aligned ? mmp->a_m : mmp->m;
|
| 3162 |
}
|
| 3163 |
return aligned ? mmp->a_s : mmp->s;
|
| 3164 |
}
|
| 3165 |
|
| 3166 |
-
if (m <= 32 || n <= 32) {
|
| 3167 |
return aligned ? mmp->a_s : mmp->s;
|
| 3168 |
}
|
| 3169 |
-
if (m <= 64 || n <= 64) {
|
| 3170 |
return aligned ? mmp->a_m : mmp->m;
|
| 3171 |
}
|
| 3172 |
return aligned ? mmp->a_l : mmp->l;
|
| 3173 |
}
|
| 3174 |
|
| 3175 |
-
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
|
| 3176 |
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
|
| 3177 |
-
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align;
|
| 3178 |
}
|
| 3179 |
|
| 3180 |
static void ggml_vk_matmul(
|
|
@@ -3201,6 +3461,33 @@ static void ggml_vk_matmul(
|
|
| 3201 |
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
|
| 3202 |
}
|
| 3203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3204 |
static void ggml_vk_matmul_id(
|
| 3205 |
ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
|
| 3206 |
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
|
|
@@ -3350,10 +3637,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 3350 |
const int y_ne = ne11 * ne10;
|
| 3351 |
const int d_ne = ne11 * ne01;
|
| 3352 |
|
| 3353 |
-
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
|
| 3354 |
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
| 3355 |
|
| 3356 |
-
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
|
| 3357 |
|
| 3358 |
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
|
| 3359 |
|
|
@@ -3904,7 +4191,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
| 3904 |
|
| 3905 |
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
| 3906 |
|
| 3907 |
-
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
|
| 3908 |
|
| 3909 |
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
| 3910 |
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
|
@@ -3920,10 +4207,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
| 3920 |
const uint64_t y_ne = ne11 * ne10;
|
| 3921 |
const uint64_t d_ne = ne21 * ne20;
|
| 3922 |
|
| 3923 |
-
const uint32_t kpad = ggml_vk_align_size(ne10,
|
| 3924 |
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
| 3925 |
|
| 3926 |
-
vk_pipeline pipeline =
|
| 3927 |
|
| 3928 |
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
| 3929 |
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
|
@@ -5504,19 +5791,27 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|
| 5504 |
for (size_t i = 0; i < x_ne; i++) {
|
| 5505 |
if (std::is_same<float, X_TYPE>()) {
|
| 5506 |
x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
|
|
|
|
|
|
|
|
|
|
| 5507 |
} else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
|
| 5508 |
x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
|
|
|
|
|
|
|
|
|
|
| 5509 |
} else {
|
| 5510 |
GGML_ABORT("fatal error");
|
| 5511 |
}
|
| 5512 |
}
|
| 5513 |
for (size_t i = 0; i < y_ne; i++) {
|
| 5514 |
if (std::is_same<float, Y_TYPE>()) {
|
| 5515 |
-
|
| 5516 |
-
y[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
|
|
|
| 5517 |
} else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
|
| 5518 |
-
|
| 5519 |
-
y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
|
|
|
|
| 5520 |
} else {
|
| 5521 |
GGML_ABORT("fatal error");
|
| 5522 |
}
|
|
@@ -5600,7 +5895,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|
| 5600 |
double err = std::fabs(d[i] - d_chk[i]);
|
| 5601 |
avg_err += err;
|
| 5602 |
|
| 5603 |
-
if (err > 0.05f && first_err_n == -1) {
|
| 5604 |
first_err_b = i / (m * n);
|
| 5605 |
first_err_n = (i % (m * n)) / m;
|
| 5606 |
first_err_m = (i % (m * n)) % m;
|
|
@@ -5613,12 +5908,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|
| 5613 |
|
| 5614 |
std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
|
| 5615 |
|
| 5616 |
-
if (avg_err > 0.1) {
|
| 5617 |
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
|
| 5618 |
std::cerr << "Actual result: " << std::endl << std::endl;
|
| 5619 |
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
| 5620 |
-
std::cerr << std::endl;
|
| 5621 |
-
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
|
| 5622 |
std::cerr << "Expected result: " << std::endl << std::endl;
|
| 5623 |
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
| 5624 |
|
|
@@ -5801,13 +6094,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
| 5801 |
vk_pipeline p;
|
| 5802 |
std::string shname;
|
| 5803 |
if (shader_size == 0) {
|
| 5804 |
-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
|
| 5805 |
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
|
| 5806 |
} else if (shader_size == 1) {
|
| 5807 |
-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
|
| 5808 |
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
|
| 5809 |
} else if (shader_size == 2) {
|
| 5810 |
-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
|
| 5811 |
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
|
| 5812 |
} else {
|
| 5813 |
GGML_ASSERT(0);
|
|
@@ -5817,13 +6110,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
| 5817 |
|
| 5818 |
if (k != kpad) {
|
| 5819 |
if (shader_size == 0) {
|
| 5820 |
-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
|
| 5821 |
shname = std::string(ggml_type_name(quant)) + "_S";
|
| 5822 |
} else if (shader_size == 1) {
|
| 5823 |
-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
|
| 5824 |
shname = std::string(ggml_type_name(quant)) + "_M";
|
| 5825 |
} else if (shader_size == 2) {
|
| 5826 |
-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
|
| 5827 |
shname = std::string(ggml_type_name(quant)) + "_L";
|
| 5828 |
} else {
|
| 5829 |
GGML_ASSERT(0);
|
|
@@ -5982,105 +6275,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
| 5982 |
|
| 5983 |
static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
| 5984 |
#if defined(GGML_VULKAN_RUN_TESTS)
|
| 5985 |
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
|
| 5986 |
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0);
|
| 5987 |
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1);
|
| 5988 |
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_0);
|
| 5989 |
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_1);
|
| 5990 |
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q8_0);
|
| 5991 |
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q2_K);
|
| 5992 |
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q3_K);
|
| 5993 |
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_K);
|
| 5994 |
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K);
|
| 5995 |
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K);
|
| 5996 |
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_IQ4_NL);
|
| 5997 |
-
|
| 5998 |
-
ggml_vk_test_matmul<ggml_fp16_t, ggml_fp16_t>(ctx, 512, 512, 100, 32, 100, 1, 2);
|
| 5999 |
-
|
| 6000 |
-
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
|
| 6001 |
-
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
|
| 6002 |
-
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
|
| 6003 |
-
// ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
|
| 6004 |
-
// ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
|
| 6005 |
-
// ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
|
| 6006 |
-
|
| 6007 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0);
|
| 6008 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0);
|
| 6009 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0);
|
| 6010 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0);
|
| 6011 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0);
|
| 6012 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0);
|
| 6013 |
-
|
| 6014 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_1);
|
| 6015 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_1);
|
| 6016 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_1);
|
| 6017 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_1);
|
| 6018 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_1);
|
| 6019 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_1);
|
| 6020 |
-
|
| 6021 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_0);
|
| 6022 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_0);
|
| 6023 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_0);
|
| 6024 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_0);
|
| 6025 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_0);
|
| 6026 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_0);
|
| 6027 |
-
|
| 6028 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_1);
|
| 6029 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_1);
|
| 6030 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_1);
|
| 6031 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_1);
|
| 6032 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_1);
|
| 6033 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_1);
|
| 6034 |
-
|
| 6035 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q8_0);
|
| 6036 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q8_0);
|
| 6037 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q8_0);
|
| 6038 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q8_0);
|
| 6039 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q8_0);
|
| 6040 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q8_0);
|
| 6041 |
-
|
| 6042 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q2_K);
|
| 6043 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q2_K);
|
| 6044 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q2_K);
|
| 6045 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q2_K);
|
| 6046 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q2_K);
|
| 6047 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q2_K);
|
| 6048 |
-
|
| 6049 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q3_K);
|
| 6050 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q3_K);
|
| 6051 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q3_K);
|
| 6052 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q3_K);
|
| 6053 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q3_K);
|
| 6054 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q3_K);
|
| 6055 |
-
|
| 6056 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_K);
|
| 6057 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_K);
|
| 6058 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_K);
|
| 6059 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_K);
|
| 6060 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_K);
|
| 6061 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_K);
|
| 6062 |
-
|
| 6063 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_K);
|
| 6064 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_K);
|
| 6065 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_K);
|
| 6066 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_K);
|
| 6067 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_K);
|
| 6068 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_K);
|
| 6069 |
-
|
| 6070 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q6_K);
|
| 6071 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q6_K);
|
| 6072 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q6_K);
|
| 6073 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q6_K);
|
| 6074 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q6_K);
|
| 6075 |
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q6_K);
|
| 6076 |
-
|
| 6077 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_IQ4_NL);
|
| 6078 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_IQ4_NL);
|
| 6079 |
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_IQ4_NL);
|
| 6080 |
-
|
| 6081 |
-
std::cerr << std::endl;
|
| 6082 |
-
|
| 6083 |
const std::vector<size_t> vals {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6084 |
8, 8, 8,
|
| 6085 |
100, 46, 576,
|
| 6086 |
623, 111, 128,
|
|
@@ -6093,15 +6294,6 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
|
| 6093 |
49, 49, 128,
|
| 6094 |
128, 49, 49,
|
| 6095 |
4096, 49, 4096,
|
| 6096 |
-
11008, 49, 4096,
|
| 6097 |
-
4096, 49, 11008,
|
| 6098 |
-
32000, 49, 4096,
|
| 6099 |
-
512, 512, 128,
|
| 6100 |
-
128, 512, 512,
|
| 6101 |
-
4096, 512, 4096,
|
| 6102 |
-
11008, 512, 4096,
|
| 6103 |
-
4096, 512, 11008,
|
| 6104 |
-
32000, 512, 4096,
|
| 6105 |
};
|
| 6106 |
const size_t num_it = 100;
|
| 6107 |
|
|
@@ -6109,10 +6301,45 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
|
| 6109 |
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
|
| 6110 |
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
|
| 6111 |
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
|
| 6112 |
-
|
| 6113 |
-
|
| 6114 |
-
|
| 6115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6116 |
}
|
| 6117 |
|
| 6118 |
GGML_ABORT("fatal error");
|
|
@@ -7200,8 +7427,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 7200 |
case GGML_OP_MUL_MAT_ID:
|
| 7201 |
{
|
| 7202 |
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
| 7203 |
-
|
| 7204 |
-
|
| 7205 |
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
|
| 7206 |
return false;
|
| 7207 |
}
|
|
|
|
| 1 |
#include "ggml-vulkan.h"
|
| 2 |
#include <vulkan/vulkan_core.h>
|
| 3 |
+
#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS)
|
| 4 |
#include <chrono>
|
| 5 |
+
#include "ggml-cpu.h"
|
| 6 |
#endif
|
| 7 |
|
| 8 |
#include <vulkan/vulkan.hpp>
|
|
|
|
| 170 |
bool uma;
|
| 171 |
bool coopmat2;
|
| 172 |
|
| 173 |
+
bool coopmat_support;
|
| 174 |
+
bool coopmat_acc_f32_support;
|
| 175 |
+
bool coopmat_acc_f16_support;
|
| 176 |
+
uint32_t coopmat_m;
|
| 177 |
+
uint32_t coopmat_n;
|
| 178 |
+
uint32_t coopmat_k;
|
| 179 |
+
|
| 180 |
size_t idx;
|
| 181 |
|
| 182 |
+
bool mul_mat_l;
|
| 183 |
+
bool mul_mat_m;
|
| 184 |
+
bool mul_mat_s;
|
| 185 |
+
bool mul_mat_id_l;
|
| 186 |
+
bool mul_mat_id_m;
|
| 187 |
+
bool mul_mat_id_s;
|
| 188 |
+
|
| 189 |
vk_matmul_pipeline pipeline_matmul_f32;
|
| 190 |
vk_matmul_pipeline pipeline_matmul_f32_f16;
|
| 191 |
vk_matmul_pipeline2 pipeline_matmul_f16;
|
|
|
|
| 196 |
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
|
| 197 |
|
| 198 |
vk_matmul_pipeline pipeline_matmul_id_f32;
|
| 199 |
+
vk_matmul_pipeline2 pipeline_matmul_id_f16;
|
| 200 |
+
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
|
| 201 |
|
| 202 |
+
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
|
| 203 |
|
| 204 |
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
|
| 205 |
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
|
|
|
|
| 1340 |
return {64, 64};
|
| 1341 |
};
|
| 1342 |
|
| 1343 |
+
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id) {
|
| 1344 |
+
// Needs to be kept up to date on shader changes
|
| 1345 |
+
const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
|
| 1346 |
+
const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
| 1347 |
+
const uint32_t warps = warptile[0] / device->subgroup_size;
|
| 1348 |
+
|
| 1349 |
+
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
| 1350 |
+
const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
|
| 1351 |
+
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
| 1352 |
+
|
| 1353 |
+
return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize;
|
| 1354 |
+
}
|
| 1355 |
|
| 1356 |
static void ggml_vk_load_shaders(vk_device& device) {
|
| 1357 |
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
|
|
|
| 1409 |
m_align = 64;
|
| 1410 |
s_align = 32;
|
| 1411 |
} else {
|
| 1412 |
+
// Matrix cores require different warp group sizes
|
| 1413 |
+
const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4;
|
| 1414 |
+
const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4;
|
| 1415 |
+
const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2;
|
| 1416 |
+
const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4;
|
| 1417 |
+
const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2;
|
| 1418 |
+
const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2;
|
| 1419 |
+
const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1;
|
| 1420 |
+
const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
|
| 1421 |
+
const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
|
| 1422 |
+
|
| 1423 |
+
l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
|
| 1424 |
+
m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
|
| 1425 |
+
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
|
| 1426 |
+
|
| 1427 |
+
l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
|
| 1428 |
+
m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
|
| 1429 |
+
s_warptile_mmq = { subgroup_size_16, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
|
| 1430 |
+
|
| 1431 |
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
| 1432 |
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
|
| 1433 |
s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
|
|
|
|
| 1468 |
// assert mul_mat_mat_id shaders will fit.
|
| 1469 |
GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
|
| 1470 |
}
|
| 1471 |
+
// Disable medium and large matrix multiplication if not enough shared memory is available
|
| 1472 |
+
// Check mmq warptiles as the largest configuration
|
| 1473 |
+
// Throw an error if not enough for any matrix multiplication is available
|
| 1474 |
+
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) {
|
| 1475 |
+
std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
|
| 1476 |
+
throw std::runtime_error("Shared memory size too small for matrix multiplication.");
|
| 1477 |
+
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) {
|
| 1478 |
+
device->mul_mat_m = false;
|
| 1479 |
+
device->mul_mat_l = false;
|
| 1480 |
+
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) {
|
| 1481 |
+
device->mul_mat_l = false;
|
| 1482 |
+
}
|
| 1483 |
+
|
| 1484 |
+
// Disable mul_mat_id if not enough shared memory is available
|
| 1485 |
+
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) {
|
| 1486 |
+
device->mul_mat_id_s = false;
|
| 1487 |
+
device->mul_mat_id_m = false;
|
| 1488 |
+
device->mul_mat_id_l = false;
|
| 1489 |
+
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) {
|
| 1490 |
+
device->mul_mat_id_m = false;
|
| 1491 |
+
device->mul_mat_id_l = false;
|
| 1492 |
+
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) {
|
| 1493 |
+
device->mul_mat_id_l = false;
|
| 1494 |
+
}
|
| 1495 |
}
|
| 1496 |
|
| 1497 |
device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1498 |
device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
|
| 1499 |
|
| 1500 |
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1501 |
|
| 1502 |
std::vector<std::future<void>> compiles;
|
| 1503 |
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false) {
|
|
|
|
| 1594 |
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
| 1595 |
|
| 1596 |
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
| 1597 |
+
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
| 1598 |
+
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
| 1599 |
+
|
| 1600 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1601 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1602 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1603 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1604 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1605 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1606 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1607 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1608 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1609 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1610 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
| 1611 |
#undef CREATE_MM
|
| 1612 |
#undef CREATE_MM2
|
| 1613 |
} else
|
| 1614 |
+
#endif // defined(VK_NV_cooperative_matrix2)
|
| 1615 |
+
if (device->coopmat_support) {
|
| 1616 |
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
| 1617 |
+
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
| 1618 |
+
if (device->mul_mat ## ID ## _l) \
|
| 1619 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
| 1620 |
+
if (device->mul_mat ## ID ## _m) \
|
| 1621 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
| 1622 |
+
if (device->mul_mat ## ID ## _s) \
|
| 1623 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
| 1624 |
+
if (device->mul_mat ## ID ## _l) \
|
| 1625 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
| 1626 |
+
if (device->mul_mat ## ID ## _m) \
|
| 1627 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
| 1628 |
+
if (device->mul_mat ## ID ## _s) \
|
| 1629 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
| 1630 |
+
|
| 1631 |
+
// Create 2 variants, {f16,f32} accumulator
|
| 1632 |
+
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
| 1633 |
+
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
| 1634 |
+
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
| 1635 |
+
|
| 1636 |
+
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 1637 |
+
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 1638 |
+
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 1639 |
+
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 1640 |
+
|
| 1641 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1642 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1643 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1644 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1645 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1646 |
+
|
| 1647 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1648 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1649 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1650 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1651 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1652 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1653 |
|
| 1654 |
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
| 1655 |
+
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
|
| 1656 |
+
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 1657 |
+
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 1658 |
+
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 1659 |
+
|
| 1660 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1661 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1662 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1663 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1664 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1665 |
+
|
| 1666 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1667 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1668 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1669 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1670 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1671 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1672 |
+
}
|
| 1673 |
+
#undef CREATE_MM
|
| 1674 |
+
} else if (device->fp16) {
|
| 1675 |
+
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
| 1676 |
+
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
| 1677 |
+
if (device->mul_mat ## ID ## _l) \
|
| 1678 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
| 1679 |
+
if (device->mul_mat ## ID ## _m) \
|
| 1680 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
| 1681 |
+
if (device->mul_mat ## ID ## _s) \
|
| 1682 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
| 1683 |
+
if (device->mul_mat ## ID ## _l) \
|
| 1684 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
| 1685 |
+
if (device->mul_mat ## ID ## _m) \
|
| 1686 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
| 1687 |
+
if (device->mul_mat ## ID ## _s) \
|
| 1688 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
| 1689 |
+
|
| 1690 |
+
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 1691 |
+
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 1692 |
+
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 1693 |
+
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 1694 |
+
|
| 1695 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1696 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1697 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1698 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1699 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1700 |
+
|
| 1701 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1702 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1703 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1704 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1705 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1706 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1707 |
+
|
| 1708 |
+
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
| 1709 |
+
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
|
| 1710 |
+
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 1711 |
+
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 1712 |
+
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 1713 |
+
|
| 1714 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1715 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1716 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1717 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1718 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1719 |
+
|
| 1720 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1721 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1722 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1723 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1724 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1725 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1726 |
}
|
| 1727 |
#undef CREATE_MM
|
| 1728 |
} else {
|
| 1729 |
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
| 1730 |
+
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
| 1731 |
+
if (device->mul_mat ## ID ## _l) \
|
| 1732 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
| 1733 |
+
if (device->mul_mat ## ID ## _m) \
|
| 1734 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
| 1735 |
+
if (device->mul_mat ## ID ## _s) \
|
| 1736 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
| 1737 |
+
if (device->mul_mat ## ID ## _l) \
|
| 1738 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
| 1739 |
+
if (device->mul_mat ## ID ## _m) \
|
| 1740 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
| 1741 |
+
if (device->mul_mat ## ID ## _s) \
|
| 1742 |
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
| 1743 |
+
|
| 1744 |
+
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 1745 |
+
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 1746 |
+
CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 1747 |
+
CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
| 1748 |
+
|
| 1749 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1750 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1751 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1752 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1753 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1754 |
+
|
| 1755 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1756 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1757 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1758 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1759 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1760 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
| 1761 |
|
| 1762 |
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
| 1763 |
+
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
|
| 1764 |
+
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 1765 |
+
CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 1766 |
+
CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
| 1767 |
+
|
| 1768 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1769 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1770 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1771 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1772 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1773 |
+
|
| 1774 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1775 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1776 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1777 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1778 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1779 |
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
| 1780 |
}
|
| 1781 |
+
#undef CREATE_MM2
|
| 1782 |
#undef CREATE_MM
|
| 1783 |
}
|
| 1784 |
|
|
|
|
| 1974 |
bool fp16_compute = false;
|
| 1975 |
bool maintenance4_support = false;
|
| 1976 |
bool sm_builtins = false;
|
| 1977 |
+
bool amd_shader_core_properties2 = false;
|
| 1978 |
bool pipeline_robustness = false;
|
| 1979 |
bool coopmat2_support = false;
|
| 1980 |
+
device->coopmat_support = false;
|
| 1981 |
|
| 1982 |
// Check if maintenance4 is supported
|
| 1983 |
for (const auto& properties : ext_props) {
|
|
|
|
| 1989 |
fp16_compute = true;
|
| 1990 |
} else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
|
| 1991 |
sm_builtins = true;
|
| 1992 |
+
} else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) {
|
| 1993 |
+
amd_shader_core_properties2 = true;
|
| 1994 |
} else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
|
| 1995 |
pipeline_robustness = true;
|
| 1996 |
+
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
|
| 1997 |
+
!getenv("GGML_VK_DISABLE_COOPMAT")) {
|
| 1998 |
+
device->coopmat_support = true;
|
| 1999 |
+
device->coopmat_m = 0;
|
| 2000 |
+
device->coopmat_n = 0;
|
| 2001 |
+
device->coopmat_k = 0;
|
| 2002 |
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
| 2003 |
+
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
| 2004 |
coopmat2_support = true;
|
| 2005 |
}
|
| 2006 |
}
|
|
|
|
| 2009 |
vk::PhysicalDeviceMaintenance3Properties props3;
|
| 2010 |
vk::PhysicalDeviceMaintenance4Properties props4;
|
| 2011 |
vk::PhysicalDeviceSubgroupProperties subgroup_props;
|
| 2012 |
+
vk::PhysicalDeviceDriverProperties driver_props;
|
| 2013 |
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
|
| 2014 |
+
vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
|
| 2015 |
props2.pNext = &props3;
|
| 2016 |
props3.pNext = &subgroup_props;
|
| 2017 |
+
subgroup_props.pNext = &driver_props;
|
| 2018 |
|
| 2019 |
+
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
|
| 2020 |
|
| 2021 |
if (maintenance4_support) {
|
| 2022 |
last_struct->pNext = (VkBaseOutStructure *)&props4;
|
|
|
|
| 2026 |
last_struct->pNext = (VkBaseOutStructure *)&sm_props;
|
| 2027 |
last_struct = (VkBaseOutStructure *)&sm_props;
|
| 2028 |
}
|
| 2029 |
+
if (amd_shader_core_properties2) {
|
| 2030 |
+
last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
|
| 2031 |
+
last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
|
| 2032 |
+
}
|
| 2033 |
|
| 2034 |
#if defined(VK_NV_cooperative_matrix2)
|
| 2035 |
vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
|
|
|
|
| 2045 |
const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
|
| 2046 |
|
| 2047 |
if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
|
| 2048 |
+
device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
|
| 2049 |
} else if (maintenance4_support) {
|
| 2050 |
device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
|
| 2051 |
} else {
|
|
|
|
| 2057 |
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
| 2058 |
if (sm_builtins) {
|
| 2059 |
device->shader_core_count = sm_props.shaderSMCount;
|
| 2060 |
+
} else if (amd_shader_core_properties2) {
|
| 2061 |
+
device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
|
| 2062 |
} else {
|
| 2063 |
device->shader_core_count = 0;
|
| 2064 |
}
|
| 2065 |
|
| 2066 |
+
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
|
|
|
|
| 2067 |
|
| 2068 |
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
| 2069 |
|
| 2070 |
+
if (device->vendor_id == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
|
| 2071 |
+
// Intel drivers don't support coopmat properly yet
|
| 2072 |
+
// Only RADV supports coopmat properly on AMD
|
| 2073 |
+
device->coopmat_support = false;
|
| 2074 |
+
}
|
| 2075 |
+
|
| 2076 |
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
| 2077 |
|
| 2078 |
// Try to find a non-graphics compute queue and transfer-focused queues
|
|
|
|
| 2123 |
device_extensions.push_back("VK_EXT_pipeline_robustness");
|
| 2124 |
}
|
| 2125 |
|
| 2126 |
+
VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
|
| 2127 |
+
coopmat_features.pNext = nullptr;
|
| 2128 |
+
coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
|
| 2129 |
+
coopmat_features.cooperativeMatrix = VK_FALSE;
|
| 2130 |
+
|
| 2131 |
+
if (device->coopmat_support) {
|
| 2132 |
+
last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
|
| 2133 |
+
last_struct = (VkBaseOutStructure *)&coopmat_features;
|
| 2134 |
+
}
|
| 2135 |
+
|
| 2136 |
#if defined(VK_NV_cooperative_matrix2)
|
| 2137 |
VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
|
| 2138 |
coopmat2_features.pNext = nullptr;
|
|
|
|
| 2150 |
|
| 2151 |
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
|
| 2152 |
|
| 2153 |
+
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
|
| 2154 |
+
|
| 2155 |
if (coopmat2_support) {
|
| 2156 |
#if defined(VK_NV_cooperative_matrix2)
|
| 2157 |
if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
|
|
|
|
| 2242 |
if (device->fp16) {
|
| 2243 |
device_extensions.push_back("VK_KHR_shader_float16_int8");
|
| 2244 |
}
|
| 2245 |
+
|
| 2246 |
+
if (device->coopmat_support) {
|
| 2247 |
+
// Query supported shapes
|
| 2248 |
+
std::vector<VkCooperativeMatrixPropertiesKHR> cm_props;
|
| 2249 |
+
|
| 2250 |
+
PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR =
|
| 2251 |
+
(PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR");
|
| 2252 |
+
|
| 2253 |
+
uint32_t cm_props_num;
|
| 2254 |
+
|
| 2255 |
+
pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr);
|
| 2256 |
+
|
| 2257 |
+
cm_props.resize(cm_props_num);
|
| 2258 |
+
|
| 2259 |
+
for (auto& prop : cm_props) {
|
| 2260 |
+
prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;
|
| 2261 |
+
}
|
| 2262 |
+
|
| 2263 |
+
pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data());
|
| 2264 |
+
|
| 2265 |
+
VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size());
|
| 2266 |
+
|
| 2267 |
+
for (auto& prop : cm_props) {
|
| 2268 |
+
VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope));
|
| 2269 |
+
|
| 2270 |
+
if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 &&
|
| 2271 |
+
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 &&
|
| 2272 |
+
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
|
| 2273 |
+
) {
|
| 2274 |
+
if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
|
| 2275 |
+
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) {
|
| 2276 |
+
// coopmat sizes not set yet
|
| 2277 |
+
if (device->coopmat_m == 0) {
|
| 2278 |
+
device->coopmat_acc_f32_support = true;
|
| 2279 |
+
device->coopmat_m = prop.MSize;
|
| 2280 |
+
device->coopmat_n = prop.NSize;
|
| 2281 |
+
device->coopmat_k = prop.KSize;
|
| 2282 |
+
} else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
|
| 2283 |
+
// Only enable if shape is identical
|
| 2284 |
+
device->coopmat_acc_f32_support = true;
|
| 2285 |
+
}
|
| 2286 |
+
} else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
|
| 2287 |
+
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
|
| 2288 |
+
// coopmat sizes not set yet
|
| 2289 |
+
if (device->coopmat_m == 0) {
|
| 2290 |
+
device->coopmat_acc_f16_support = true;
|
| 2291 |
+
device->coopmat_m = prop.MSize;
|
| 2292 |
+
device->coopmat_n = prop.NSize;
|
| 2293 |
+
device->coopmat_k = prop.KSize;
|
| 2294 |
+
} else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
|
| 2295 |
+
// Only enable if shape is identical
|
| 2296 |
+
device->coopmat_acc_f16_support = true;
|
| 2297 |
+
}
|
| 2298 |
+
}
|
| 2299 |
+
}
|
| 2300 |
+
}
|
| 2301 |
+
|
| 2302 |
+
if (device->coopmat_m == 0) {
|
| 2303 |
+
// No suitable matmul mode found
|
| 2304 |
+
GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
|
| 2305 |
+
device->coopmat_support = false;
|
| 2306 |
+
}
|
| 2307 |
+
}
|
| 2308 |
+
|
| 2309 |
+
if (device->coopmat_support) {
|
| 2310 |
+
device_extensions.push_back("VK_KHR_cooperative_matrix");
|
| 2311 |
+
}
|
| 2312 |
+
|
| 2313 |
device->name = GGML_VK_NAME + std::to_string(idx);
|
| 2314 |
|
| 2315 |
device_create_info = {
|
|
|
|
| 2325 |
ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);
|
| 2326 |
|
| 2327 |
// Shaders
|
| 2328 |
+
// Disable matmul tile sizes early if performance low or not supported
|
| 2329 |
+
switch (device->vendor_id) {
|
| 2330 |
+
#ifndef GGML_VULKAN_RUN_TESTS
|
| 2331 |
+
case VK_VENDOR_ID_AMD:
|
| 2332 |
+
case VK_VENDOR_ID_INTEL:
|
| 2333 |
+
device->mul_mat_l = false;
|
| 2334 |
+
device->mul_mat_m = true;
|
| 2335 |
+
device->mul_mat_s = true;
|
| 2336 |
+
device->mul_mat_id_l = false;
|
| 2337 |
+
device->mul_mat_id_m = true;
|
| 2338 |
+
device->mul_mat_id_s = true;
|
| 2339 |
+
break;
|
| 2340 |
+
case VK_VENDOR_ID_APPLE:
|
| 2341 |
+
device->mul_mat_l = false;
|
| 2342 |
+
device->mul_mat_m = true;
|
| 2343 |
+
device->mul_mat_s = false;
|
| 2344 |
+
device->mul_mat_id_l = false;
|
| 2345 |
+
device->mul_mat_id_m = true;
|
| 2346 |
+
device->mul_mat_id_s = false;
|
| 2347 |
+
break;
|
| 2348 |
+
#endif
|
| 2349 |
+
default:
|
| 2350 |
+
device->mul_mat_l = true;
|
| 2351 |
+
device->mul_mat_m = true;
|
| 2352 |
+
device->mul_mat_s = true;
|
| 2353 |
+
device->mul_mat_id_l = true;
|
| 2354 |
+
device->mul_mat_id_m = true;
|
| 2355 |
+
device->mul_mat_id_s = true;
|
| 2356 |
+
break;
|
| 2357 |
+
}
|
| 2358 |
+
|
| 2359 |
ggml_vk_load_shaders(device);
|
| 2360 |
|
| 2361 |
if (!device->single_queue) {
|
|
|
|
| 2413 |
|
| 2414 |
bool fp16_storage = false;
|
| 2415 |
bool fp16_compute = false;
|
| 2416 |
+
bool coopmat_support = false;
|
| 2417 |
|
| 2418 |
for (auto properties : ext_props) {
|
| 2419 |
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
| 2420 |
fp16_storage = true;
|
| 2421 |
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
|
| 2422 |
fp16_compute = true;
|
| 2423 |
+
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
|
| 2424 |
+
coopmat_support = true;
|
| 2425 |
}
|
| 2426 |
}
|
| 2427 |
|
| 2428 |
+
if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
|
| 2429 |
+
// Intel drivers don't support coopmat properly yet
|
| 2430 |
+
// Only RADV supports coopmat properly on AMD
|
| 2431 |
+
coopmat_support = false;
|
| 2432 |
+
}
|
| 2433 |
+
|
| 2434 |
const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
|
| 2435 |
bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
|
| 2436 |
|
|
|
|
| 2453 |
vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
|
| 2454 |
vk11_features.pNext = &vk12_features;
|
| 2455 |
|
| 2456 |
+
// Pointer to the last chain element
|
| 2457 |
+
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
|
| 2458 |
+
|
| 2459 |
+
VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
|
| 2460 |
+
coopmat_features.pNext = nullptr;
|
| 2461 |
+
coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
|
| 2462 |
+
coopmat_features.cooperativeMatrix = VK_FALSE;
|
| 2463 |
+
|
| 2464 |
+
if (coopmat_support) {
|
| 2465 |
+
last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
|
| 2466 |
+
last_struct = (VkBaseOutStructure *)&coopmat_features;
|
| 2467 |
+
}
|
| 2468 |
+
|
| 2469 |
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
|
| 2470 |
|
| 2471 |
fp16 = fp16 && vk12_features.shaderFloat16;
|
| 2472 |
|
| 2473 |
+
coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
|
| 2474 |
+
|
| 2475 |
std::string device_name = props2.properties.deviceName.data();
|
| 2476 |
+
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %d\n",
|
| 2477 |
+
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, coopmat_support);
|
| 2478 |
|
| 2479 |
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
|
| 2480 |
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
|
|
|
|
| 2710 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
| 2711 |
return ctx->device->pipeline_matmul_f32_f16;
|
| 2712 |
}
|
| 2713 |
+
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16) {
|
| 2714 |
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
| 2715 |
return ctx->device->pipeline_matmul_f16_f32.f16acc;
|
| 2716 |
}
|
|
|
|
| 2751 |
assert(src1_type == GGML_TYPE_F16);
|
| 2752 |
return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc;
|
| 2753 |
}
|
| 2754 |
+
return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
|
| 2755 |
}
|
| 2756 |
|
| 2757 |
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
|
|
|
|
| 2780 |
return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
|
| 2781 |
}
|
| 2782 |
|
| 2783 |
+
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
|
| 2784 |
VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()");
|
| 2785 |
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
| 2786 |
return ctx->device->pipeline_matmul_id_f32;
|
| 2787 |
}
|
| 2788 |
+
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16) {
|
| 2789 |
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
| 2790 |
+
return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
|
| 2791 |
+
}
|
| 2792 |
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
| 2793 |
+
return ctx->device->pipeline_matmul_id_f16.f16acc;
|
| 2794 |
+
}
|
| 2795 |
+
} else {
|
| 2796 |
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
| 2797 |
+
return ctx->device->pipeline_matmul_id_f16_f32.f32acc;
|
| 2798 |
+
}
|
| 2799 |
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
| 2800 |
+
return ctx->device->pipeline_matmul_id_f16.f32acc;
|
| 2801 |
+
}
|
| 2802 |
}
|
| 2803 |
|
| 2804 |
GGML_ASSERT(src1_type == GGML_TYPE_F32);
|
|
|
|
| 2820 |
return nullptr;
|
| 2821 |
}
|
| 2822 |
|
| 2823 |
+
return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
|
| 2824 |
}
|
| 2825 |
|
| 2826 |
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
|
|
|
|
| 3410 |
return split_k;
|
| 3411 |
}
|
| 3412 |
|
| 3413 |
+
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type type_a) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3414 |
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3415 |
|
| 3416 |
if (ctx->device->coopmat2) {
|
| 3417 |
+
if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) {
|
| 3418 |
return aligned ? mmp->a_l : mmp->l;
|
| 3419 |
}
|
| 3420 |
+
if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) {
|
| 3421 |
return aligned ? mmp->a_m : mmp->m;
|
| 3422 |
}
|
| 3423 |
return aligned ? mmp->a_s : mmp->s;
|
| 3424 |
}
|
| 3425 |
|
| 3426 |
+
if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) {
|
| 3427 |
return aligned ? mmp->a_s : mmp->s;
|
| 3428 |
}
|
| 3429 |
+
if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) {
|
| 3430 |
return aligned ? mmp->a_m : mmp->m;
|
| 3431 |
}
|
| 3432 |
return aligned ? mmp->a_l : mmp->l;
|
| 3433 |
}
|
| 3434 |
|
| 3435 |
+
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type type_a) {
|
| 3436 |
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
|
| 3437 |
+
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, type_a)->align;
|
| 3438 |
}
|
| 3439 |
|
| 3440 |
static void ggml_vk_matmul(
|
|
|
|
| 3461 |
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
|
| 3462 |
}
|
| 3463 |
|
| 3464 |
+
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
|
| 3465 |
+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
|
| 3466 |
+
|
| 3467 |
+
if (ctx->device->coopmat2) {
|
| 3468 |
+
if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) {
|
| 3469 |
+
return aligned ? mmp->a_l : mmp->l;
|
| 3470 |
+
}
|
| 3471 |
+
if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) {
|
| 3472 |
+
return aligned ? mmp->a_m : mmp->m;
|
| 3473 |
+
}
|
| 3474 |
+
return aligned ? mmp->a_s : mmp->s;
|
| 3475 |
+
}
|
| 3476 |
+
|
| 3477 |
+
if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) {
|
| 3478 |
+
return aligned ? mmp->a_s : mmp->s;
|
| 3479 |
+
}
|
| 3480 |
+
if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) {
|
| 3481 |
+
return aligned ? mmp->a_m : mmp->m;
|
| 3482 |
+
}
|
| 3483 |
+
return aligned ? mmp->a_l : mmp->l;
|
| 3484 |
+
}
|
| 3485 |
+
|
| 3486 |
+
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
|
| 3487 |
+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
|
| 3488 |
+
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align;
|
| 3489 |
+
}
|
| 3490 |
+
|
| 3491 |
static void ggml_vk_matmul_id(
|
| 3492 |
ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
|
| 3493 |
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
|
|
|
|
| 3637 |
const int y_ne = ne11 * ne10;
|
| 3638 |
const int d_ne = ne11 * ne01;
|
| 3639 |
|
| 3640 |
+
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, src0->type));
|
| 3641 |
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
| 3642 |
|
| 3643 |
+
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, src0->type);
|
| 3644 |
|
| 3645 |
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
|
| 3646 |
|
|
|
|
| 4191 |
|
| 4192 |
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
| 4193 |
|
| 4194 |
+
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
|
| 4195 |
|
| 4196 |
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
| 4197 |
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
|
|
|
| 4207 |
const uint64_t y_ne = ne11 * ne10;
|
| 4208 |
const uint64_t d_ne = ne21 * ne20;
|
| 4209 |
|
| 4210 |
+
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1));
|
| 4211 |
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
| 4212 |
|
| 4213 |
+
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned);
|
| 4214 |
|
| 4215 |
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
| 4216 |
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
|
|
|
| 5791 |
for (size_t i = 0; i < x_ne; i++) {
|
| 5792 |
if (std::is_same<float, X_TYPE>()) {
|
| 5793 |
x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
|
| 5794 |
+
// x[i] = 1.0f;
|
| 5795 |
+
// x[i] = i + 1;
|
| 5796 |
+
// x[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
| 5797 |
} else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
|
| 5798 |
x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
|
| 5799 |
+
// x[i] = ggml_fp32_to_fp16(1.0f);
|
| 5800 |
+
// x[i] = ggml_fp32_to_fp16(i + 1);
|
| 5801 |
+
// x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
|
| 5802 |
} else {
|
| 5803 |
GGML_ABORT("fatal error");
|
| 5804 |
}
|
| 5805 |
}
|
| 5806 |
for (size_t i = 0; i < y_ne; i++) {
|
| 5807 |
if (std::is_same<float, Y_TYPE>()) {
|
| 5808 |
+
y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
|
| 5809 |
+
// y[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
| 5810 |
+
// y[i] = i + 1;
|
| 5811 |
} else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
|
| 5812 |
+
y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
|
| 5813 |
+
// y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
|
| 5814 |
+
// y[i] = ggml_fp32_to_fp16(i + 1);
|
| 5815 |
} else {
|
| 5816 |
GGML_ABORT("fatal error");
|
| 5817 |
}
|
|
|
|
| 5895 |
double err = std::fabs(d[i] - d_chk[i]);
|
| 5896 |
avg_err += err;
|
| 5897 |
|
| 5898 |
+
if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
|
| 5899 |
first_err_b = i / (m * n);
|
| 5900 |
first_err_n = (i % (m * n)) / m;
|
| 5901 |
first_err_m = (i % (m * n)) % m;
|
|
|
|
| 5908 |
|
| 5909 |
std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
|
| 5910 |
|
| 5911 |
+
if (avg_err > 0.1 || std::isnan(avg_err)) {
|
| 5912 |
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
|
| 5913 |
std::cerr << "Actual result: " << std::endl << std::endl;
|
| 5914 |
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
|
|
|
|
|
|
| 5915 |
std::cerr << "Expected result: " << std::endl << std::endl;
|
| 5916 |
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
| 5917 |
|
|
|
|
| 6094 |
vk_pipeline p;
|
| 6095 |
std::string shname;
|
| 6096 |
if (shader_size == 0) {
|
| 6097 |
+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
|
| 6098 |
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
|
| 6099 |
} else if (shader_size == 1) {
|
| 6100 |
+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
|
| 6101 |
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
|
| 6102 |
} else if (shader_size == 2) {
|
| 6103 |
+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
|
| 6104 |
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
|
| 6105 |
} else {
|
| 6106 |
GGML_ASSERT(0);
|
|
|
|
| 6110 |
|
| 6111 |
if (k != kpad) {
|
| 6112 |
if (shader_size == 0) {
|
| 6113 |
+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
|
| 6114 |
shname = std::string(ggml_type_name(quant)) + "_S";
|
| 6115 |
} else if (shader_size == 1) {
|
| 6116 |
+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
|
| 6117 |
shname = std::string(ggml_type_name(quant)) + "_M";
|
| 6118 |
} else if (shader_size == 2) {
|
| 6119 |
+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
|
| 6120 |
shname = std::string(ggml_type_name(quant)) + "_L";
|
| 6121 |
} else {
|
| 6122 |
GGML_ASSERT(0);
|
|
|
|
| 6275 |
|
| 6276 |
static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
| 6277 |
#if defined(GGML_VULKAN_RUN_TESTS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6278 |
const std::vector<size_t> vals {
|
| 6279 |
+
512, 512, 128,
|
| 6280 |
+
128, 512, 512,
|
| 6281 |
+
4096, 512, 4096,
|
| 6282 |
+
11008, 512, 4096,
|
| 6283 |
+
4096, 512, 11008,
|
| 6284 |
+
32000, 512, 4096,
|
| 6285 |
8, 8, 8,
|
| 6286 |
100, 46, 576,
|
| 6287 |
623, 111, 128,
|
|
|
|
| 6294 |
49, 49, 128,
|
| 6295 |
128, 49, 49,
|
| 6296 |
4096, 49, 4096,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6297 |
};
|
| 6298 |
const size_t num_it = 100;
|
| 6299 |
|
|
|
|
| 6301 |
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
|
| 6302 |
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
|
| 6303 |
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
|
| 6304 |
+
std::cerr << '\n';
|
| 6305 |
+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0);
|
| 6306 |
+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1);
|
| 6307 |
+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2);
|
| 6308 |
+
std::cerr << '\n';
|
| 6309 |
+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
|
| 6310 |
+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
|
| 6311 |
+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
|
| 6312 |
+
std::cerr << '\n' << std::endl;
|
| 6313 |
+
|
| 6314 |
+
if (vals[i + 2] % 32 == 0) {
|
| 6315 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0);
|
| 6316 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0);
|
| 6317 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0);
|
| 6318 |
+
std::cerr << '\n';
|
| 6319 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0);
|
| 6320 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0);
|
| 6321 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0);
|
| 6322 |
+
std::cerr << '\n';
|
| 6323 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0);
|
| 6324 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0);
|
| 6325 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0);
|
| 6326 |
+
std::cerr << '\n' << std::endl;
|
| 6327 |
+
}
|
| 6328 |
+
|
| 6329 |
+
if (vals[i + 2] % 256 == 0) {
|
| 6330 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K);
|
| 6331 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K);
|
| 6332 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K);
|
| 6333 |
+
std::cerr << '\n';
|
| 6334 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K);
|
| 6335 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K);
|
| 6336 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K);
|
| 6337 |
+
std::cerr << '\n';
|
| 6338 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K);
|
| 6339 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K);
|
| 6340 |
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K);
|
| 6341 |
+
std::cerr << '\n' << std::endl;
|
| 6342 |
+
}
|
| 6343 |
}
|
| 6344 |
|
| 6345 |
GGML_ABORT("fatal error");
|
|
|
|
| 7427 |
case GGML_OP_MUL_MAT_ID:
|
| 7428 |
{
|
| 7429 |
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
| 7430 |
+
const vk_device& device = ggml_vk_get_device(ctx->device);
|
| 7431 |
+
if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) {
|
| 7432 |
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
|
| 7433 |
return false;
|
| 7434 |
}
|
|
@@ -7,6 +7,12 @@
|
|
| 7 |
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
| 8 |
#endif
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
#ifdef MUL_MAT_ID
|
| 11 |
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
| 12 |
#endif
|
|
@@ -57,6 +63,7 @@ layout (push_constant) uniform parameter
|
|
| 57 |
#endif
|
| 58 |
} p;
|
| 59 |
|
|
|
|
| 60 |
layout (constant_id = 1) const uint BM = 64;
|
| 61 |
layout (constant_id = 2) const uint BN = 64;
|
| 62 |
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
|
|
@@ -65,13 +72,26 @@ layout (constant_id = 5) const uint WN = 32;
|
|
| 65 |
layout (constant_id = 6) const uint WMITER = 2;
|
| 66 |
layout (constant_id = 7) const uint TM = 4;
|
| 67 |
layout (constant_id = 8) const uint TN = 2;
|
| 68 |
-
layout (constant_id = 9) const uint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
shared FLOAT_TYPE buf_a[BM *
|
| 71 |
-
shared FLOAT_TYPE buf_b[BN *
|
| 72 |
|
| 73 |
#ifdef MUL_MAT_ID
|
| 74 |
shared u16vec2 row_ids[3072];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
#endif
|
| 76 |
|
| 77 |
void main() {
|
|
@@ -98,17 +118,32 @@ void main() {
|
|
| 98 |
const uint ik = gl_WorkGroupID.x / blocks_m;
|
| 99 |
const uint ic = gl_WorkGroupID.y;
|
| 100 |
|
| 101 |
-
const uint warp_i = gl_LocalInvocationID.x / WARP;
|
| 102 |
-
const uint warp_r = warp_i % (BM / WM);
|
| 103 |
-
const uint warp_c = warp_i / (BM / WM);
|
| 104 |
-
|
| 105 |
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
| 106 |
const uint WSUBM = WM / WMITER;
|
| 107 |
const uint WSUBN = WN / WNITER;
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
const uint tiw = gl_LocalInvocationID.x % WARP;
|
|
|
|
| 110 |
const uint tiwr = tiw % (WSUBM / TM);
|
| 111 |
const uint tiwc = tiw / (WSUBM / TM);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
|
| 114 |
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
|
|
@@ -156,21 +191,31 @@ void main() {
|
|
| 156 |
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
|
| 157 |
#endif
|
| 158 |
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
FLOAT_TYPE cache_a[WMITER * TM];
|
| 161 |
FLOAT_TYPE cache_b[WNITER * TN];
|
| 162 |
|
| 163 |
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
| 164 |
-
sums[i] = 0.0f;
|
| 165 |
}
|
|
|
|
| 166 |
|
| 167 |
-
|
| 168 |
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
|
| 169 |
|
| 170 |
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
| 171 |
#if LOAD_VEC_A == 8
|
| 172 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 173 |
-
const uint buf_idx = (loadc_a + l) *
|
| 174 |
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
|
| 175 |
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
|
| 176 |
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
|
|
@@ -181,21 +226,21 @@ void main() {
|
|
| 181 |
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
|
| 182 |
#elif LOAD_VEC_A == 4
|
| 183 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 184 |
-
const uint buf_idx = (loadc_a + l) *
|
| 185 |
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
|
| 186 |
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
|
| 187 |
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
|
| 188 |
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
|
| 189 |
#else
|
| 190 |
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
|
| 191 |
-
buf_a[(loadc_a + l) *
|
| 192 |
} else {
|
| 193 |
-
buf_a[(loadc_a + l) *
|
| 194 |
}
|
| 195 |
#endif
|
| 196 |
#elif defined(DATA_A_Q4_0)
|
| 197 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 198 |
-
const uint buf_idx = (loadc_a + l) *
|
| 199 |
|
| 200 |
const uint ib = idx / 16;
|
| 201 |
const uint iqs = idx & 0xF;
|
|
@@ -208,7 +253,7 @@ void main() {
|
|
| 208 |
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
| 209 |
#elif defined(DATA_A_Q4_1)
|
| 210 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 211 |
-
const uint buf_idx = (loadc_a + l) *
|
| 212 |
|
| 213 |
const uint ib = idx / 16;
|
| 214 |
const uint iqs = idx & 0xF;
|
|
@@ -222,7 +267,7 @@ void main() {
|
|
| 222 |
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
| 223 |
#elif defined(DATA_A_Q5_0)
|
| 224 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 225 |
-
const uint buf_idx = (loadc_a + l) *
|
| 226 |
|
| 227 |
const uint ib = idx / 16;
|
| 228 |
const uint iqs = idx & 0xF;
|
|
@@ -237,7 +282,7 @@ void main() {
|
|
| 237 |
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
| 238 |
#elif defined(DATA_A_Q5_1)
|
| 239 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 240 |
-
const uint buf_idx = (loadc_a + l) *
|
| 241 |
|
| 242 |
const uint ib = idx / 16;
|
| 243 |
const uint iqs = idx & 0xF;
|
|
@@ -253,7 +298,7 @@ void main() {
|
|
| 253 |
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
| 254 |
#elif defined(DATA_A_Q8_0)
|
| 255 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 256 |
-
const uint buf_idx = (loadc_a + l) *
|
| 257 |
|
| 258 |
const uint ib = idx / 16;
|
| 259 |
const uint iqs = (idx & 0xF) * 2;
|
|
@@ -265,7 +310,7 @@ void main() {
|
|
| 265 |
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
| 266 |
#elif defined(DATA_A_Q2_K)
|
| 267 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 268 |
-
const uint buf_idx = (loadc_a + l) *
|
| 269 |
|
| 270 |
const uint ib = idx / 128; // 2 values per idx
|
| 271 |
const uint iqs = idx % 128; // 0..127
|
|
@@ -284,7 +329,7 @@ void main() {
|
|
| 284 |
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
| 285 |
#elif defined(DATA_A_Q3_K)
|
| 286 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 287 |
-
const uint buf_idx = (loadc_a + l) *
|
| 288 |
|
| 289 |
const uint ib = idx / 128; // 2 values per idx
|
| 290 |
const uint iqs = idx % 128; // 0..127
|
|
@@ -308,7 +353,7 @@ void main() {
|
|
| 308 |
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
|
| 309 |
#elif defined(DATA_A_Q4_K)
|
| 310 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 311 |
-
const uint buf_idx = (loadc_a + l) *
|
| 312 |
|
| 313 |
const uint ib = idx / 128; // 2 values per idx
|
| 314 |
const uint iqs = idx % 128; // 0..127
|
|
@@ -320,15 +365,20 @@ void main() {
|
|
| 320 |
|
| 321 |
const vec2 loadd = vec2(data_a[ib].d);
|
| 322 |
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
const float d = loadd.x * sc;
|
| 333 |
const float m = -loadd.y * mbyte;
|
| 334 |
|
|
@@ -336,7 +386,7 @@ void main() {
|
|
| 336 |
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
|
| 337 |
#elif defined(DATA_A_Q5_K)
|
| 338 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 339 |
-
const uint buf_idx = (loadc_a + l) *
|
| 340 |
|
| 341 |
const uint ib = idx / 128; // 2 values per idx
|
| 342 |
const uint iqs = idx % 128; // 0..127
|
|
@@ -351,15 +401,20 @@ void main() {
|
|
| 351 |
|
| 352 |
const vec2 loadd = vec2(data_a[ib].d);
|
| 353 |
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
const float d = loadd.x * sc;
|
| 364 |
const float m = -loadd.y * mbyte;
|
| 365 |
|
|
@@ -367,7 +422,7 @@ void main() {
|
|
| 367 |
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
|
| 368 |
#elif defined(DATA_A_Q6_K)
|
| 369 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 370 |
-
const uint buf_idx = (loadc_a + l) *
|
| 371 |
|
| 372 |
const uint ib = idx / 128; // 2 values per idx
|
| 373 |
const uint iqs = idx % 128; // 0..127
|
|
@@ -386,7 +441,7 @@ void main() {
|
|
| 386 |
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
| 387 |
#elif defined(DATA_A_IQ4_NL)
|
| 388 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 389 |
-
const uint buf_idx = (loadc_a + l) *
|
| 390 |
|
| 391 |
const uint ib = idx / 16;
|
| 392 |
const uint iqs = idx & 0xF;
|
|
@@ -407,7 +462,7 @@ void main() {
|
|
| 407 |
#else
|
| 408 |
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
| 409 |
#endif
|
| 410 |
-
const uint buf_idx = (loadc_b + l) *
|
| 411 |
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
|
| 412 |
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
|
| 413 |
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
|
|
@@ -423,24 +478,24 @@ void main() {
|
|
| 423 |
#else
|
| 424 |
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
| 425 |
#endif
|
| 426 |
-
const uint buf_idx = (loadc_b + l) *
|
| 427 |
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
|
| 428 |
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
|
| 429 |
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
|
| 430 |
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
|
| 431 |
#elif !MUL_MAT_ID
|
| 432 |
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
|
| 433 |
-
buf_b[(loadc_b + l) *
|
| 434 |
} else {
|
| 435 |
-
buf_b[(loadc_b + l) *
|
| 436 |
}
|
| 437 |
#else
|
| 438 |
const uint row_i = ic * BN + loadc_b + l;
|
| 439 |
if (row_i < _ne1) {
|
| 440 |
const u16vec2 row_idx = row_ids[row_i];
|
| 441 |
-
buf_b[(loadc_b + l) *
|
| 442 |
} else {
|
| 443 |
-
buf_b[(loadc_b + l) *
|
| 444 |
}
|
| 445 |
#endif
|
| 446 |
}
|
|
@@ -450,16 +505,30 @@ void main() {
|
|
| 450 |
pos_a += BK / LOAD_VEC_A;
|
| 451 |
pos_b += BK / LOAD_VEC_B;
|
| 452 |
|
| 453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
// Load from shared into cache
|
| 455 |
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
| 456 |
[[unroll]] for (uint j = 0; j < TM; j++) {
|
| 457 |
-
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) *
|
| 458 |
}
|
| 459 |
}
|
| 460 |
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
| 461 |
[[unroll]] for (uint j = 0; j < TN; j++) {
|
| 462 |
-
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) *
|
| 463 |
}
|
| 464 |
}
|
| 465 |
|
|
@@ -468,12 +537,13 @@ void main() {
|
|
| 468 |
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
| 469 |
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
| 470 |
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
| 471 |
-
sums[sums_idx] = fma(
|
| 472 |
}
|
| 473 |
}
|
| 474 |
}
|
| 475 |
}
|
| 476 |
}
|
|
|
|
| 477 |
|
| 478 |
barrier();
|
| 479 |
}
|
|
@@ -485,6 +555,54 @@ void main() {
|
|
| 485 |
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
| 486 |
#endif
|
| 487 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
| 489 |
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
| 490 |
|
|
@@ -496,7 +614,7 @@ void main() {
|
|
| 496 |
if (row_i >= _ne1) break;
|
| 497 |
|
| 498 |
const u16vec2 row_idx = row_ids[row_i];
|
| 499 |
-
#endif
|
| 500 |
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
| 501 |
#ifdef MUL_MAT_ID
|
| 502 |
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
|
@@ -504,9 +622,10 @@ void main() {
|
|
| 504 |
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
| 505 |
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
| 506 |
}
|
| 507 |
-
#endif
|
| 508 |
}
|
| 509 |
}
|
| 510 |
}
|
| 511 |
}
|
|
|
|
| 512 |
}
|
|
|
|
| 7 |
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
| 8 |
#endif
|
| 9 |
|
| 10 |
+
#ifdef COOPMAT
|
| 11 |
+
#extension GL_KHR_cooperative_matrix : enable
|
| 12 |
+
#extension GL_KHR_memory_scope_semantics : enable
|
| 13 |
+
#extension GL_KHR_shader_subgroup_basic : enable
|
| 14 |
+
#endif
|
| 15 |
+
|
| 16 |
#ifdef MUL_MAT_ID
|
| 17 |
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
| 18 |
#endif
|
|
|
|
| 63 |
#endif
|
| 64 |
} p;
|
| 65 |
|
| 66 |
+
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
|
| 67 |
layout (constant_id = 1) const uint BM = 64;
|
| 68 |
layout (constant_id = 2) const uint BN = 64;
|
| 69 |
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
|
|
|
|
| 72 |
layout (constant_id = 6) const uint WMITER = 2;
|
| 73 |
layout (constant_id = 7) const uint TM = 4;
|
| 74 |
layout (constant_id = 8) const uint TN = 2;
|
| 75 |
+
layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
|
| 76 |
+
layout (constant_id = 10) const uint WARP = 32;
|
| 77 |
+
|
| 78 |
+
#ifdef COOPMAT
|
| 79 |
+
#define SHMEM_STRIDE (BK + 8)
|
| 80 |
+
#else
|
| 81 |
+
#define SHMEM_STRIDE (BK + 1)
|
| 82 |
+
#endif
|
| 83 |
|
| 84 |
+
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
|
| 85 |
+
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
| 86 |
|
| 87 |
#ifdef MUL_MAT_ID
|
| 88 |
shared u16vec2 row_ids[3072];
|
| 89 |
+
#endif // MUL_MAT_ID
|
| 90 |
+
|
| 91 |
+
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
| 92 |
+
|
| 93 |
+
#ifdef COOPMAT
|
| 94 |
+
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
| 95 |
#endif
|
| 96 |
|
| 97 |
void main() {
|
|
|
|
| 118 |
const uint ik = gl_WorkGroupID.x / blocks_m;
|
| 119 |
const uint ic = gl_WorkGroupID.y;
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
| 122 |
const uint WSUBM = WM / WMITER;
|
| 123 |
const uint WSUBN = WN / WNITER;
|
| 124 |
|
| 125 |
+
#ifdef COOPMAT
|
| 126 |
+
const uint warp_i = gl_SubgroupID;
|
| 127 |
+
|
| 128 |
+
const uint tiw = gl_SubgroupInvocationID;
|
| 129 |
+
|
| 130 |
+
const uint cms_per_row = WM / TM;
|
| 131 |
+
const uint cms_per_col = WN / TN;
|
| 132 |
+
|
| 133 |
+
const uint storestride = WARP / TM;
|
| 134 |
+
const uint store_r = tiw % TM;
|
| 135 |
+
const uint store_c = tiw / TM;
|
| 136 |
+
#else
|
| 137 |
+
const uint warp_i = gl_LocalInvocationID.x / WARP;
|
| 138 |
+
|
| 139 |
const uint tiw = gl_LocalInvocationID.x % WARP;
|
| 140 |
+
|
| 141 |
const uint tiwr = tiw % (WSUBM / TM);
|
| 142 |
const uint tiwc = tiw / (WSUBM / TM);
|
| 143 |
+
#endif
|
| 144 |
+
|
| 145 |
+
const uint warp_r = warp_i % (BM / WM);
|
| 146 |
+
const uint warp_c = warp_i / (BM / WM);
|
| 147 |
|
| 148 |
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
|
| 149 |
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
|
|
|
|
| 191 |
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
|
| 192 |
#endif
|
| 193 |
|
| 194 |
+
#ifdef COOPMAT
|
| 195 |
+
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
|
| 196 |
+
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
| 197 |
+
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
| 198 |
+
|
| 199 |
+
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
| 200 |
+
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
| 201 |
+
}
|
| 202 |
+
#else
|
| 203 |
+
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
| 204 |
FLOAT_TYPE cache_a[WMITER * TM];
|
| 205 |
FLOAT_TYPE cache_b[WNITER * TN];
|
| 206 |
|
| 207 |
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
| 208 |
+
sums[i] = ACC_TYPE(0.0f);
|
| 209 |
}
|
| 210 |
+
#endif
|
| 211 |
|
| 212 |
+
for (uint block = start_k; block < end_k; block += BK) {
|
| 213 |
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
|
| 214 |
|
| 215 |
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
| 216 |
#if LOAD_VEC_A == 8
|
| 217 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 218 |
+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
| 219 |
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
|
| 220 |
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
|
| 221 |
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
|
|
|
|
| 226 |
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
|
| 227 |
#elif LOAD_VEC_A == 4
|
| 228 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 229 |
+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
| 230 |
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
|
| 231 |
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
|
| 232 |
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
|
| 233 |
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
|
| 234 |
#else
|
| 235 |
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
|
| 236 |
+
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
|
| 237 |
} else {
|
| 238 |
+
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
|
| 239 |
}
|
| 240 |
#endif
|
| 241 |
#elif defined(DATA_A_Q4_0)
|
| 242 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 243 |
+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
|
| 244 |
|
| 245 |
const uint ib = idx / 16;
|
| 246 |
const uint iqs = idx & 0xF;
|
|
|
|
| 253 |
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
| 254 |
#elif defined(DATA_A_Q4_1)
|
| 255 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 256 |
+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
|
| 257 |
|
| 258 |
const uint ib = idx / 16;
|
| 259 |
const uint iqs = idx & 0xF;
|
|
|
|
| 267 |
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
| 268 |
#elif defined(DATA_A_Q5_0)
|
| 269 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 270 |
+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
|
| 271 |
|
| 272 |
const uint ib = idx / 16;
|
| 273 |
const uint iqs = idx & 0xF;
|
|
|
|
| 282 |
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
| 283 |
#elif defined(DATA_A_Q5_1)
|
| 284 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 285 |
+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
|
| 286 |
|
| 287 |
const uint ib = idx / 16;
|
| 288 |
const uint iqs = idx & 0xF;
|
|
|
|
| 298 |
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
| 299 |
#elif defined(DATA_A_Q8_0)
|
| 300 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 301 |
+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
| 302 |
|
| 303 |
const uint ib = idx / 16;
|
| 304 |
const uint iqs = (idx & 0xF) * 2;
|
|
|
|
| 310 |
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
| 311 |
#elif defined(DATA_A_Q2_K)
|
| 312 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 313 |
+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
| 314 |
|
| 315 |
const uint ib = idx / 128; // 2 values per idx
|
| 316 |
const uint iqs = idx % 128; // 0..127
|
|
|
|
| 329 |
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
| 330 |
#elif defined(DATA_A_Q3_K)
|
| 331 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 332 |
+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
| 333 |
|
| 334 |
const uint ib = idx / 128; // 2 values per idx
|
| 335 |
const uint iqs = idx % 128; // 0..127
|
|
|
|
| 353 |
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
|
| 354 |
#elif defined(DATA_A_Q4_K)
|
| 355 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 356 |
+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
| 357 |
|
| 358 |
const uint ib = idx / 128; // 2 values per idx
|
| 359 |
const uint iqs = idx % 128; // 0..127
|
|
|
|
| 365 |
|
| 366 |
const vec2 loadd = vec2(data_a[ib].d);
|
| 367 |
|
| 368 |
+
const uint scidx0 = (is < 4) ? is : (is + 4);
|
| 369 |
+
const uint scidx1 = (is < 4) ? is : (is - 4);
|
| 370 |
+
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
| 371 |
+
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
| 372 |
+
const uint mbidx0 = is + 4;
|
| 373 |
+
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
| 374 |
+
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
| 375 |
+
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
| 376 |
+
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
| 377 |
+
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
| 378 |
+
|
| 379 |
+
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
| 380 |
+
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
| 381 |
+
|
| 382 |
const float d = loadd.x * sc;
|
| 383 |
const float m = -loadd.y * mbyte;
|
| 384 |
|
|
|
|
| 386 |
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
|
| 387 |
#elif defined(DATA_A_Q5_K)
|
| 388 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 389 |
+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
| 390 |
|
| 391 |
const uint ib = idx / 128; // 2 values per idx
|
| 392 |
const uint iqs = idx % 128; // 0..127
|
|
|
|
| 401 |
|
| 402 |
const vec2 loadd = vec2(data_a[ib].d);
|
| 403 |
|
| 404 |
+
const uint scidx0 = (is < 4) ? is : (is + 4);
|
| 405 |
+
const uint scidx1 = (is < 4) ? is : (is - 4);
|
| 406 |
+
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
| 407 |
+
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
| 408 |
+
const uint mbidx0 = is + 4;
|
| 409 |
+
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
| 410 |
+
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
| 411 |
+
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
| 412 |
+
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
| 413 |
+
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
| 414 |
+
|
| 415 |
+
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
| 416 |
+
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
| 417 |
+
|
| 418 |
const float d = loadd.x * sc;
|
| 419 |
const float m = -loadd.y * mbyte;
|
| 420 |
|
|
|
|
| 422 |
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
|
| 423 |
#elif defined(DATA_A_Q6_K)
|
| 424 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 425 |
+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
| 426 |
|
| 427 |
const uint ib = idx / 128; // 2 values per idx
|
| 428 |
const uint iqs = idx % 128; // 0..127
|
|
|
|
| 441 |
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
| 442 |
#elif defined(DATA_A_IQ4_NL)
|
| 443 |
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
| 444 |
+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
|
| 445 |
|
| 446 |
const uint ib = idx / 16;
|
| 447 |
const uint iqs = idx & 0xF;
|
|
|
|
| 462 |
#else
|
| 463 |
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
| 464 |
#endif
|
| 465 |
+
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
|
| 466 |
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
|
| 467 |
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
|
| 468 |
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
|
|
|
|
| 478 |
#else
|
| 479 |
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
| 480 |
#endif
|
| 481 |
+
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
|
| 482 |
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
|
| 483 |
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
|
| 484 |
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
|
| 485 |
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
|
| 486 |
#elif !MUL_MAT_ID
|
| 487 |
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
|
| 488 |
+
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
|
| 489 |
} else {
|
| 490 |
+
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
| 491 |
}
|
| 492 |
#else
|
| 493 |
const uint row_i = ic * BN + loadc_b + l;
|
| 494 |
if (row_i < _ne1) {
|
| 495 |
const u16vec2 row_idx = row_ids[row_i];
|
| 496 |
+
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
| 497 |
} else {
|
| 498 |
+
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
| 499 |
}
|
| 500 |
#endif
|
| 501 |
}
|
|
|
|
| 505 |
pos_a += BK / LOAD_VEC_A;
|
| 506 |
pos_b += BK / LOAD_VEC_B;
|
| 507 |
|
| 508 |
+
#ifdef COOPMAT
|
| 509 |
+
[[unroll]] for (uint i = 0; i < BK; i += TK) {
|
| 510 |
+
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
| 511 |
+
// Load from shared into cache
|
| 512 |
+
coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
| 513 |
+
|
| 514 |
+
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
| 515 |
+
coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
|
| 516 |
+
|
| 517 |
+
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
|
| 518 |
+
}
|
| 519 |
+
}
|
| 520 |
+
}
|
| 521 |
+
#else
|
| 522 |
+
[[unroll]] for (uint i = 0; i < BK; i++) {
|
| 523 |
// Load from shared into cache
|
| 524 |
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
| 525 |
[[unroll]] for (uint j = 0; j < TM; j++) {
|
| 526 |
+
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
|
| 527 |
}
|
| 528 |
}
|
| 529 |
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
| 530 |
[[unroll]] for (uint j = 0; j < TN; j++) {
|
| 531 |
+
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
|
| 532 |
}
|
| 533 |
}
|
| 534 |
|
|
|
|
| 537 |
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
| 538 |
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
| 539 |
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
| 540 |
+
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]);
|
| 541 |
}
|
| 542 |
}
|
| 543 |
}
|
| 544 |
}
|
| 545 |
}
|
| 546 |
+
#endif
|
| 547 |
|
| 548 |
barrier();
|
| 549 |
}
|
|
|
|
| 555 |
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
| 556 |
#endif
|
| 557 |
|
| 558 |
+
#ifdef COOPMAT
|
| 559 |
+
#ifdef MUL_MAT_ID
|
| 560 |
+
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
| 561 |
+
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
| 562 |
+
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
| 563 |
+
|
| 564 |
+
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
|
| 565 |
+
const uint row_i = dc + cm_col * TN + col + store_c;
|
| 566 |
+
if (row_i >= _ne1) break;
|
| 567 |
+
|
| 568 |
+
const u16vec2 row_idx = row_ids[row_i];
|
| 569 |
+
|
| 570 |
+
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
| 571 |
+
}
|
| 572 |
+
}
|
| 573 |
+
}
|
| 574 |
+
#else
|
| 575 |
+
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
|
| 576 |
+
|
| 577 |
+
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
| 578 |
+
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
| 579 |
+
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
|
| 580 |
+
|
| 581 |
+
if (is_aligned && is_in_bounds) {
|
| 582 |
+
// Full coopMat is within bounds and stride_d is aligned with 16B
|
| 583 |
+
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
|
| 584 |
+
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
|
| 585 |
+
} else if (is_in_bounds) {
|
| 586 |
+
// Full coopMat is within bounds, but stride_d is not aligned
|
| 587 |
+
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
| 588 |
+
|
| 589 |
+
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
| 590 |
+
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
| 591 |
+
}
|
| 592 |
+
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
|
| 593 |
+
// Partial coopMat is within bounds
|
| 594 |
+
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
| 595 |
+
|
| 596 |
+
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
| 597 |
+
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
|
| 598 |
+
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
| 599 |
+
}
|
| 600 |
+
}
|
| 601 |
+
}
|
| 602 |
+
}
|
| 603 |
+
}
|
| 604 |
+
#endif // MUL_MAT_ID
|
| 605 |
+
#else
|
| 606 |
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
| 607 |
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
| 608 |
|
|
|
|
| 614 |
if (row_i >= _ne1) break;
|
| 615 |
|
| 616 |
const u16vec2 row_idx = row_ids[row_i];
|
| 617 |
+
#endif // MUL_MAT_ID
|
| 618 |
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
| 619 |
#ifdef MUL_MAT_ID
|
| 620 |
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
|
|
|
| 622 |
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
| 623 |
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
| 624 |
}
|
| 625 |
+
#endif // MUL_MAT_ID
|
| 626 |
}
|
| 627 |
}
|
| 628 |
}
|
| 629 |
}
|
| 630 |
+
#endif // COOPMAT
|
| 631 |
}
|
|
@@ -60,6 +60,7 @@ const std::vector<std::string> type_names = {
|
|
| 60 |
"iq4_nl"
|
| 61 |
};
|
| 62 |
|
|
|
|
| 63 |
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
|
| 64 |
#ifdef _WIN32
|
| 65 |
HANDLE stdout_read, stdout_write;
|
|
@@ -198,8 +199,8 @@ static uint32_t compile_count = 0;
|
|
| 198 |
static std::mutex compile_count_mutex;
|
| 199 |
static std::condition_variable compile_count_cond;
|
| 200 |
|
| 201 |
-
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) {
|
| 202 |
-
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
| 203 |
std::string out_fname = join_paths(output_dir, name + ".spv");
|
| 204 |
std::string in_path = join_paths(input_dir, in_fname);
|
| 205 |
|
|
@@ -258,7 +259,7 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
|
|
| 258 |
}
|
| 259 |
|
| 260 |
static std::vector<std::future<void>> compiles;
|
| 261 |
-
void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) {
|
| 262 |
{
|
| 263 |
// wait until fewer than N compiles are in progress.
|
| 264 |
// 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
|
|
@@ -269,10 +270,10 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
|
|
| 269 |
}
|
| 270 |
compile_count++;
|
| 271 |
}
|
| 272 |
-
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat2, f16acc));
|
| 273 |
}
|
| 274 |
|
| 275 |
-
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
|
| 276 |
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
|
| 277 |
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
|
| 278 |
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
|
@@ -291,14 +292,20 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
|
|
| 291 |
|
| 292 |
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
| 293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
| 295 |
|
| 296 |
// Shaders with f16 B_TYPE
|
| 297 |
-
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat2, f16acc);
|
| 298 |
-
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
|
| 299 |
|
| 300 |
-
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
|
| 301 |
-
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat2, f16acc);
|
| 302 |
|
| 303 |
for (const auto& tname : type_names) {
|
| 304 |
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
|
@@ -307,12 +314,12 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
|
|
| 307 |
// For aligned matmul loads
|
| 308 |
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
|
| 309 |
|
| 310 |
-
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc);
|
| 311 |
-
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
|
| 312 |
|
| 313 |
if (tname != "f16" && tname != "f32") {
|
| 314 |
-
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc);
|
| 315 |
-
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
|
| 316 |
}
|
| 317 |
}
|
| 318 |
}
|
|
@@ -322,25 +329,24 @@ void process_shaders() {
|
|
| 322 |
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
|
| 323 |
|
| 324 |
// matmul
|
| 325 |
-
for (const auto&
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
#endif
|
| 334 |
-
if (coopmat2 && !fp16) {
|
| 335 |
-
continue;
|
| 336 |
-
}
|
| 337 |
-
if (!coopmat2 && f16acc) {
|
| 338 |
-
continue;
|
| 339 |
-
}
|
| 340 |
-
matmul_shaders(fp16, matmul_id, coopmat2, f16acc);
|
| 341 |
-
}
|
| 342 |
-
}
|
| 343 |
-
}
|
| 344 |
}
|
| 345 |
|
| 346 |
#if defined(VK_NV_cooperative_matrix2)
|
|
@@ -355,11 +361,11 @@ void process_shaders() {
|
|
| 355 |
|
| 356 |
if (tname == "f16") {
|
| 357 |
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
| 358 |
-
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, true, f16acc);
|
| 359 |
} else {
|
| 360 |
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
| 361 |
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
| 362 |
-
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, true, f16acc);
|
| 363 |
}
|
| 364 |
}
|
| 365 |
}
|
|
@@ -524,6 +530,7 @@ void write_output_files() {
|
|
| 524 |
fclose(hdr);
|
| 525 |
fclose(src);
|
| 526 |
}
|
|
|
|
| 527 |
|
| 528 |
int main(int argc, char** argv) {
|
| 529 |
std::map<std::string, std::string> args;
|
|
|
|
| 60 |
"iq4_nl"
|
| 61 |
};
|
| 62 |
|
| 63 |
+
namespace {
|
| 64 |
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
|
| 65 |
#ifdef _WIN32
|
| 66 |
HANDLE stdout_read, stdout_write;
|
|
|
|
| 199 |
static std::mutex compile_count_mutex;
|
| 200 |
static std::condition_variable compile_count_cond;
|
| 201 |
|
| 202 |
+
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
|
| 203 |
+
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
| 204 |
std::string out_fname = join_paths(output_dir, name + ".spv");
|
| 205 |
std::string in_path = join_paths(input_dir, in_fname);
|
| 206 |
|
|
|
|
| 259 |
}
|
| 260 |
|
| 261 |
static std::vector<std::future<void>> compiles;
|
| 262 |
+
void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
|
| 263 |
{
|
| 264 |
// wait until fewer than N compiles are in progress.
|
| 265 |
// 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
|
|
|
|
| 270 |
}
|
| 271 |
compile_count++;
|
| 272 |
}
|
| 273 |
+
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
|
| 274 |
}
|
| 275 |
|
| 276 |
+
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
|
| 277 |
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
|
| 278 |
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
|
| 279 |
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
|
|
|
| 292 |
|
| 293 |
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
| 294 |
|
| 295 |
+
if (coopmat) {
|
| 296 |
+
base_dict["COOPMAT"] = "1";
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
| 300 |
+
|
| 301 |
std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
| 302 |
|
| 303 |
// Shaders with f16 B_TYPE
|
| 304 |
+
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
| 305 |
+
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 306 |
|
| 307 |
+
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 308 |
+
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
| 309 |
|
| 310 |
for (const auto& tname : type_names) {
|
| 311 |
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
|
|
|
| 314 |
// For aligned matmul loads
|
| 315 |
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
|
| 316 |
|
| 317 |
+
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 318 |
+
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 319 |
|
| 320 |
if (tname != "f16" && tname != "f32") {
|
| 321 |
+
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 322 |
+
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
| 323 |
}
|
| 324 |
}
|
| 325 |
}
|
|
|
|
| 329 |
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
|
| 330 |
|
| 331 |
// matmul
|
| 332 |
+
for (const auto& matmul_id : {false, true}) {
|
| 333 |
+
// No coopmats
|
| 334 |
+
// fp32
|
| 335 |
+
matmul_shaders(false, matmul_id, false, false, false);
|
| 336 |
+
|
| 337 |
+
// fp16, fp32acc and fp16acc
|
| 338 |
+
matmul_shaders(true, matmul_id, false, false, false);
|
| 339 |
+
matmul_shaders(true, matmul_id, false, false, true);
|
| 340 |
+
|
| 341 |
+
// Coopmat, fp32acc and fp16acc
|
| 342 |
+
matmul_shaders(true, matmul_id, true, false, false);
|
| 343 |
+
matmul_shaders(true, matmul_id, true, false, true);
|
| 344 |
+
|
| 345 |
+
#if defined(VK_NV_cooperative_matrix2)
|
| 346 |
+
// Coopmat2, fp32acc and fp16acc
|
| 347 |
+
matmul_shaders(true, matmul_id, false, true, false);
|
| 348 |
+
matmul_shaders(true, matmul_id, false, true, true);
|
| 349 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
}
|
| 351 |
|
| 352 |
#if defined(VK_NV_cooperative_matrix2)
|
|
|
|
| 361 |
|
| 362 |
if (tname == "f16") {
|
| 363 |
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
| 364 |
+
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
|
| 365 |
} else {
|
| 366 |
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
| 367 |
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
| 368 |
+
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
| 369 |
}
|
| 370 |
}
|
| 371 |
}
|
|
|
|
| 530 |
fclose(hdr);
|
| 531 |
fclose(src);
|
| 532 |
}
|
| 533 |
+
}
|
| 534 |
|
| 535 |
int main(int argc, char** argv) {
|
| 536 |
std::map<std::string, std::string> args;
|