OccamRazor commited on
Commit
9a4de04
·
1 Parent(s): 44e7250

Vulkan: VK_KHR_cooperative_matrix support to speed up prompt processing (llama/10597)

Browse files

* Vulkan: Implement VK_KHR_cooperative_matrix support in the matrix matrix multiplication shader

* Improve performance with better q4_k and q5_k dequant and store unrolling

* Add Vulkan MUL_MAT and MUL_MAT_ID accumulator precision selection

* Rework mulmat shader selection and compilation logic, avoid compiling shaders that won't get used by device

* Vulkan: Implement accumulator switch for specific mul mat mat shaders

* Vulkan: Unroll more loops for more mul mat mat performance

* Vulkan: Add VK_AMD_shader_core_properties2 support to read Compute Unit count for split_k logic

* Disable coopmat support on AMD proprietary driver

* Remove redundant checks

* Add environment variable GGML_VK_DISABLE_COOPMAT to disable VK_KHR_cooperative_matrix support

* Fix rebase typo

* Fix coopmat2 MUL_MAT_ID pipeline selection

ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -1,7 +1,8 @@
1
  #include "ggml-vulkan.h"
2
  #include <vulkan/vulkan_core.h>
3
- #if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF)
4
  #include <chrono>
 
5
  #endif
6
 
7
  #include <vulkan/vulkan.hpp>
@@ -169,8 +170,22 @@ struct vk_device_struct {
169
  bool uma;
170
  bool coopmat2;
171
 
 
 
 
 
 
 
 
172
  size_t idx;
173
 
 
 
 
 
 
 
 
174
  vk_matmul_pipeline pipeline_matmul_f32;
175
  vk_matmul_pipeline pipeline_matmul_f32_f16;
176
  vk_matmul_pipeline2 pipeline_matmul_f16;
@@ -181,10 +196,10 @@ struct vk_device_struct {
181
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
182
 
183
  vk_matmul_pipeline pipeline_matmul_id_f32;
184
- vk_matmul_pipeline pipeline_matmul_id_f16;
185
- vk_matmul_pipeline pipeline_matmul_id_f16_f32;
186
 
187
- vk_matmul_pipeline pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
188
 
189
  vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
190
  vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
@@ -1325,6 +1340,18 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
1325
  return {64, 64};
1326
  };
1327
 
 
 
 
 
 
 
 
 
 
 
 
 
1328
 
1329
  static void ggml_vk_load_shaders(vk_device& device) {
1330
  VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
@@ -1382,12 +1409,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
1382
  m_align = 64;
1383
  s_align = 32;
1384
  } else {
1385
- l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
1386
- m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
1387
- s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
1388
- l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
1389
- m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
1390
- s_warptile_mmq = { subgroup_size_16, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
 
 
 
 
 
 
 
 
 
 
 
 
 
1391
  l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
1392
  m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
1393
  s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
@@ -1428,25 +1468,36 @@ static void ggml_vk_load_shaders(vk_device& device) {
1428
  // assert mul_mat_mat_id shaders will fit.
1429
  GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
1430
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1431
  }
1432
 
1433
  device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1434
  device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
1435
 
1436
  device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1437
- device->pipeline_matmul_id_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1438
- device->pipeline_matmul_id_f16 = std::make_shared<vk_matmul_pipeline_struct>();
1439
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
1440
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
1441
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
1442
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
1443
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
1444
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
1445
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
1446
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
1447
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
1448
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
1449
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
1450
 
1451
  std::vector<std::future<void>> compiles;
1452
  auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false) {
@@ -1543,119 +1594,191 @@ static void ggml_vk_load_shaders(vk_device& device) {
1543
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1544
 
1545
  CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1546
- CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1547
- CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1548
-
1549
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1550
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1551
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1552
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1553
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1554
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1555
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1556
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1557
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1558
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1559
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1560
  #undef CREATE_MM
1561
  #undef CREATE_MM2
1562
  } else
1563
- #endif
1564
- if (device->fp16) {
1565
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1566
- #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1567
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1568
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1569
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1570
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1571
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1572
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1573
-
1574
- CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1575
- CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1576
- CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1577
- CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1578
-
1579
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1580
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1581
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1582
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1583
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1584
-
1585
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1586
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1587
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1588
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1589
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1590
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
 
 
 
 
 
 
 
 
 
 
 
1591
 
1592
  // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1593
- if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
1594
- CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
1595
- CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
1596
- CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
1597
-
1598
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1599
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1600
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1601
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1602
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1603
-
1604
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1605
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1606
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1607
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1608
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1609
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1610
  }
1611
  #undef CREATE_MM
1612
  } else {
1613
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1614
- #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1615
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1616
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1617
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1618
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1619
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1620
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1621
-
1622
- CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1623
- CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1624
- CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1625
- CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1626
-
1627
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1628
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1629
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1630
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1631
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1632
-
1633
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1634
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1635
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1636
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1637
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1638
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
 
 
 
 
 
 
1639
 
1640
  // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1641
- if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
1642
- CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
1643
- CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
1644
- CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
1645
-
1646
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1647
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1648
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1649
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1650
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1651
-
1652
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1653
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1654
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1655
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1656
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1657
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1658
  }
 
1659
  #undef CREATE_MM
1660
  }
1661
 
@@ -1851,8 +1974,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
1851
  bool fp16_compute = false;
1852
  bool maintenance4_support = false;
1853
  bool sm_builtins = false;
 
1854
  bool pipeline_robustness = false;
1855
  bool coopmat2_support = false;
 
1856
 
1857
  // Check if maintenance4 is supported
1858
  for (const auto& properties : ext_props) {
@@ -1864,10 +1989,18 @@ static vk_device ggml_vk_get_device(size_t idx) {
1864
  fp16_compute = true;
1865
  } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
1866
  sm_builtins = true;
 
 
1867
  } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
1868
  pipeline_robustness = true;
 
 
 
 
 
 
1869
  } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
1870
- !getenv("GGML_VULKAN_DISABLE_COOPMAT2")) {
1871
  coopmat2_support = true;
1872
  }
1873
  }
@@ -1876,11 +2009,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
1876
  vk::PhysicalDeviceMaintenance3Properties props3;
1877
  vk::PhysicalDeviceMaintenance4Properties props4;
1878
  vk::PhysicalDeviceSubgroupProperties subgroup_props;
 
1879
  vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
 
1880
  props2.pNext = &props3;
1881
  props3.pNext = &subgroup_props;
 
1882
 
1883
- VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&subgroup_props;
1884
 
1885
  if (maintenance4_support) {
1886
  last_struct->pNext = (VkBaseOutStructure *)&props4;
@@ -1890,6 +2026,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
1890
  last_struct->pNext = (VkBaseOutStructure *)&sm_props;
1891
  last_struct = (VkBaseOutStructure *)&sm_props;
1892
  }
 
 
 
 
1893
 
1894
  #if defined(VK_NV_cooperative_matrix2)
1895
  vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
@@ -1905,7 +2045,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
1905
  const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
1906
 
1907
  if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
1908
- device->max_memory_allocation_size = std::stoi(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
1909
  } else if (maintenance4_support) {
1910
  device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
1911
  } else {
@@ -1917,15 +2057,22 @@ static vk_device ggml_vk_get_device(size_t idx) {
1917
  device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
1918
  if (sm_builtins) {
1919
  device->shader_core_count = sm_props.shaderSMCount;
 
 
1920
  } else {
1921
  device->shader_core_count = 0;
1922
  }
1923
 
1924
- const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
1925
- const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
1926
 
1927
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
1928
 
 
 
 
 
 
 
1929
  std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
1930
 
1931
  // Try to find a non-graphics compute queue and transfer-focused queues
@@ -1976,6 +2123,16 @@ static vk_device ggml_vk_get_device(size_t idx) {
1976
  device_extensions.push_back("VK_EXT_pipeline_robustness");
1977
  }
1978
 
 
 
 
 
 
 
 
 
 
 
1979
  #if defined(VK_NV_cooperative_matrix2)
1980
  VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
1981
  coopmat2_features.pNext = nullptr;
@@ -1993,6 +2150,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
1993
 
1994
  device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
1995
 
 
 
1996
  if (coopmat2_support) {
1997
  #if defined(VK_NV_cooperative_matrix2)
1998
  if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
@@ -2083,6 +2242,74 @@ static vk_device ggml_vk_get_device(size_t idx) {
2083
  if (device->fp16) {
2084
  device_extensions.push_back("VK_KHR_shader_float16_int8");
2085
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2086
  device->name = GGML_VK_NAME + std::to_string(idx);
2087
 
2088
  device_create_info = {
@@ -2098,6 +2325,37 @@ static vk_device ggml_vk_get_device(size_t idx) {
2098
  ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);
2099
 
2100
  // Shaders
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2101
  ggml_vk_load_shaders(device);
2102
 
2103
  if (!device->single_queue) {
@@ -2155,15 +2413,24 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2155
 
2156
  bool fp16_storage = false;
2157
  bool fp16_compute = false;
 
2158
 
2159
  for (auto properties : ext_props) {
2160
  if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
2161
  fp16_storage = true;
2162
  } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
2163
  fp16_compute = true;
 
 
2164
  }
2165
  }
2166
 
 
 
 
 
 
 
2167
  const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
2168
  bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
2169
 
@@ -2186,13 +2453,28 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2186
  vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
2187
  vk11_features.pNext = &vk12_features;
2188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2189
  vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
2190
 
2191
  fp16 = fp16 && vk12_features.shaderFloat16;
2192
 
 
 
2193
  std::string device_name = props2.properties.deviceName.data();
2194
- GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu\n",
2195
- idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size);
2196
 
2197
  if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
2198
  GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
@@ -2428,7 +2710,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
2428
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
2429
  return ctx->device->pipeline_matmul_f32_f16;
2430
  }
2431
- if (prec == GGML_PREC_DEFAULT && ctx->device->coopmat2) {
2432
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2433
  return ctx->device->pipeline_matmul_f16_f32.f16acc;
2434
  }
@@ -2469,7 +2751,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
2469
  assert(src1_type == GGML_TYPE_F16);
2470
  return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc;
2471
  }
2472
- return ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
2473
  }
2474
 
2475
  static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
@@ -2498,16 +2780,25 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
2498
  return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
2499
  }
