Diego Devesa commited on
Commit
e59d9a7
·
1 Parent(s): 464a186

ggml-cpu : add chunking support to mul_mat_id (llama/11666)

Browse files

* ggml-cpu : add chunking support to mul_mat_id

* allocate chunk counter in wdata
parallelize src1 quantization by column to allows parallelization even when there is only one row

* disable for arm

* cleanup

* better way to disable for arm

* fix uninitialized counter when using 1 thread only

* revert test-backend-ops changes

Files changed (1) hide show
  1. ggml/src/ggml-cpu/ggml-cpu.c +184 -85
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -7,10 +7,8 @@
7
  #include "ggml-cpu-impl.h"
8
  #include "ggml-cpu.h"
9
  #include "ggml-impl.h"
10
- #include "ggml-quants.h"
11
  #include "ggml-cpu-quants.h"
12
  #include "ggml-threading.h"
13
- #include "amx/amx.h"
14
  #include "ggml.h"
15
 
16
  #if defined(_MSC_VER) || defined(__MINGW32__)
@@ -1291,7 +1289,7 @@ struct ggml_threadpool {
1291
  atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
1292
  atomic_int GGML_CACHE_ALIGN n_barrier;
1293
  atomic_int GGML_CACHE_ALIGN n_barrier_passed;
1294
- atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
1295
 
1296
  // these are atomic as an annotation for thread-sanitizer
1297
  atomic_bool stop; // Used for stopping the threadpool altogether
@@ -7490,6 +7488,7 @@ UseGgmlGemm1:;
7490
  if (src1->type != vec_dot_type) {
7491
  char * wdata = params->wdata;
7492
 
 
7493
  const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
7494
  const size_t nbw2 = nbw1*ne11;
7495
  const size_t nbw3 = nbw2*ne12;
@@ -7497,6 +7496,7 @@ UseGgmlGemm1:;
7497
  assert(params->wsize >= ne13*nbw3);
7498
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
7499
 
 
7500
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
7501
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
7502
  for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
@@ -7506,6 +7506,20 @@ UseGgmlGemm1:;
7506
  }
7507
  }
7508
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7509
  }
7510
 
7511
  if (ith == 0) {
@@ -7593,7 +7607,6 @@ UseGgmlGemm2:;
7593
  if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
7594
  num_rows_per_vec_dot = 1;
7595
  }
7596
-
7597
  ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
7598
 
