#include #include #include #include typedef struct { int tid; int nthreads; int ndocs; int ndoc_vectors; int nquery_vectors; int64_t* lengths; float* scores; int64_t* offsets; float* max_scores; } max_args_t; void* max(void* args) { max_args_t* max_args = (max_args_t*)args; int ndocs_per_thread = std::ceil(((float)max_args->ndocs) / max_args->nthreads); int start = max_args->tid * ndocs_per_thread; int end = std::min((max_args->tid + 1) * ndocs_per_thread, max_args->ndocs); auto max_scores_offset = max_args->max_scores + (start * max_args->nquery_vectors); auto scores_offset = max_args->scores + (max_args->offsets[start] * max_args->nquery_vectors); for (int i = start; i < end; i++) { for (int j = 0; j < max_args->lengths[i]; j++) { std::transform(max_scores_offset, max_scores_offset + max_args->nquery_vectors, scores_offset, max_scores_offset, [](float a, float b) { return std::max(a, b); }); scores_offset += max_args->nquery_vectors; } max_scores_offset += max_args->nquery_vectors; } return NULL; } torch::Tensor segmented_maxsim(const torch::Tensor scores, const torch::Tensor lengths) { auto lengths_a = lengths.data_ptr(); auto scores_a = scores.data_ptr(); auto ndocs = lengths.size(0); auto ndoc_vectors = scores.size(0); auto nquery_vectors = scores.size(1); auto nthreads = at::get_num_threads(); torch::Tensor max_scores = torch::zeros({ndocs, nquery_vectors}, scores.options()); int64_t offsets[ndocs + 1]; offsets[0] = 0; std::partial_sum(lengths_a, lengths_a + ndocs, offsets + 1); pthread_t threads[nthreads]; max_args_t args[nthreads]; for (int i = 0; i < nthreads; i++) { args[i].tid = i; args[i].nthreads = nthreads; args[i].ndocs = ndocs; args[i].ndoc_vectors = ndoc_vectors; args[i].nquery_vectors = nquery_vectors; args[i].lengths = lengths_a; args[i].scores = scores_a; args[i].offsets = offsets; args[i].max_scores = max_scores.data_ptr(); int rc = pthread_create(&threads[i], NULL, max, (void*)&args[i]); if (rc) { fprintf(stderr, "Unable to create thread %d: %d\n", i, rc); } } for (int i = 0; i < nthreads; i++) { pthread_join(threads[i], NULL); } return max_scores.sum(1); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("segmented_maxsim_cpp", &segmented_maxsim, "Segmented MaxSim"); }