2500
 
2501
- static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
2502
  VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()");
2503
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
2504
  return ctx->device->pipeline_matmul_id_f32;
2505
  }
2506
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2507
- return ctx->device->pipeline_matmul_id_f16_f32;
2508
- }
2509
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2510
- return ctx->device->pipeline_matmul_id_f16;
 
 
 
 
 
 
 
 
 
2511
  }
2512
 
2513
  GGML_ASSERT(src1_type == GGML_TYPE_F32);
@@ -2529,7 +2820,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
2529
  return nullptr;
2530
  }
2531
 
2532
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
2533
  }
2534
 
2535
  static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
@@ -3119,62 +3410,31 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
3119
  return split_k;
3120
  }
3121
 
3122
- static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
3123
- if (m <= 32 || n <= 32) {
3124
- return aligned ? mmp->a_s : mmp->s;
3125
- }
3126
- return aligned ? mmp->a_m : mmp->m;
3127
-
3128
- GGML_UNUSED(ctx);
3129
- }
3130
-
3131
- static vk_pipeline ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
3132
- return aligned ? mmp->a_m : mmp->m;
3133
-
3134
- GGML_UNUSED(ctx);
3135
- }
3136
-
3137
- static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
3138
- return aligned ? mmp->a_s : mmp->s;
3139
-
3140
- GGML_UNUSED(ctx);
3141
- }
3142
-
3143
- static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
3144
  VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
3145
- switch (ctx->device->vendor_id) {
3146
- case VK_VENDOR_ID_AMD:
3147
- return ggml_vk_guess_matmul_pipeline_amd(ctx, mmp, m, n, aligned);
3148
- case VK_VENDOR_ID_APPLE:
3149
- return ggml_vk_guess_matmul_pipeline_apple(ctx, mmp, aligned);
3150
- case VK_VENDOR_ID_INTEL:
3151
- return ggml_vk_guess_matmul_pipeline_intel(ctx, mmp, aligned);
3152
- default:
3153
- break;
3154
- }
3155
 
3156
  if (ctx->device->coopmat2) {
3157
- if ((m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) {
3158
  return aligned ? mmp->a_l : mmp->l;
3159
  }
3160
- if ((m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) {
3161
  return aligned ? mmp->a_m : mmp->m;
3162
  }
3163
  return aligned ? mmp->a_s : mmp->s;
3164
  }
3165
 
3166
- if (m <= 32 || n <= 32) {
3167
  return aligned ? mmp->a_s : mmp->s;
3168
  }
3169
- if (m <= 64 || n <= 64) {
3170
  return aligned ? mmp->a_m : mmp->m;
3171
  }
3172
  return aligned ? mmp->a_l : mmp->l;
3173
  }
3174
 
3175
- static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
3176
  VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
3177
- return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align;
3178
  }
3179
 
3180
  static void ggml_vk_matmul(
@@ -3201,6 +3461,33 @@ static void ggml_vk_matmul(
3201
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
3202
  }
3203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3204
  static void ggml_vk_matmul_id(
3205
  ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
3206
  vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
@@ -3350,10 +3637,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
3350
  const int y_ne = ne11 * ne10;
3351
  const int d_ne = ne11 * ne01;
3352
 
3353
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
3354
  const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
3355
 
3356
- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
3357
 
3358
  const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
3359
 
@@ -3904,7 +4191,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
3904
 
3905
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
3906
 
3907
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
3908
 
3909
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
3910
  const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
@@ -3920,10 +4207,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
3920
  const uint64_t y_ne = ne11 * ne10;
3921
  const uint64_t d_ne = ne21 * ne20;
3922
 
3923
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, nei1));
3924
  const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
3925
 
3926
- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, nei1, aligned);
3927
 
3928
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
3929
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
@@ -5504,19 +5791,27 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
5504
  for (size_t i = 0; i < x_ne; i++) {
5505
  if (std::is_same<float, X_TYPE>()) {
5506
  x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
 
 
 
5507
  } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
5508
  x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
 
 
 
5509
  } else {
5510
  GGML_ABORT("fatal error");
5511
  }
5512
  }
5513
  for (size_t i = 0; i < y_ne; i++) {
5514
  if (std::is_same<float, Y_TYPE>()) {
5515
- // y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
5516
- y[i] = (i % k == i / k) ? 1.0f : 0.0f;
 
5517
  } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
5518
- // y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
5519
- y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
 
5520
  } else {
5521
  GGML_ABORT("fatal error");
5522
  }
@@ -5600,7 +5895,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
5600
  double err = std::fabs(d[i] - d_chk[i]);
5601
  avg_err += err;
5602
 
5603
- if (err > 0.05f && first_err_n == -1) {
5604
  first_err_b = i / (m * n);
5605
  first_err_n = (i % (m * n)) / m;
5606
  first_err_m = (i % (m * n)) % m;
@@ -5613,12 +5908,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
5613
 
5614
  std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
5615
 
5616
- if (avg_err > 0.1) {
5617
  std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
5618
  std::cerr << "Actual result: " << std::endl << std::endl;
5619
  ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
5620
- std::cerr << std::endl;
5621
- ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
5622
  std::cerr << "Expected result: " << std::endl << std::endl;
5623
  ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
5624
 
@@ -5801,13 +6094,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
5801
  vk_pipeline p;
5802
  std::string shname;
5803
  if (shader_size == 0) {
5804
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
5805
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
5806
  } else if (shader_size == 1) {
5807
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
5808
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
5809
  } else if (shader_size == 2) {
5810
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
5811
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
5812
  } else {
5813
  GGML_ASSERT(0);
@@ -5817,13 +6110,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
5817
 
5818
  if (k != kpad) {
5819
  if (shader_size == 0) {
5820
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
5821
  shname = std::string(ggml_type_name(quant)) + "_S";
5822
  } else if (shader_size == 1) {
5823
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
5824
  shname = std::string(ggml_type_name(quant)) + "_M";
5825
  } else if (shader_size == 2) {
5826
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
5827
  shname = std::string(ggml_type_name(quant)) + "_L";
5828
  } else {
5829
  GGML_ASSERT(0);
@@ -5982,105 +6275,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
5982
 
5983
  static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
5984
  #if defined(GGML_VULKAN_RUN_TESTS)
5985
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
5986
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0);
5987
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1);
5988
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_0);
5989
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_1);
5990
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q8_0);
5991
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q2_K);
5992
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q3_K);
5993
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_K);
5994
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K);
5995
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K);
5996
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_IQ4_NL);
5997
-
5998
- ggml_vk_test_matmul<ggml_fp16_t, ggml_fp16_t>(ctx, 512, 512, 100, 32, 100, 1, 2);
5999
-
6000
- ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
6001
- ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
6002
- ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
6003
- // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
6004
- // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
6005
- // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
6006
-
6007
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0);
6008
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0);
6009
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0);
6010
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0);
6011
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0);
6012
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0);
6013
-
6014
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_1);
6015
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_1);
6016
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_1);
6017
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_1);
6018
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_1);
6019
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_1);
6020
-
6021
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_0);
6022
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_0);
6023
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_0);
6024
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_0);
6025
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_0);
6026
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_0);
6027
-
6028
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_1);
6029
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_1);
6030
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_1);
6031
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_1);
6032
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_1);
6033
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_1);
6034
-
6035
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q8_0);
6036
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q8_0);
6037
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q8_0);
6038
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q8_0);
6039
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q8_0);
6040
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q8_0);
6041
-
6042
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q2_K);
6043
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q2_K);
6044
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q2_K);
6045
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q2_K);
6046
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q2_K);
6047
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q2_K);
6048
-
6049
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q3_K);
6050
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q3_K);
6051
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q3_K);
6052
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q3_K);
6053
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q3_K);
6054
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q3_K);
6055
-
6056
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_K);
6057
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_K);
6058
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_K);
6059
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_K);
6060
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_K);
6061
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_K);
6062
-
6063
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_K);
6064
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_K);
6065
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_K);
6066
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_K);
6067
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_K);
6068
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_K);
6069
-
6070
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q6_K);
6071
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q6_K);
6072
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q6_K);
6073
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q6_K);
6074
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q6_K);
6075
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q6_K);
6076
-
6077
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_IQ4_NL);
6078
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_IQ4_NL);
6079
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_IQ4_NL);
6080
-
6081
- std::cerr << std::endl;
6082
-
6083
  const std::vector<size_t> vals {
 
 
 
 
 
 
6084
  8, 8, 8,
6085
  100, 46, 576,
6086
  623, 111, 128,
@@ -6093,15 +6294,6 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
6093
  49, 49, 128,
6094
  128, 49, 49,
6095
  4096, 49, 4096,
6096
- 11008, 49, 4096,
6097
- 4096, 49, 11008,
6098
- 32000, 49, 4096,
6099
- 512, 512, 128,
6100
- 128, 512, 512,
6101
- 4096, 512, 4096,
6102
- 11008, 512, 4096,
6103
- 4096, 512, 11008,
6104
- 32000, 512, 4096,
6105
  };
