uvos commited on
Commit
18afa4b
·
1 Parent(s): 7d25156

HIP/CUDA: set the paramerter value in maintain_cuda_graph instead of replaceing it. (llama/12209)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml-cuda/ggml-cuda.cu +1 -1
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -2571,7 +2571,7 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto
2571
  for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2572
  if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
2573
  char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
2574
- cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
2575
  CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
2576
  }
2577
  }
 
2571
  for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2572
  if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
2573
  char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
2574
+ *(void**)cuda_ctx->cuda_graph->params[i].kernelParams[1] = *(void**)updated_kernel_arg_ptr;
2575
  CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
2576
  }
2577
  }