7599
  if (nth >= nchunk0 * nchunk1) {
@@ -7606,6 +7619,84 @@ UseGgmlGemm2:;
7606
 
7607
  // ggml_compute_forward_mul_mat_id
7608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7609
  static void ggml_compute_forward_mul_mat_id(
7610
  const struct ggml_compute_params * params,
7611
  struct ggml_tensor * dst) {
@@ -7623,7 +7714,6 @@ static void ggml_compute_forward_mul_mat_id(
7623
 
7624
  const bool src1_cont = ggml_is_contiguous(src1);
7625
 
7626
- ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
7627
  enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
7628
  ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
7629
 
@@ -7641,21 +7731,27 @@ static void ggml_compute_forward_mul_mat_id(
7641
  const int n_ids = ids->ne[0]; // n_expert_used
7642
  const int n_as = ne02; // n_expert
7643
 
7644
- char * wdata_src1_end = (src1->type == vec_dot_type) ?
7645
- (char *) params->wdata :
7646
- (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
7647
 
7648
- struct mmid_row_mapping {
7649
- int32_t i1;
7650
- int32_t i2;
7651
- };
 
 
 
 
 
7652
 
7653
- int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
7654
- struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
 
 
7655
 
7656
  if (src1->type != vec_dot_type) {
7657
  char * wdata = params->wdata;
7658
 
 
7659
  const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
7660
  const size_t nbw2 = nbw1*ne11;
7661
  const size_t nbw3 = nbw2*ne12;
@@ -7663,19 +7759,32 @@ static void ggml_compute_forward_mul_mat_id(
7663
  assert(params->wsize >= ne13*nbw3);
7664
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
7665
 
 
7666
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
7667
- for (int64_t i12 = 0; i12 < ne12; ++i12) {
7668
- for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
7669
  from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
7670
  (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
7671
  ne10);
7672
  }
7673
  }
7674
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7675
  }
7676
 
7677
- #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
7678
-
7679
  if (ith == 0) {
7680
  // initialize matrix_row_counts
7681
  memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
@@ -7693,9 +7802,14 @@ static void ggml_compute_forward_mul_mat_id(
7693
  }
7694
  }
7695
 
 
 
 
 
 
 
7696
  ggml_barrier(params->threadpool);
7697
 
7698
- // compute each matrix multiplication in sequence
7699
  for (int cur_a = 0; cur_a < n_as; ++cur_a) {
7700
  const int64_t cne1 = matrix_row_counts[cur_a];
7701
 
@@ -7703,84 +7817,64 @@ static void ggml_compute_forward_mul_mat_id(
7703
  continue;
7704
  }
7705
 
7706
- const char * src0_cur = (const char *) src0->data + cur_a*nb02;
7707
-
7708
- const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
7709
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
7710
 
7711
- const int64_t nr0 = ne01; // src0 rows
7712
- const int64_t nr1 = cne1; // src1 rows
7713
-
7714
- // distribute the thread work across the inner or outer loop based on which one is larger
7715
-
7716
- const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
7717
- const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
7718
-
7719
- const int64_t ith0 = ith % nth0;
7720
- const int64_t ith1 = ith / nth0;
7721
-
7722
- const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
7723
- const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
7724
-
7725
- const int64_t ir010 = dr0*ith0;
7726
- const int64_t ir011 = MIN(ir010 + dr0, nr0);
7727
 
7728
- const int64_t ir110 = dr1*ith1;
7729
- const int64_t ir111 = MIN(ir110 + dr1, nr1);
7730
-
7731
- // threads with no work simply yield (not sure if it helps)
7732
- //if (ir010 >= ir011 || ir110 >= ir111) {
7733
- // sched_yield();
7734
- // continue;
7735
- //}
7736
 
7737
- // block-tiling attempt
7738
- const int64_t blck_0 = 16;
7739
- const int64_t blck_1 = 16;
 
 
 
 
7740
 
7741
- // attempt to reduce false-sharing (does not seem to make a difference)
7742
- float tmp[16];
7743
 
7744
- for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
7745
- for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
7746
- for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
7747
- const int64_t _i12 = ir1; // logical row index for this expert
7748
 
7749
- struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
7750
- const int id = row_mapping.i1; // selected expert index
7751
 
7752
- const int64_t i11 = id % ne11;
7753
- const int64_t i12 = row_mapping.i2; // row index in src1
7754
 
7755
- const int64_t i1 = id; // selected expert index
7756
- const int64_t i2 = i12; // row
7757
 
7758
- // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
7759
- // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
7760
- // the original src1 data pointer, so we should index using the indices directly
7761
- // TODO: this is a bit of a hack, we should probably have a better way to handle this
7762
- const char * src1_col = (const char *) wdata +
7763
- (src1_cont || src1->type != vec_dot_type
7764
- ? (i11 + i12*ne11)*row_size
7765
- : (i11*nb11 + i12*nb12));
7766
 
7767
- float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
 
7768
 
7769
- //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
7770
- // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
7771
- //}
7772
 
7773
- for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
7774
- vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
7775
- }
 
 
7776
 
7777
- memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
7778
- }
7779
  }
 
 
7780
  }
7781
  }
7782
-
7783
- #undef MMID_MATRIX_ROW
7784
  }
7785
 
7786
  // ggml_compute_forward_out_prod
@@ -13713,14 +13807,19 @@ struct ggml_cplan ggml_graph_plan(
13713
  cur = 0;
13714
  const struct ggml_tensor * src0 = node->src[0];
13715
  const struct ggml_tensor * src1 = node->src[1];
 
13716
  const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
 
 
13717
  if (src1->type != vec_dot_type) {
13718
- cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
13719
  }
13720
- const int n_as = src0->ne[2];
13721
- cur += GGML_PAD(cur, sizeof(int64_t)); // align
13722
- cur += n_as * sizeof(int64_t); // matrix_row_counts
13723
- cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
 
 
13724
  } break;
13725
  case GGML_OP_OUT_PROD:
13726
  {
 
7
  #include "ggml-cpu-impl.h"
8
  #include "ggml-cpu.h"
9
  #include "ggml-impl.h"
 
10
  #include "ggml-cpu-quants.h"
11
  #include "ggml-threading.h"
 
12
  #include "ggml.h"
13
 
14
  #if defined(_MSC_VER) || defined(__MINGW32__)
 
1289
  atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
1290
  atomic_int GGML_CACHE_ALIGN n_barrier;
1291
  atomic_int GGML_CACHE_ALIGN n_barrier_passed;
1292
+ atomic_int GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
1293
 
1294
  // these are atomic as an annotation for thread-sanitizer
1295
  atomic_bool stop; // Used for stopping the threadpool altogether
 
7488
  if (src1->type != vec_dot_type) {
7489
  char * wdata = params->wdata;
7490
 
7491
+ const size_t nbw0 = ggml_type_size(vec_dot_type);
7492
  const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
7493
  const size_t nbw2 = nbw1*ne11;
7494
  const size_t nbw3 = nbw2*ne12;
 
7496
  assert(params->wsize >= ne13*nbw3);
7497
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
7498
 
7499
+ #if 0
7500
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
7501
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
7502
  for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
 
7506
  }
7507
  }