6106
  const size_t num_it = 100;
6107
 
@@ -6109,10 +6301,45 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
6109
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
6110
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
6111
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
6112
- // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
6113
- // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
6114
- // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
6115
- std::cerr << std::endl;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6116
  }
6117
 
6118
  GGML_ABORT("fatal error");
@@ -7200,8 +7427,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
7200
  case GGML_OP_MUL_MAT_ID:
7201
  {
7202
  ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
7203
- if (op->op == GGML_OP_MUL_MAT_ID &&
7204
- ggml_vk_get_device(ctx->device)->properties.limits.maxComputeSharedMemorySize < 32768) {
7205
  // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
7206
  return false;
7207
  }
 
1
  #include "ggml-vulkan.h"
2
  #include <vulkan/vulkan_core.h>
3
+ #if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS)
4
  #include <chrono>
5
+ #include "ggml-cpu.h"
6
  #endif
7
 
8
  #include <vulkan/vulkan.hpp>
 
170
  bool uma;
171
  bool coopmat2;
172
 
173
+ bool coopmat_support;
174
+ bool coopmat_acc_f32_support;
175
+ bool coopmat_acc_f16_support;
176
+ uint32_t coopmat_m;
177
+ uint32_t coopmat_n;
178
+ uint32_t coopmat_k;
179
+
180
  size_t idx;
181
 
182
+ bool mul_mat_l;
183
+ bool mul_mat_m;
184
+ bool mul_mat_s;
185
+ bool mul_mat_id_l;
186
+ bool mul_mat_id_m;
187
+ bool mul_mat_id_s;
188
+
189
  vk_matmul_pipeline pipeline_matmul_f32;
190
  vk_matmul_pipeline pipeline_matmul_f32_f16;
191
  vk_matmul_pipeline2 pipeline_matmul_f16;
 
196
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
197
 
198
  vk_matmul_pipeline pipeline_matmul_id_f32;
199
+ vk_matmul_pipeline2 pipeline_matmul_id_f16;
200
+ vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
201
 
202
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
203
 
204
  vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
205
  vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
 
1340
  return {64, 64};
1341
  };
1342
 
1343
+ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id) {
1344
+ // Needs to be kept up to date on shader changes
1345
+ const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
1346
+ const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
1347
+ const uint32_t warps = warptile[0] / device->subgroup_size;
1348
+
1349
+ const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
1350
+ const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
1351
+ const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
1352
+
1353
+ return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize;
1354
+ }
1355
 
1356
  static void ggml_vk_load_shaders(vk_device& device) {
1357
  VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
 
1409
  m_align = 64;
1410
  s_align = 32;
1411
  } else {
1412
+ // Matrix cores require different warp group sizes
1413
+ const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4;
1414
+ const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4;
1415
+ const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2;
1416
+ const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4;
1417
+ const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2;
1418
+ const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2;
1419
+ const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1;
1420
+ const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
1421
+ const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
1422
+
1423
+ l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
1424
+ m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
1425
+ s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
1426
+
1427
+ l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
1428
+ m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
1429
+ s_warptile_mmq = { subgroup_size_16, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
1430
+
1431
  l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
1432
  m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
1433
  s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
 
1468
  // assert mul_mat_mat_id shaders will fit.
1469
  GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
1470
  }
1471
+ // Disable medium and large matrix multiplication if not enough shared memory is available
1472
+ // Check mmq warptiles as the largest configuration
1473
+ // Throw an error if not enough for any matrix multiplication is available
1474
+ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) {
1475
+ std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
1476
+ throw std::runtime_error("Shared memory size too small for matrix multiplication.");
1477
+ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) {
1478
+ device->mul_mat_m = false;
1479
+ device->mul_mat_l = false;
1480
+ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) {
1481
+ device->mul_mat_l = false;
1482
+ }
1483
+
1484
+ // Disable mul_mat_id if not enough shared memory is available
1485
+ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) {
1486
+ device->mul_mat_id_s = false;
1487
+ device->mul_mat_id_m = false;
1488
+ device->mul_mat_id_l = false;
1489
+ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) {
1490
+ device->mul_mat_id_m = false;
1491
+ device->mul_mat_id_l = false;
1492
+ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) {
1493
+ device->mul_mat_id_l = false;
1494
+ }
1495
  }
1496
 
1497
  device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1498
  device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
1499
 
1500
  device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
 
 
 
 
 
 
 
 
 
 
 
 
 
1501
 
1502
  std::vector<std::future<void>> compiles;
1503
  auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false) {
 
1594
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1595
 
1596
  CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1597
+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1598
+ CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1599
+
1600
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1601
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1602
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1603
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1604
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1605
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1606
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1607
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1608
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1609
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1610
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1611
  #undef CREATE_MM
1612
  #undef CREATE_MM2
1613
  } else
