PreFLMR_ViT-L / segmented_maxsim.cpp
LinWeizheDragon's picture
Upload folder using huggingface_hub
ab8d3d5 verified
#include <pthread.h>
#include <torch/extension.h>
#include <algorithm>
#include <numeric>
typedef struct {
int tid;
int nthreads;
int ndocs;
int ndoc_vectors;
int nquery_vectors;
int64_t* lengths;
float* scores;
int64_t* offsets;
float* max_scores;
} max_args_t;
void* max(void* args) {
max_args_t* max_args = (max_args_t*)args;
int ndocs_per_thread =
std::ceil(((float)max_args->ndocs) / max_args->nthreads);
int start = max_args->tid * ndocs_per_thread;
int end = std::min((max_args->tid + 1) * ndocs_per_thread, max_args->ndocs);
auto max_scores_offset =
max_args->max_scores + (start * max_args->nquery_vectors);
auto scores_offset =
max_args->scores + (max_args->offsets[start] * max_args->nquery_vectors);
for (int i = start; i < end; i++) {
for (int j = 0; j < max_args->lengths[i]; j++) {
std::transform(max_scores_offset,
max_scores_offset + max_args->nquery_vectors,
scores_offset, max_scores_offset,
[](float a, float b) { return std::max(a, b); });
scores_offset += max_args->nquery_vectors;
}
max_scores_offset += max_args->nquery_vectors;
}
return NULL;
}
torch::Tensor segmented_maxsim(const torch::Tensor scores,
const torch::Tensor lengths) {
auto lengths_a = lengths.data_ptr<int64_t>();
auto scores_a = scores.data_ptr<float>();
auto ndocs = lengths.size(0);
auto ndoc_vectors = scores.size(0);
auto nquery_vectors = scores.size(1);
auto nthreads = at::get_num_threads();
torch::Tensor max_scores =
torch::zeros({ndocs, nquery_vectors}, scores.options());
int64_t offsets[ndocs + 1];
offsets[0] = 0;
std::partial_sum(lengths_a, lengths_a + ndocs, offsets + 1);
pthread_t threads[nthreads];
max_args_t args[nthreads];
for (int i = 0; i < nthreads; i++) {
args[i].tid = i;
args[i].nthreads = nthreads;
args[i].ndocs = ndocs;
args[i].ndoc_vectors = ndoc_vectors;
args[i].nquery_vectors = nquery_vectors;
args[i].lengths = lengths_a;
args[i].scores = scores_a;
args[i].offsets = offsets;
args[i].max_scores = max_scores.data_ptr<float>();
int rc = pthread_create(&threads[i], NULL, max, (void*)&args[i]);
if (rc) {
fprintf(stderr, "Unable to create thread %d: %d\n", i, rc);
}
}
for (int i = 0; i < nthreads; i++) {
pthread_join(threads[i], NULL);
}
return max_scores.sum(1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("segmented_maxsim_cpp", &segmented_maxsim, "Segmented MaxSim");
}