7508
  }
7509
+ #else
7510
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
7511
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
7512
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
7513
+ size_t bs = ggml_blck_size(vec_dot_type);
7514
+ int64_t ne10_block_start = (ith * ne10/bs) / nth;
7515
+ int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
7516
+ from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
7517
+ (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
7518
+ (ne10_block_end - ne10_block_start) * bs);
7519
+ }
7520
+ }
7521
+ }
7522
+ #endif
7523
  }
7524
 
7525
  if (ith == 0) {
 
7607
  if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
7608
  num_rows_per_vec_dot = 1;
7609
  }
 
7610
  ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
7611
 
7612
  if (nth >= nchunk0 * nchunk1) {
 
7619
 
7620
  // ggml_compute_forward_mul_mat_id
7621
 
7622
+ #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)]
7623
+
7624
+ struct mmid_row_mapping {
7625
+ int32_t i1;
7626
+ int32_t i2;
7627
+ };
7628
+
7629
+ static void ggml_compute_forward_mul_mat_id_one_chunk(
7630
+ struct ggml_tensor * dst,
7631
+ const struct ggml_tensor * src0,
7632
+ const struct ggml_tensor * src1,
7633
+ const struct ggml_tensor * ids,
7634
+ const int64_t cur_a,
7635
+ const int64_t ir0_start,
7636
+ const int64_t ir0_end,
7637
+ const int64_t ir1_start,
7638
+ const int64_t ir1_end,
7639
+ const char * src0_cur,
7640
+ const struct mmid_row_mapping * matrix_rows,
7641
+ const size_t row_size,
7642
+ const bool src1_cont,
7643
+ const void * wdata) {
7644
+
7645
+ GGML_TENSOR_BINARY_OP_LOCALS
7646
+
7647
+ const enum ggml_type type = src0->type;
7648
+
7649
+ ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
7650
+ enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
7651
+
7652
+ const int64_t blck_0 = 16;
7653
+ const int64_t blck_1 = 16;
7654
+
7655
+ float tmp[16];
7656
+
7657
+ for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
7658
+ for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
7659
+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) {
7660
+ const int64_t _i12 = ir1; // logical row index for this expert
7661
+
7662
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
7663
+ const int id = row_mapping.i1; // selected expert index
7664
+
7665
+ const int64_t i11 = id % ne11;
7666
+ const int64_t i12 = row_mapping.i2; // row index in src1
7667
+
7668
+ const int64_t i1 = id; // selected expert index
7669
+ const int64_t i2 = i12; // row
7670
+
7671
+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
7672
+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
7673
+ // the original src1 data pointer, so we should index using the indices directly
7674
+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
7675
+ const char * src1_col = (const char *) wdata +
7676
+ (src1_cont || src1->type != vec_dot_type
7677
+ ? (i11 + i12*ne11)*row_size
7678
+ : (i11*nb11 + i12*nb12));
7679
+
7680
+ float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
7681
+
7682
+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
7683
+ vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
7684
+ }
7685
+
7686
+ memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
7687
+ }
7688
+ }
7689
+ }
7690
+ }
7691
+
7692
+ static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
7693
+
7694
+ void * ptr = *p;
7695
+ ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
7696
+ *p = (void *) ((char *) ptr + size);
7697
+ return ptr;
7698
+ }
7699
+
7700
  static void ggml_compute_forward_mul_mat_id(
7701
  const struct ggml_compute_params * params,
7702
  struct ggml_tensor * dst) {
 
7714
 
7715
  const bool src1_cont = ggml_is_contiguous(src1);
7716
 
 
7717
  enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
7718
  ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
7719
 
 
7731
  const int n_ids = ids->ne[0]; // n_expert_used
7732
  const int n_as = ne02; // n_expert
7733
 
7734
+ void * wdata_cur = params->wdata;
 
 
7735
 
7736
+ if (src1->type != vec_dot_type) {
7737
+ incr_ptr_aligned(&wdata_cur, ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
7738
+ }
7739
+
7740
+ int64_t * matrix_row_counts = // [n_as]
7741
+ incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t));
7742
+
7743
+ struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]]
7744
+ incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t));
7745
 
7746
+ char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as]
7747
+ incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE);
7748
+
7749
+ GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata));
7750
 