1614
+ #endif // defined(VK_NV_cooperative_matrix2)
1615
+ if (device->coopmat_support) {
1616
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1617
+ #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1618
+ if (device->mul_mat ## ID ## _l) \
1619
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1620
+ if (device->mul_mat ## ID ## _m) \
1621
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1622
+ if (device->mul_mat ## ID ## _s) \
1623
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1624
+ if (device->mul_mat ## ID ## _l) \
1625
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1626
+ if (device->mul_mat ## ID ## _m) \
1627
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1628
+ if (device->mul_mat ## ID ## _s) \
1629
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1630
+
1631
+ // Create 2 variants, {f16,f32} accumulator
1632
+ #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1633
+ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1634
+ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1635
+
1636
+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1637
+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1638
+ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1639
+ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1640
+
1641
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1642
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1643
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1644
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1645
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1646
+
1647
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1648
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1649
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1650
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1651
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1652
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1653
 
1654
  // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1655
+ if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
1656
+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1657
+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1658
+ CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1659
+
1660
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1661
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1662
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1663
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1664
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1665
+
1666
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1667
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1668
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1669
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1670
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1671
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1672
+ }
1673
+ #undef CREATE_MM
1674
+ } else if (device->fp16) {
1675
+ // Create 6 variants, {s,m,l}x{unaligned,aligned}
1676
+ #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1677
+ if (device->mul_mat ## ID ## _l) \
1678
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1679
+ if (device->mul_mat ## ID ## _m) \
1680
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1681
+ if (device->mul_mat ## ID ## _s) \
1682
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1683
+ if (device->mul_mat ## ID ## _l) \
1684
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1685
+ if (device->mul_mat ## ID ## _m) \
1686
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1687
+ if (device->mul_mat ## ID ## _s) \
1688
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1689
+
1690
+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1691
+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1692
+ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1693
+ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1694
+
1695
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1696
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1697
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1698
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1699
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1700
+
1701
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1702
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1703
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1704
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1705
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1706
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1707
+
1708
+ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1709
+ if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
1710
+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1711
+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1712
+ CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1713
+
1714
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1715
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1716
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1717
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1718
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1719
+
1720
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1721
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1722
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1723
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1724
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1725
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1726
  }
1727
  #undef CREATE_MM
1728
  } else {
1729
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1730
+ #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1731
+ if (device->mul_mat ## ID ## _l) \
1732
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1733
+ if (device->mul_mat ## ID ## _m) \
1734
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1735
+ if (device->mul_mat ## ID ## _s) \
1736
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1737
+ if (device->mul_mat ## ID ## _l) \
1738
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1739
+ if (device->mul_mat ## ID ## _m) \
1740
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1741
+ if (device->mul_mat ## ID ## _s) \
1742
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1743
+
1744
+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1745
+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1746
+ CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1747
+ CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1748
+
1749
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1750
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1751
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1752
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1753
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1754
+
1755
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1756
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1757
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1758
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1759
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1760
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1761
 
1762
  // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1763
+ if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
1764
+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1765
+ CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1766
+ CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1767
+
1768
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1769
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1770
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1771
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1772
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1773
+
1774
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1775
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1776
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1777
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1778
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1779
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1780
  }
1781
+ #undef CREATE_MM2
1782
  #undef CREATE_MM
1783
  }
1784
 
 
1974
  bool fp16_compute = false;
1975
  bool maintenance4_support = false;
1976
  bool sm_builtins = false;
1977
+ bool amd_shader_core_properties2 = false;
1978
  bool pipeline_robustness = false;
1979
  bool coopmat2_support = false;
1980
+ device->coopmat_support = false;
1981
 
1982
  // Check if maintenance4 is supported
1983
  for (const auto& properties : ext_props) {
 
1989
  fp16_compute = true;
1990
  } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
1991
  sm_builtins = true;
1992
+ } else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) {
1993
+ amd_shader_core_properties2 = true;
1994
  } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
1995
  pipeline_robustness = true;
1996
+ } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
1997
+ !getenv("GGML_VK_DISABLE_COOPMAT")) {
1998
+ device->coopmat_support = true;
1999
+ device->coopmat_m = 0;
2000
+ device->coopmat_n = 0;
2001
+ device->coopmat_k = 0;
2002
  } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2003
+ !getenv("GGML_VK_DISABLE_COOPMAT2")) {
2004
  coopmat2_support = true;
2005
  }
2006
  }
 
2009
  vk::PhysicalDeviceMaintenance3Properties props3;
2010
  vk::PhysicalDeviceMaintenance4Properties props4;
2011
  vk::PhysicalDeviceSubgroupProperties subgroup_props;
2012
+ vk::PhysicalDeviceDriverProperties driver_props;
2013
  vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
2014
+ vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2015
  props2.pNext = &props3;
2016
  props3.pNext = &subgroup_props;
2017
+ subgroup_props.pNext = &driver_props;
2018
 
2019
+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
2020
 
2021
  if (maintenance4_support) {
2022
  last_struct->pNext = (VkBaseOutStructure *)&props4;
 
2026
  last_struct->pNext = (VkBaseOutStructure *)&sm_props;
2027
  last_struct = (VkBaseOutStructure *)&sm_props;
2028
  }
2029
+ if (amd_shader_core_properties2) {
2030
+ last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
2031
+ last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
2032
+ }
2033
 
2034
  #if defined(VK_NV_cooperative_matrix2)
2035
  vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
 
2045
  const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
2046
 
2047
  if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
2048
+ device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
2049
  } else if (maintenance4_support) {
2050
  device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
2051
  } else {
 
2057
  device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
2058
  if (sm_builtins) {
2059
  device->shader_core_count = sm_props.shaderSMCount;
2060
+ } else if (amd_shader_core_properties2) {
2061
+ device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
2062
  } else {
2063
  device->shader_core_count = 0;
2064
  }
2065
 
2066
+ const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
 
2067
 
2068
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
2069
 
2070
+ if (device->vendor_id == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2071
+ // Intel drivers don't support coopmat properly yet
2072
+ // Only RADV supports coopmat properly on AMD
2073
+ device->coopmat_support = false;
2074
+ }
2075
+
2076
  std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
2077
 
2078
  // Try to find a non-graphics compute queue and transfer-focused queues
 
2123
  device_extensions.push_back("VK_EXT_pipeline_robustness");
2124
  }
2125
 
2126
+ VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
2127
+ coopmat_features.pNext = nullptr;
2128
+ coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
2129
+ coopmat_features.cooperativeMatrix = VK_FALSE;
2130
+
2131
+ if (device->coopmat_support) {
2132
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
2133
+ last_struct = (VkBaseOutStructure *)&coopmat_features;
2134
+ }
2135
+
2136
  #if defined(VK_NV_cooperative_matrix2)
2137
  VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
2138
  coopmat2_features.pNext = nullptr;
 
2150
 
2151
  device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
2152
 
2153
+ device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
2154
+
2155
  if (coopmat2_support) {
2156
  #if defined(VK_NV_cooperative_matrix2)
2157
  if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
 
2242
  if (device->fp16) {
2243
  device_extensions.push_back("VK_KHR_shader_float16_int8");
2244
  }
2245
+
2246
+ if (device->coopmat_support) {
2247
+ // Query supported shapes
2248
+ std::vector<VkCooperativeMatrixPropertiesKHR> cm_props;
2249
+
2250
+ PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR =
2251
+ (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR");
2252
+
2253
+ uint32_t cm_props_num;
2254
+
2255
+ pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr);
2256
+
2257
+ cm_props.resize(cm_props_num);
2258
+
2259
+ for (auto& prop : cm_props) {
2260
+ prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;
2261
+ }
2262
+
2263
+ pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data());
2264
+
2265
+ VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size());
2266
+
2267
+ for (auto& prop : cm_props) {
2268
+ VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope));
2269
+
2270
+ if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 &&
2271
+ (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 &&
2272
+ (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
2273
+ ) {
2274
+ if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
2275
+ (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) {
2276
+ // coopmat sizes not set yet
2277
+ if (device->coopmat_m == 0) {
2278
+ device->coopmat_acc_f32_support = true;
2279
+ device->coopmat_m = prop.MSize;
2280
+ device->coopmat_n = prop.NSize;
2281
+ device->coopmat_k = prop.KSize;
2282
+ } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
2283
+ // Only enable if shape is identical
2284
+ device->coopmat_acc_f32_support = true;
2285
+ }
2286
+ } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
2287
+ (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
2288
+ // coopmat sizes not set yet
2289
+ if (device->coopmat_m == 0) {
2290
+ device->coopmat_acc_f16_support = true;
2291
+ device->coopmat_m = prop.MSize;
2292
+ device->coopmat_n = prop.NSize;
2293
+ device->coopmat_k = prop.KSize;
2294
+ } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
2295
+ // Only enable if shape is identical
2296
+ device->coopmat_acc_f16_support = true;
2297
+ }
2298
+ }
2299
+ }
2300
+ }
2301
+
2302
+ if (device->coopmat_m == 0) {
2303
+ // No suitable matmul mode found
2304
+ GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
2305
+ device->coopmat_support = false;
2306
+ }
2307
+ }
2308
+
2309
+ if (device->coopmat_support) {
2310
+ device_extensions.push_back("VK_KHR_cooperative_matrix");
2311
+ }
2312
+
2313
  device->name = GGML_VK_NAME + std::to_string(idx);
2314
 
2315
  device_create_info = {
 
2325
  ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);
2326
 
2327
  // Shaders
2328
+ // Disable matmul tile sizes early if performance low or not supported
2329
+ switch (device->vendor_id) {
2330
+ #ifndef GGML_VULKAN_RUN_TESTS
2331
+ case VK_VENDOR_ID_AMD:
2332
+ case VK_VENDOR_ID_INTEL:
2333
+ device->mul_mat_l = false;
2334
+ device->mul_mat_m = true;
2335
+ device->mul_mat_s = true;
2336
+ device->mul_mat_id_l = false;
2337
+ device->mul_mat_id_m = true;
2338
+ device->mul_mat_id_s = true;
2339
+ break;
2340
+ case VK_VENDOR_ID_APPLE:
2341
+ device->mul_mat_l = false;
2342
+ device->mul_mat_m = true;
2343
+ device->mul_mat_s = false;
2344
+ device->mul_mat_id_l = false;
2345
+ device->mul_mat_id_m = true;
2346
+ device->mul_mat_id_s = false;
2347
+ break;
2348
+ #endif
2349
+ default:
2350
+ device->mul_mat_l = true;
2351
+ device->mul_mat_m = true;
2352
+ device->mul_mat_s = true;
2353
+ device->mul_mat_id_l = true;
2354
+ device->mul_mat_id_m = true;
2355
+ device->mul_mat_id_s = true;
2356
+ break;
2357
+ }
2358
+
2359
  ggml_vk_load_shaders(device);
2360
 
2361
  if (!device->single_queue) {
 
2413
 
2414
  bool fp16_storage = false;
2415
  bool fp16_compute = false;
2416
+ bool coopmat_support = false;
2417
 
2418
  for (auto properties : ext_props) {
2419
  if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
2420
  fp16_storage = true;
2421
  } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
2422
  fp16_compute = true;
2423
+ } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
2424
+ coopmat_support = true;
2425
  }
2426
  }
2427
 
2428
+ if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2429
+ // Intel drivers don't support coopmat properly yet
2430
+ // Only RADV supports coopmat properly on AMD
2431
+ coopmat_support = false;
2432
+ }
2433
+
2434
  const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
2435
  bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
2436
 
 
2453
  vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
2454
  vk11_features.pNext = &vk12_features;
2455
 
2456
+ // Pointer to the last chain element
2457
+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
2458
+
2459
+ VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
2460
+ coopmat_features.pNext = nullptr;
2461
+ coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
2462
+ coopmat_features.cooperativeMatrix = VK_FALSE;
2463
+
2464
+ if (coopmat_support) {
2465
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
2466
+ last_struct = (VkBaseOutStructure *)&coopmat_features;
2467
+ }
2468
+
2469
  vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
2470
 
2471
  fp16 = fp16 && vk12_features.shaderFloat16;
2472
 
2473
+ coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
2474
+
2475
  std::string device_name = props2.properties.deviceName.data();
2476
+ GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %d\n",
2477
+ idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, coopmat_support);
2478
 
