lhez commited on
Commit
2b95b05
·
1 Parent(s): ccee17d

opencl : update upscale to support align corners (llama/14488)

Browse files
ggml/src/ggml-opencl/ggml-opencl.cpp CHANGED
@@ -4453,7 +4453,8 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
4453
 
4454
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4455
 
4456
- const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);
 
4457
  cl_kernel kernel = nullptr;
4458
 
4459
  if (mode == GGML_SCALE_MODE_NEAREST) {
@@ -4484,18 +4485,22 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
4484
  const cl_ulong nb02 = src0->nb[2];
4485
  const cl_ulong nb03 = src0->nb[3];
4486
 
4487
- const int ne00_src = src0->ne[0];
4488
- const int ne01_src = src0->ne[1];
 
 
4489
 
4490
- const int ne10_dst = dst->ne[0];
4491
- const int ne11_dst = dst->ne[1];
4492
- const int ne12_dst = dst->ne[2];
4493
- const int ne13_dst = dst->ne[3];
 
 
 
 
 
4494
 
4495
- const float sf0 = (float)dst->ne[0] / src0->ne[0];
4496
- const float sf1 = (float)dst->ne[1] / src0->ne[1];
4497
- const float sf2 = (float)dst->ne[2] / src0->ne[2];
4498
- const float sf3 = (float)dst->ne[3] / src0->ne[3];
4499
 
4500
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
4501
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
@@ -4507,29 +4512,36 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
4507
  CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb03));
4508
 
4509
  if (mode == GGML_SCALE_MODE_NEAREST) {
4510
- CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne10_dst));
4511
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11_dst));
4512
- CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12_dst));
4513
- CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13_dst));
4514
  CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &sf0));
4515
  CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &sf1));
4516
  CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf2));
4517
  CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3));
4518
  } else if (mode == GGML_SCALE_MODE_BILINEAR) {
4519
- CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00_src));
4520
- CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01_src));
4521
- CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10_dst));
4522
- CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11_dst));
4523
- CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12_dst));
4524
- CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13_dst));
 
 
 
 
 
 
4525
  CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf0));
4526
  CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf1));
4527
  CL_CHECK(clSetKernelArg(kernel, 16, sizeof(float), &sf2));
4528
  CL_CHECK(clSetKernelArg(kernel, 17, sizeof(float), &sf3));
 
4529
  }
4530
 
4531
 
4532
- size_t dst_total_elements = (size_t)ne10_dst * ne11_dst * ne12_dst * ne13_dst;
4533
  if (dst_total_elements == 0) {
4534
  return;
4535
  }
 
4453
 
4454
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
4455
 
4456
+ const int mode_flags = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);
4457
+ const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
4458
  cl_kernel kernel = nullptr;
4459
 
4460
  if (mode == GGML_SCALE_MODE_NEAREST) {
 
4485
  const cl_ulong nb02 = src0->nb[2];
4486
  const cl_ulong nb03 = src0->nb[3];
4487
 
4488
+ const int ne00 = src0->ne[0];
4489
+ const int ne01 = src0->ne[1];
4490
+ const int ne02 = src0->ne[2];
4491
+ const int ne03 = src0->ne[3];
4492
 
4493
+ const int ne0 = dst->ne[0];
4494
+ const int ne1 = dst->ne[1];
4495
+ const int ne2 = dst->ne[2];
4496
+ const int ne3 = dst->ne[3];
4497
+
4498
+ float sf0 = (float)ne0 / ne00;
4499
+ float sf1 = (float)ne1 / ne01;
4500
+ float sf2 = (float)ne2 / ne02;
4501
+ float sf3 = (float)ne3 / ne03;
4502
 
4503
+ float pixel_offset = 0.5f;
 
 
 
4504
 
4505
  CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
4506
  CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
 
4512
  CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb03));
4513
 
4514
  if (mode == GGML_SCALE_MODE_NEAREST) {
4515
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne0));
4516
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne1));
4517
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne2));
4518
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne3));
4519
  CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &sf0));
4520
  CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &sf1));
4521
  CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf2));
4522
  CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3));
4523
  } else if (mode == GGML_SCALE_MODE_BILINEAR) {
4524
+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
4525
+ sf0 = (float)(ne0 - 1) / (ne00 - 1);
4526
+ sf1 = (float)(ne1 - 1) / (ne01 - 1);
4527
+ pixel_offset = 0.0f;
4528
+ }
4529
+
4530
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
4531
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
4532
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne0));
4533
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne1));
4534
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne2));
4535
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne3));
4536
  CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf0));
4537
  CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf1));
4538
  CL_CHECK(clSetKernelArg(kernel, 16, sizeof(float), &sf2));
4539
  CL_CHECK(clSetKernelArg(kernel, 17, sizeof(float), &sf3));
4540
+ CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &pixel_offset));
4541
  }
4542
 
4543
 
4544
+ size_t dst_total_elements = (size_t)ne0 * ne1 * ne2 * ne3;
4545
  if (dst_total_elements == 0) {
4546
  return;
4547
  }
ggml/src/ggml-opencl/kernels/upscale.cl CHANGED
@@ -60,7 +60,8 @@ kernel void kernel_upscale_bilinear(
60
  float sf0,
61
  float sf1,
62
  float sf2,
63
- float sf3
 
64
  ) {
65
  global const char * src_base = (global const char *)p_src0 + off_src0;
66
  global float * dst_base = (global float *)((global char *)p_dst + off_dst);
@@ -80,8 +81,6 @@ kernel void kernel_upscale_bilinear(
80
  int i02_src = (int)(i12_dst / sf2);
81
  int i03_src = (int)(i13_dst / sf3);
82
 
83
- const float pixel_offset = 0.5f;
84
-
85
  float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
86
  long y0_src = (long)floor(y_src_f);
87
  long y1_src = y0_src + 1;
 
60
  float sf0,
61
  float sf1,
62
  float sf2,
63
+ float sf3,
64
+ float pixel_offset
65
  ) {
66
  global const char * src_base = (global const char *)p_src0 + off_src0;
67
  global float * dst_base = (global float *)((global char *)p_dst + off_dst);
 
81
  int i02_src = (int)(i12_dst / sf2);
82
  int i03_src = (int)(i13_dst / sf3);
83
 
 
 
84
  float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
85
  long y0_src = (long)floor(y_src_f);
86
  long y1_src = y0_src + 1;