7751
  if (src1->type != vec_dot_type) {
7752
  char * wdata = params->wdata;
7753
 
7754
+ const size_t nbw0 = ggml_type_size(vec_dot_type);
7755
  const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
7756
  const size_t nbw2 = nbw1*ne11;
7757
  const size_t nbw3 = nbw2*ne12;
 
7759
  assert(params->wsize >= ne13*nbw3);
7760
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
7761
 
7762
+ #if 0
7763
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
7764
+ for (int64_t i12 = ith; i12 < ne12; i12 += nth) {
7765
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
7766
  from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
7767
  (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
7768
  ne10);
7769
  }
7770
  }
7771
  }
7772
+ #else
7773
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
7774
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
7775
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
7776
+ size_t bs = ggml_blck_size(vec_dot_type);
7777
+ int64_t ne10_block_start = (ith * ne10/bs) / nth;
7778
+ int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
7779
+ from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
7780
+ (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
7781
+ (ne10_block_end - ne10_block_start) * bs);
7782
+ }
7783
+ }
7784
+ }
7785
+ #endif
7786
  }
7787
 
 
 
7788
  if (ith == 0) {
7789
  // initialize matrix_row_counts
7790
  memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
 
7802
  }
7803
  }
7804
 
7805
+ // reset current_chunk
7806
+ for (int cur_a = ith; cur_a < n_as; cur_a += nth) {
7807
+ atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
7808
+ *current_chunk_ctr = nth;
7809
+ }
7810
+
7811
  ggml_barrier(params->threadpool);
7812
 
 
7813
  for (int cur_a = 0; cur_a < n_as; ++cur_a) {
7814
  const int64_t cne1 = matrix_row_counts[cur_a];
7815
 
 
7817
  continue;
7818
  }
7819
 
7820
+ const char * src0_cur = (const char *) src0->data + cur_a * nb02;
7821
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
 
7822
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
7823
 
7824
+ const int64_t nr0 = ne01;
7825
+ const int64_t nr1 = cne1;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7826
 
7827
+ int chunk_size = 16;
7828
+ if (nr0 == 1 || nr1 == 1) {
7829
+ chunk_size = 64;
7830
+ }
 
 
 
 
7831
 
7832
+ #if defined(__aarch64__)
7833
+ // disable for ARM
7834
+ const bool disable_chunking = true;
7835
+ #else
7836
+ // disable for NUMA
7837
+ const bool disable_chunking = ggml_is_numa();
7838
+ #endif // defined(__aarch64__)
7839
 
7840
+ int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
7841
+ int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
7842
 
7843
+ if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) {
7844
+ nchunk0 = nr0 > nr1 ? nth : 1;
7845
+ nchunk1 = nr0 > nr1 ? 1 : nth;
7846
+ }
7847
 
7848
+ const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
7849
+ const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
7850
 
7851
+ int current_chunk = ith;
 
7852
 
7853
+ atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
 
7854
 
7855
+ while (current_chunk < nchunk0 * nchunk1) {
7856
+ const int64_t ith0 = current_chunk % nchunk0;
7857
+ const int64_t ith1 = current_chunk / nchunk0;
 
 
 
 
 
7858
 
7859
+ const int64_t ir0_start = dr0 * ith0;
7860
+ const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
7861
 
7862
+ const int64_t ir1_start = dr1 * ith1;
7863
+ const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
 
7864
 
7865
+ ggml_compute_forward_mul_mat_id_one_chunk(
7866
+ dst, src0, src1, ids, cur_a,
7867
+ ir0_start, ir0_end, ir1_start, ir1_end,
7868
+ src0_cur, matrix_rows, row_size, src1_cont, wdata
7869
+ );
7870
 
7871
+ if (nth >= nchunk0 * nchunk1) {
7872
+ break;
7873
  }
7874
+
7875
+ current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed);
7876
  }
7877
  }
 
 
7878
  }
7879
 
7880
  // ggml_compute_forward_out_prod
 
13807
  cur = 0;
13808
  const struct ggml_tensor * src0 = node->src[0];
13809
  const struct ggml_tensor * src1 = node->src[1];
13810
+ const struct ggml_tensor * ids = node->src[2];
13811
  const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
13812
+ const int n_as = src0->ne[2];
13813
+ // src1
13814
  if (src1->type != vec_dot_type) {
13815
+ cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)) + sizeof(int64_t);
13816
  }
13817
+ // matrix_row_counts
13818
+ cur += n_as * sizeof(int64_t) + sizeof(int64_t);
13819
+ // matrix_rows
13820
+ cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t);
13821
+ // atomic_current_chunk
13822
+ cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE;
13823
  } break;
13824
  case GGML_OP_OUT_PROD:
13825
  {