2479
  if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
2480
  GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
 
2710
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
2711
  return ctx->device->pipeline_matmul_f32_f16;
2712
  }
2713
+ if (prec == GGML_PREC_DEFAULT && ctx->device->fp16) {
2714
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2715
  return ctx->device->pipeline_matmul_f16_f32.f16acc;
2716
  }
 
2751
  assert(src1_type == GGML_TYPE_F16);
2752
  return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc;
2753
  }
2754
+ return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
2755
  }
2756
 
2757
  static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
 
2780
  return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
2781
  }
2782
 
2783
+ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
2784
  VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()");
2785
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
2786
  return ctx->device->pipeline_matmul_id_f32;
2787
  }
2788
+ if (prec == GGML_PREC_DEFAULT && ctx->device->fp16) {
2789
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2790
+ return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
2791
+ }
2792
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2793
+ return ctx->device->pipeline_matmul_id_f16.f16acc;
2794
+ }
2795
+ } else {
2796
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2797
+ return ctx->device->pipeline_matmul_id_f16_f32.f32acc;
2798
+ }
2799
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2800
+ return ctx->device->pipeline_matmul_id_f16.f32acc;
2801
+ }
2802
  }
2803
 
2804
  GGML_ASSERT(src1_type == GGML_TYPE_F32);
 
2820
  return nullptr;
2821
  }
2822
 
2823
+ return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
2824
  }
2825
 
2826
  static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
 
3410
  return split_k;
3411
  }
3412
 
3413
+ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type type_a) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3414
  VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
 
 
 
 
 
 
 
 
 
 
3415
 
3416
  if (ctx->device->coopmat2) {
3417
+ if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) {
3418
  return aligned ? mmp->a_l : mmp->l;
3419
  }
3420
+ if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) {
3421
  return aligned ? mmp->a_m : mmp->m;
3422
  }
3423
  return aligned ? mmp->a_s : mmp->s;
3424
  }
3425
 
3426
+ if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) {
3427
  return aligned ? mmp->a_s : mmp->s;
3428
  }
3429
+ if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) {
3430
  return aligned ? mmp->a_m : mmp->m;
3431
  }
3432
  return aligned ? mmp->a_l : mmp->l;
3433
  }
3434
 
3435
+ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type type_a) {
3436
  VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
3437
+ return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, type_a)->align;
3438
  }
3439
 
3440
  static void ggml_vk_matmul(
 
3461
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
3462
  }
3463
 
3464
+ static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
3465
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
3466
+
3467
+ if (ctx->device->coopmat2) {
3468
+ if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) {
3469
+ return aligned ? mmp->a_l : mmp->l;
3470
+ }
3471
+ if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) {
3472
+ return aligned ? mmp->a_m : mmp->m;
3473
+ }
3474
+ return aligned ? mmp->a_s : mmp->s;
3475
+ }
3476
+
3477
+ if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) {
3478
+ return aligned ? mmp->a_s : mmp->s;
3479
+ }
3480
+ if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) {
3481
+ return aligned ? mmp->a_m : mmp->m;
3482
+ }
3483
+ return aligned ? mmp->a_l : mmp->l;
3484
+ }
3485
+
3486
+ static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
3487
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
3488
+ return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align;
3489
+ }
3490
+
3491
  static void ggml_vk_matmul_id(
3492
  ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
3493
  vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
 
3637
  const int y_ne = ne11 * ne10;
3638
  const int d_ne = ne11 * ne01;
3639
 
3640
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, src0->type));
3641
  const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
3642
 
3643
+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, src0->type);
3644
 
3645
  const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
3646
 
 
4191
 
4192
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
4193
 
4194
+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
4195
 
4196
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
4197
  const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
 
4207
  const uint64_t y_ne = ne11 * ne10;
4208
  const uint64_t d_ne = ne21 * ne20;
4209
 
4210
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1));
4211
  const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
4212
 
4213
+ vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned);
4214
 
4215
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
4216
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
 
5791
  for (size_t i = 0; i < x_ne; i++) {
5792
  if (std::is_same<float, X_TYPE>()) {
5793
  x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
5794
+ // x[i] = 1.0f;
5795
+ // x[i] = i + 1;
5796
+ // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
5797
  } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
5798
  x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
5799
+ // x[i] = ggml_fp32_to_fp16(1.0f);
5800
+ // x[i] = ggml_fp32_to_fp16(i + 1);
5801
+ // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
5802
  } else {
5803
  GGML_ABORT("fatal error");
5804
  }
5805
  }
5806
  for (size_t i = 0; i < y_ne; i++) {
5807
  if (std::is_same<float, Y_TYPE>()) {
5808
+ y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
5809
+ // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
5810
+ // y[i] = i + 1;
5811
  } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
5812
+ y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
5813
+ // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
5814
+ // y[i] = ggml_fp32_to_fp16(i + 1);
5815
  } else {
5816
  GGML_ABORT("fatal error");
5817
  }
 
5895
  double err = std::fabs(d[i] - d_chk[i]);
5896
  avg_err += err;
5897
 
5898
+ if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
5899
  first_err_b = i / (m * n);
5900
  first_err_n = (i % (m * n)) / m;
5901
  first_err_m = (i % (m * n)) % m;
 
5908
 
5909
  std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
5910
 
5911
+ if (avg_err > 0.1 || std::isnan(avg_err)) {
5912
  std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
5913
  std::cerr << "Actual result: " << std::endl << std::endl;
5914
  ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
 
 
5915
  std::cerr << "Expected result: " << std::endl << std::endl;
5916
  ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
5917
 
 
6094
  vk_pipeline p;
6095
  std::string shname;
6096
  if (shader_size == 0) {
6097
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
6098
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
6099
  } else if (shader_size == 1) {
6100
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
6101
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
6102
  } else if (shader_size == 2) {
6103
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
6104
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
6105
  } else {
6106
  GGML_ASSERT(0);
 
6110
 
6111
  if (k != kpad) {
6112
  if (shader_size == 0) {
6113
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
6114
  shname = std::string(ggml_type_name(quant)) + "_S";
6115
  } else if (shader_size == 1) {
6116
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
6117
  shname = std::string(ggml_type_name(quant)) + "_M";
6118
  } else if (shader_size == 2) {
6119
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
6120
  shname = std::string(ggml_type_name(quant)) + "_L";
6121
  } else {
6122
  GGML_ASSERT(0);
 
6275
 
6276
  static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
6277
  #if defined(GGML_VULKAN_RUN_TESTS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6278
  const std::vector<size_t> vals {
6279
+ 512, 512, 128,
6280
+ 128, 512, 512,
6281
+ 4096, 512, 4096,
6282
+ 11008, 512, 4096,
6283
+ 4096, 512, 11008,
6284
+ 32000, 512, 4096,
6285
  8, 8, 8,
6286
  100, 46, 576,
6287
  623, 111, 128,
 
6294
  49, 49, 128,
6295
  128, 49, 49,
6296
  4096, 49, 4096,
 
 
 
 
 
 
 
 
 
6297
  };
6298
  const size_t num_it = 100;
6299
 
 
6301
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
6302
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
6303
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
6304
+ std::cerr << '\n';
6305
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0);
6306
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1);
6307
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2);
6308
+ std::cerr << '\n';
6309
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
6310
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
6311
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
6312
+ std::cerr << '\n' << std::endl;
6313
+
6314
+ if (vals[i + 2] % 32 == 0) {
6315
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0);
6316
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0);
6317
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0);
6318
+ std::cerr << '\n';
6319
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0);
6320
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0);
6321
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0);
6322
+ std::cerr << '\n';
6323
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0);
6324
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0);
6325
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0);
6326
+ std::cerr << '\n' << std::endl;
6327
+ }
6328
+
6329
+ if (vals[i + 2] % 256 == 0) {
6330
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K);
6331
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K);
6332
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K);
6333
+ std::cerr << '\n';
6334
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K);
6335
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K);
6336
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K);
6337
+ std::cerr << '\n';
6338
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K);
6339
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K);
6340
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K);
6341
+ std::cerr << '\n' << std::endl;
6342
+ }
6343
  }
6344
 
6345
  GGML_ABORT("fatal error");
 
7427
  case GGML_OP_MUL_MAT_ID:
7428
  {
7429
  ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
7430
+ const vk_device& device = ggml_vk_get_device(ctx->device);
7431
+ if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) {
7432
  // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
7433
  return false;
7434
  }
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp CHANGED
@@ -7,6 +7,12 @@
7
  #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
8
  #endif
9
 
 
 
 
 
 
 
10
  #ifdef MUL_MAT_ID
11
  #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
12
  #endif
@@ -57,6 +63,7 @@ layout (push_constant) uniform parameter
57
  #endif
58
  } p;
59
 
 
60
  layout (constant_id = 1) const uint BM = 64;
61
  layout (constant_id = 2) const uint BN = 64;
62
  layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
@@ -65,13 +72,26 @@ layout (constant_id = 5) const uint WN = 32;
65
  layout (constant_id = 6) const uint WMITER = 2;
66
  layout (constant_id = 7) const uint TM = 4;
67
  layout (constant_id = 8) const uint TN = 2;
68
- layout (constant_id = 9) const uint WARP = 32;
 
 
 
 
 
 
 
69
 
70
- shared FLOAT_TYPE buf_a[BM * (BK+1)];
71
- shared FLOAT_TYPE buf_b[BN * (BK+1)];
72
 
73
  #ifdef MUL_MAT_ID
74
  shared u16vec2 row_ids[3072];
 
 
 
 
 
 
75
  #endif
76
 
77
  void main() {
@@ -98,17 +118,32 @@ void main() {
98
  const uint ik = gl_WorkGroupID.x / blocks_m;
99
  const uint ic = gl_WorkGroupID.y;
100
 
101
- const uint warp_i = gl_LocalInvocationID.x / WARP;
102
- const uint warp_r = warp_i % (BM / WM);
103
- const uint warp_c = warp_i / (BM / WM);
104
-
105
  const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
106
  const uint WSUBM = WM / WMITER;
107
  const uint WSUBN = WN / WNITER;
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  const uint tiw = gl_LocalInvocationID.x % WARP;
 
110
  const uint tiwr = tiw % (WSUBM / TM);
111
  const uint tiwc = tiw / (WSUBM / TM);
 
 
 
 
112
 
113
  const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
114
  const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
@@ -156,21 +191,31 @@ void main() {
156
  uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
157
  #endif
158
 
159
- float sums[WMITER * TM * WNITER * TN];
 
 
 
 
 
 
 
 
 
160
  FLOAT_TYPE cache_a[WMITER * TM];
161
  FLOAT_TYPE cache_b[WNITER * TN];
162
 
163
  [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
164
- sums[i] = 0.0f;
165
  }
 
166
 
167
- [[unroll]] for (uint block = start_k; block < end_k; block += BK) {
168
  [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
169
 
170
  #if defined(DATA_A_F32) || defined(DATA_A_F16)
171
  #if LOAD_VEC_A == 8
172
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
173
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
174
  buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
175
  buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
176
  buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
@@ -181,21 +226,21 @@ void main() {
181
  buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
182
  #elif LOAD_VEC_A == 4
183
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
184
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
185
  buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
186
  buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
187
  buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
188
  buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
189
  #else
190
  if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
191
- buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
192
  } else {
193
- buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
194
  }
195
  #endif
196
  #elif defined(DATA_A_Q4_0)
197
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
198
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
199
 
200
  const uint ib = idx / 16;
201
  const uint iqs = idx & 0xF;
@@ -208,7 +253,7 @@ void main() {
208
  buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
209
  #elif defined(DATA_A_Q4_1)
210
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
211
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
212
 
213
  const uint ib = idx / 16;
214
  const uint iqs = idx & 0xF;
@@ -222,7 +267,7 @@ void main() {
222
  buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
223
  #elif defined(DATA_A_Q5_0)
224
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
225
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
226
 
227
  const uint ib = idx / 16;
228
  const uint iqs = idx & 0xF;
@@ -237,7 +282,7 @@ void main() {
237
  buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
238
  #elif defined(DATA_A_Q5_1)
239
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
240
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
241
 
242
  const uint ib = idx / 16;
243
  const uint iqs = idx & 0xF;
@@ -253,7 +298,7 @@ void main() {
253
  buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
254
  #elif defined(DATA_A_Q8_0)
255
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
256
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
257
 
258
  const uint ib = idx / 16;
259
  const uint iqs = (idx & 0xF) * 2;
@@ -265,7 +310,7 @@ void main() {
265
  buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
266
  #elif defined(DATA_A_Q2_K)
267
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
268
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
269
 
270
  const uint ib = idx / 128; // 2 values per idx
271
  const uint iqs = idx % 128; // 0..127
@@ -284,7 +329,7 @@ void main() {
284
  buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
285
  #elif defined(DATA_A_Q3_K)
286
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
287
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
288
 
289
  const uint ib = idx / 128; // 2 values per idx
290
  const uint iqs = idx % 128; // 0..127
@@ -308,7 +353,7 @@ void main() {
308
  buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
309
  #elif defined(DATA_A_Q4_K)
310
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
311
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
312
 
313
  const uint ib = idx / 128; // 2 values per idx
314
  const uint iqs = idx % 128; // 0..127
@@ -320,15 +365,20 @@ void main() {
320
 
321
  const vec2 loadd = vec2(data_a[ib].d);
322
 
323
- uint8_t sc;
324
- uint8_t mbyte;
325
- if (is < 4) {
326
- sc = uint8_t(data_a[ib].scales[is ] & 63);
327
- mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
328
- } else {
329
- sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
330
- mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
331
- }
 
 
 
 
 
332
  const float d = loadd.x * sc;
333
  const float m = -loadd.y * mbyte;
334
 
@@ -336,7 +386,7 @@ void main() {
336
  buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
337
  #elif defined(DATA_A_Q5_K)
338
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
339
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
340
 
341
  const uint ib = idx / 128; // 2 values per idx
342
  const uint iqs = idx % 128; // 0..127
@@ -351,15 +401,20 @@ void main() {
351
 
352
  const vec2 loadd = vec2(data_a[ib].d);
353
 
354
- uint8_t sc;
355
- uint8_t mbyte;
356
- if (is < 4) {
357
- sc = uint8_t(data_a[ib].scales[is ] & 63);
358
- mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
359
- } else {
360
- sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
361
- mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
362
- }
 
 
 
 
 
363
  const float d = loadd.x * sc;
364
  const float m = -loadd.y * mbyte;
365
 
@@ -367,7 +422,7 @@ void main() {
367
  buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
368
  #elif defined(DATA_A_Q6_K)
369
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
370
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
371
 
372
  const uint ib = idx / 128; // 2 values per idx
373
  const uint iqs = idx % 128; // 0..127
@@ -386,7 +441,7 @@ void main() {
386
  buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
387
  #elif defined(DATA_A_IQ4_NL)
388
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
389
- const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
390
 
391
  const uint ib = idx / 16;
392
  const uint iqs = idx & 0xF;
@@ -407,7 +462,7 @@ void main() {
407
  #else
408
  const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
409
  #endif
410
- const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
411
  buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
412
  buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
413
  buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
@@ -423,24 +478,24 @@ void main() {
423
  #else
424
  const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
425
  #endif
426
- const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
427
  buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
428
  buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
429
  buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
430
  buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
431
  #elif !MUL_MAT_ID
432
  if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
433
- buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
434
  } else {
435
- buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
436
  }
437
  #else
438
  const uint row_i = ic * BN + loadc_b + l;
439
  if (row_i < _ne1) {
440
  const u16vec2 row_idx = row_ids[row_i];
441
- buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
442
  } else {
443
- buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
444
  }
445
  #endif
446
  }
@@ -450,16 +505,30 @@ void main() {
450
  pos_a += BK / LOAD_VEC_A;
451
  pos_b += BK / LOAD_VEC_B;
452
 
453
- for (uint i = 0; i < BK; i++) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  // Load from shared into cache
455
  [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
456
  [[unroll]] for (uint j = 0; j < TM; j++) {
457
- cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
458
  }
459
  }
460
  [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
461
  [[unroll]] for (uint j = 0; j < TN; j++) {
462
- cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
463
  }
464
  }
465
 
@@ -468,12 +537,13 @@ void main() {
468
  [[unroll]] for (uint cc = 0; cc < TN; cc++) {
469
  [[unroll]] for (uint cr = 0; cr < TM; cr++) {
470
  const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
471
- sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]);
472
  }
473
  }
474
  }
475
  }
476
  }
 
477
 
478
  barrier();
479
  }
@@ -485,6 +555,54 @@ void main() {
485
  const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
486
  #endif
487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
489
  [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
490
 
@@ -496,7 +614,7 @@ void main() {
496
  if (row_i >= _ne1) break;
497
 
498
  const u16vec2 row_idx = row_ids[row_i];
499
- #endif
500
  [[unroll]] for (uint cr = 0; cr < TM; cr++) {
501
  #ifdef MUL_MAT_ID
502
  data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
@@ -504,9 +622,10 @@ void main() {
504
  if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
505
  data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
506
  }
507
- #endif
508
  }
509
  }
510
  }
511
  }
 
512
  }
 
7
  #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
8
  #endif
9
 
10
+ #ifdef COOPMAT
11
+ #extension GL_KHR_cooperative_matrix : enable
12
+ #extension GL_KHR_memory_scope_semantics : enable
13
+ #extension GL_KHR_shader_subgroup_basic : enable
14
+ #endif
15
+
16
  #ifdef MUL_MAT_ID
17
  #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
18
  #endif
 
63
  #endif
64
  } p;
65
 
66
+ layout (constant_id = 0) const uint BLOCK_SIZE = 64;
67
  layout (constant_id = 1) const uint BM = 64;
68
  layout (constant_id = 2) const uint BN = 64;
69
  layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
 
72
  layout (constant_id = 6) const uint WMITER = 2;
73
  layout (constant_id = 7) const uint TM = 4;
74
  layout (constant_id = 8) const uint TN = 2;
75
+ layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
76
+ layout (constant_id = 10) const uint WARP = 32;
77
+
78
+ #ifdef COOPMAT
79
+ #define SHMEM_STRIDE (BK + 8)
80
+ #else
81
+ #define SHMEM_STRIDE (BK + 1)
82
+ #endif
83
 
84
+ shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
85
+ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
86
 
87
  #ifdef MUL_MAT_ID
88
  shared u16vec2 row_ids[3072];
89
+ #endif // MUL_MAT_ID
90
+
91
+ #define NUM_WARPS (BLOCK_SIZE / WARP)
92
+
93
+ #ifdef COOPMAT
94
+ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
95
  #endif
96
 
97
  void main() {
 
118
  const uint ik = gl_WorkGroupID.x / blocks_m;
119
  const uint ic = gl_WorkGroupID.y;
120
 
 
 
 
 
121
  const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
122
  const uint WSUBM = WM / WMITER;
123
  const uint WSUBN = WN / WNITER;
124
 
125
+ #ifdef COOPMAT
126
+ const uint warp_i = gl_SubgroupID;
127
+
128
+ const uint tiw = gl_SubgroupInvocationID;
129
+
130
+ const uint cms_per_row = WM / TM;
131
+ const uint cms_per_col = WN / TN;
132
+
133
+ const uint storestride = WARP / TM;
134
+ const uint store_r = tiw % TM;
135
+ const uint store_c = tiw / TM;
136
+ #else
137
+ const uint warp_i = gl_LocalInvocationID.x / WARP;
138
+
139
  const uint tiw = gl_LocalInvocationID.x % WARP;
140
+
141
  const uint tiwr = tiw % (WSUBM / TM);
142
  const uint tiwc = tiw / (WSUBM / TM);
143
+ #endif
144
+
145
+ const uint warp_r = warp_i % (BM / WM);
146
+ const uint warp_c = warp_i / (BM / WM);
147
 
148
  const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
149
  const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
 
191
  uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
192
  #endif
193
 
194
+ #ifdef COOPMAT
195
+ coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
196
+ coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
197
+ coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
198
+
199
+ [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
200
+ sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
201
+ }
202
+ #else
203
+ ACC_TYPE sums[WMITER * TM * WNITER * TN];
204
  FLOAT_TYPE cache_a[WMITER * TM];
205
  FLOAT_TYPE cache_b[WNITER * TN];
206
 
207
  [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
208
+ sums[i] = ACC_TYPE(0.0f);
209
  }
210
+ #endif
211
 
212
+ for (uint block = start_k; block < end_k; block += BK) {
213
  [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
214
 
215
  #if defined(DATA_A_F32) || defined(DATA_A_F16)
216
  #if LOAD_VEC_A == 8
217
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
218
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
219
  buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
220
  buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
221
  buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
 
226
  buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
227
  #elif LOAD_VEC_A == 4
228
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
229
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
230
  buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
231
  buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
232
  buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
233
  buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
234
  #else
235
  if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
236
+ buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
237
  } else {
238
+ buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
239
  }
240
  #endif
241
  #elif defined(DATA_A_Q4_0)
242
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
243
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
244
 
245
  const uint ib = idx / 16;
246
  const uint iqs = idx & 0xF;
 
253
  buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
254
  #elif defined(DATA_A_Q4_1)
255
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
256
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
257
 
258
  const uint ib = idx / 16;
259
  const uint iqs = idx & 0xF;
 
267
  buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
268
  #elif defined(DATA_A_Q5_0)
269
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
270
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
271
 
272
  const uint ib = idx / 16;
273
  const uint iqs = idx & 0xF;
 
282
  buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
283
  #elif defined(DATA_A_Q5_1)
284
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
285
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
286
 
287
  const uint ib = idx / 16;
288
  const uint iqs = idx & 0xF;
 
298
  buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
299
  #elif defined(DATA_A_Q8_0)
300
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
301
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
302
 
303
  const uint ib = idx / 16;
304
  const uint iqs = (idx & 0xF) * 2;
 
310
  buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
311
  #elif defined(DATA_A_Q2_K)
312
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
313
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
314
 
315
  const uint ib = idx / 128; // 2 values per idx
316
  const uint iqs = idx % 128; // 0..127
 
329
  buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
330
  #elif defined(DATA_A_Q3_K)
331
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
332
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
333
 
334
  const uint ib = idx / 128; // 2 values per idx
335
  const uint iqs = idx % 128; // 0..127
 
353
  buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
354
  #elif defined(DATA_A_Q4_K)
355
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
356
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
357
 
358
  const uint ib = idx / 128; // 2 values per idx
359
  const uint iqs = idx % 128; // 0..127
 
365
 
366
  const vec2 loadd = vec2(data_a[ib].d);
367
 
368
+ const uint scidx0 = (is < 4) ? is : (is + 4);
369
+ const uint scidx1 = (is < 4) ? is : (is - 4);
370
+ const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
371
+ const uint scidxshift1 = (is < 4) ? 0 : 2;
372
+ const uint mbidx0 = is + 4;
373
+ const uint mbidx1 = (is < 4) ? is + 4 : is;
374
+ const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
375
+ const uint mbidxshift0 = (is < 4) ? 0 : 4;
376
+ const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
377
+ const uint mbidxshift1 = (is < 4) ? 0 : 2;
378
+
379
+ const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
380
+ const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
381
+
382
  const float d = loadd.x * sc;
383
  const float m = -loadd.y * mbyte;
384
 
 
386
  buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
387
  #elif defined(DATA_A_Q5_K)
388
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
389
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
390
 
391
  const uint ib = idx / 128; // 2 values per idx
392
  const uint iqs = idx % 128; // 0..127
 
401
 
402
  const vec2 loadd = vec2(data_a[ib].d);
403
 
404
+ const uint scidx0 = (is < 4) ? is : (is + 4);
405
+ const uint scidx1 = (is < 4) ? is : (is - 4);
406
+ const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
407
+ const uint scidxshift1 = (is < 4) ? 0 : 2;
408
+ const uint mbidx0 = is + 4;
409
+ const uint mbidx1 = (is < 4) ? is + 4 : is;
410
+ const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
411
+ const uint mbidxshift0 = (is < 4) ? 0 : 4;
412
+ const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
413
+ const uint mbidxshift1 = (is < 4) ? 0 : 2;
414
+
415
+ const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
416
+ const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
417
+
418
  const float d = loadd.x * sc;
419
  const float m = -loadd.y * mbyte;
420
 
 
422
  buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
423
  #elif defined(DATA_A_Q6_K)
424
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
425
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
426
 
427
  const uint ib = idx / 128; // 2 values per idx
428
  const uint iqs = idx % 128; // 0..127
 
441
  buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
442
  #elif defined(DATA_A_IQ4_NL)
443
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
444
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
445
 
446
  const uint ib = idx / 16;
447
  const uint iqs = idx & 0xF;
 
462
  #else
463
  const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
464
  #endif
465
+ const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
466
  buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
467
  buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
468
  buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
 
478
  #else
479
  const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
480
  #endif
481
+ const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
482
  buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
483
  buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
484
  buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
485
  buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
486
  #elif !MUL_MAT_ID
487
  if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
488
+ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
489
  } else {
490
+ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
491
  }
492
  #else
493
  const uint row_i = ic * BN + loadc_b + l;
494
  if (row_i < _ne1) {
495
  const u16vec2 row_idx = row_ids[row_i];
496
+ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
497
  } else {
498
+ buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
499
  }
500
  #endif
501
  }
 
505
  pos_a += BK / LOAD_VEC_A;
506
  pos_b += BK / LOAD_VEC_B;
507
 
508
+ #ifdef COOPMAT
509
+ [[unroll]] for (uint i = 0; i < BK; i += TK) {
510
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
511
+ // Load from shared into cache
512
+ coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
513
+
514
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
515
+ coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
516
+
517
+ sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
518
+ }
519
+ }
520
+ }
521
+ #else
522
+ [[unroll]] for (uint i = 0; i < BK; i++) {
523
  // Load from shared into cache
524
  [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
525
  [[unroll]] for (uint j = 0; j < TM; j++) {
526
+ cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
527
  }
528
  }
529
  [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
530
  [[unroll]] for (uint j = 0; j < TN; j++) {
531
+ cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
532
  }
533
  }
534
 
 
537
  [[unroll]] for (uint cc = 0; cc < TN; cc++) {
538
  [[unroll]] for (uint cr = 0; cr < TM; cr++) {
539
  const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
540
+ sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]);
541
  }
542
  }
543
  }
544
  }
545
  }
546
+ #endif
547
 
548
  barrier();
549
  }
 
555
  const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
556
  #endif
557
 
558
+ #ifdef COOPMAT
559
+ #ifdef MUL_MAT_ID
560
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
561
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
562
+ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
563
+
564
+ [[unroll]] for (uint col = 0; col < BN; col += storestride) {
565
+ const uint row_i = dc + cm_col * TN + col + store_c;
566
+ if (row_i >= _ne1) break;
567
+
568
+ const u16vec2 row_idx = row_ids[row_i];
569
+
570
+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
571
+ }
572
+ }
573
+ }
574
+ #else
575
+ const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
576
+
577
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
578
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
579
+ const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
580
+
581
+ if (is_aligned && is_in_bounds) {
582
+ // Full coopMat is within bounds and stride_d is aligned with 16B
583
+ coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
584
+ coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
585
+ } else if (is_in_bounds) {
586
+ // Full coopMat is within bounds, but stride_d is not aligned
587
+ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
588
+
589
+ [[unroll]] for (uint col = 0; col < TN; col += storestride) {
590
+ data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
591
+ }
592
+ } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
593
+ // Partial coopMat is within bounds
594
+ coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
595
+
596
+ [[unroll]] for (uint col = 0; col < TN; col += storestride) {
597
+ if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
598
+ data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
599
+ }
600
+ }
601
+ }
602
+ }
603
+ }
604
+ #endif // MUL_MAT_ID
605
+ #else
606
  [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
607
  [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
608
 
 
614
  if (row_i >= _ne1) break;
615
 
616
  const u16vec2 row_idx = row_ids[row_i];
617
+ #endif // MUL_MAT_ID
618
  [[unroll]] for (uint cr = 0; cr < TM; cr++) {
619
  #ifdef MUL_MAT_ID
620
  data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
 
622
  if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
623
  data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
624
  }
625
+ #endif // MUL_MAT_ID
626
  }
627
  }
628
  }
629
  }
630
+ #endif // COOPMAT
631
  }
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -60,6 +60,7 @@ const std::vector<std::string> type_names = {
60
  "iq4_nl"
61
  };
62
 
 
63
  void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
64
  #ifdef _WIN32
65
  HANDLE stdout_read, stdout_write;
@@ -198,8 +199,8 @@ static uint32_t compile_count = 0;
198
  static std::mutex compile_count_mutex;
199
  static std::condition_variable compile_count_cond;
200
 
201
- void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) {
202
- std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
203
  std::string out_fname = join_paths(output_dir, name + ".spv");
204
  std::string in_path = join_paths(input_dir, in_fname);
205
 
@@ -258,7 +259,7 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
258
  }
259
 
260
  static std::vector<std::future<void>> compiles;
261
- void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) {
262
  {
263
  // wait until fewer than N compiles are in progress.
264
  // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
@@ -269,10 +270,10 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
269
  }
270
  compile_count++;
271
  }
272
- compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat2, f16acc));
273
  }
274
 
275
- void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
276
  std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
277
  std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
278
  std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
@@ -291,14 +292,20 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
291
 
292
  base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
293
 
 
 
 
 
 
 
294
  std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
295
 
296
  // Shaders with f16 B_TYPE
297
- string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat2, f16acc);
298
- string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
299
 
300
- string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
301
- string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat2, f16acc);
302
 
303
  for (const auto& tname : type_names) {
304
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
@@ -307,12 +314,12 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
307
  // For aligned matmul loads
308
  std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
309
 
310
- string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc);
311
- string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
312
 
313
  if (tname != "f16" && tname != "f32") {
314
- string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc);
315
- string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
316
  }
317
  }
318
  }
@@ -322,25 +329,24 @@ void process_shaders() {
322
  std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
323
 
324
  // matmul
325
- for (const auto& fp16 : {false, true}) {
326
- for (const auto& matmul_id : {false, true}) {
327
- for (const auto& coopmat2 : {false, true}) {
328
- for (const auto& f16acc : {false, true}) {
329
- #if !defined(VK_NV_cooperative_matrix2)
330
- if (coopmat2) {
331
- continue;
332
- }
 
 
 
 
 
 
 
 
 
333
  #endif
334
- if (coopmat2 && !fp16) {
335
- continue;
336
- }
337
- if (!coopmat2 && f16acc) {
338
- continue;
339
- }
340
- matmul_shaders(fp16, matmul_id, coopmat2, f16acc);
341
- }
342
- }
343
- }
344
  }
345
 
346
  #if defined(VK_NV_cooperative_matrix2)
@@ -355,11 +361,11 @@ void process_shaders() {
355
 
356
  if (tname == "f16") {
357
  string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
358
- merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, true, f16acc);
359
  } else {
360
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
361
  string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
362
- merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, true, f16acc);
363
  }
364
  }
365
  }
@@ -524,6 +530,7 @@ void write_output_files() {
524
  fclose(hdr);
525
  fclose(src);
526
  }
 
527
 
528
  int main(int argc, char** argv) {
529
  std::map<std::string, std::string> args;
 
60
  "iq4_nl"
61
  };
62
 
63
+ namespace {
64
  void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
65
  #ifdef _WIN32
66
  HANDLE stdout_read, stdout_write;
 
199
  static std::mutex compile_count_mutex;
200
  static std::condition_variable compile_count_cond;
201
 
202
+ void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
203
+ std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
204
  std::string out_fname = join_paths(output_dir, name + ".spv");
205
  std::string in_path = join_paths(input_dir, in_fname);
206
 
 
259
  }
260
 
261
  static std::vector<std::future<void>> compiles;
262
+ void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
263
  {
264
  // wait until fewer than N compiles are in progress.
265
  // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
 
270
  }
271
  compile_count++;
272
  }
273
+ compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
274
  }
275
 
276
+ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
277
  std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
278
  std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
279
  std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
 
292
 
293
  base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
294
 
295
+ if (coopmat) {
296
+ base_dict["COOPMAT"] = "1";
297
+ }
298
+
299
+ base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
300
+
301
  std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
302
 
303
  // Shaders with f16 B_TYPE
304
+ string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
305
+ string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
306
 
307
+ string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
308
+ string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
309
 
310
  for (const auto& tname : type_names) {
311
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
 
314
  // For aligned matmul loads
315
  std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
316
 
317
+ string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
318
+ string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
319
 
320
  if (tname != "f16" && tname != "f32") {
321
+ string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
322
+ string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
323
  }
324
  }
325
  }
 
329
  std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
330
 
331
  // matmul
332
+ for (const auto& matmul_id : {false, true}) {
333
+ // No coopmats
334
+ // fp32
335
+ matmul_shaders(false, matmul_id, false, false, false);
336
+
337
+ // fp16, fp32acc and fp16acc
338
+ matmul_shaders(true, matmul_id, false, false, false);
339
+ matmul_shaders(true, matmul_id, false, false, true);
340
+
341
+ // Coopmat, fp32acc and fp16acc
342
+ matmul_shaders(true, matmul_id, true, false, false);
343
+ matmul_shaders(true, matmul_id, true, false, true);
344
+
345
+ #if defined(VK_NV_cooperative_matrix2)
346
+ // Coopmat2, fp32acc and fp16acc
347
+ matmul_shaders(true, matmul_id, false, true, false);
348
+ matmul_shaders(true, matmul_id, false, true, true);
349
  #endif
 
 
 
 
 
 
 
 
 
 
350
  }
351
 
352
  #if defined(VK_NV_cooperative_matrix2)
 
361
 
362
  if (tname == "f16") {
363
  string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
364
+ merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
365
  } else {
366
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
367
  string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
368
+ merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
369
  }
370
  }
371
  }
 
530
  fclose(hdr);
531
  fclose(src);
532
  }
533
+ }
534
 
535
  int main(int argc, char** argv) {
536
  std::map<std::string, std::string> args;