|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "sparse_matmul/compute/thread_bounds.h" |
|
|
|
#include <vector> |
|
|
|
#include "glog/logging.h" |
|
|
|
namespace csrblocksparse { |
|
|
|
void ThreadBounds::PrepareForThreads(int block_width, int block_height, |
|
int num_threads, |
|
int reduced_rows_per_cache_row, |
|
int reduced_rows, const int* nnz_per_row) { |
|
CHECK_GT(num_threads, 0); |
|
block_width_ = block_width; |
|
block_height_ = block_height; |
|
ComputeThreadSplitPoints(num_threads, reduced_rows_per_cache_row, |
|
reduced_rows, nnz_per_row); |
|
weight_starts_.clear(); |
|
rhs_indices_starts_.clear(); |
|
bias_starts_.clear(); |
|
weight_starts_.reserve(row_starts_.size()); |
|
rhs_indices_starts_.reserve(row_starts_.size()); |
|
bias_starts_.reserve(row_starts_.size()); |
|
|
|
|
|
|
|
int weight_index = 0; |
|
int rhs_indices_index = 0; |
|
int bias_index = 0; |
|
int row = 0; |
|
for (int start : row_starts_) { |
|
while (row < start) { |
|
weight_index += nnz_per_row[row] * block_width_ * block_height_; |
|
rhs_indices_index += nnz_per_row[row]; |
|
bias_index += block_height_; |
|
++row; |
|
} |
|
weight_starts_.push_back(weight_index); |
|
rhs_indices_starts_.push_back(rhs_indices_index); |
|
bias_starts_.push_back(bias_index); |
|
} |
|
} |
|
|
|
|
|
void ThreadBounds::ComputeThreadSplitPoints(int num_threads, |
|
int reduced_rows_per_cache_row, |
|
int reduced_rows, |
|
const int* nnz_per_row) { |
|
row_starts_.assign(1, 0); |
|
|
|
|
|
if (reduced_rows_per_cache_row * num_threads > reduced_rows) |
|
reduced_rows_per_cache_row = std::max(reduced_rows / num_threads, 1); |
|
int cache_rows = (reduced_rows + reduced_rows_per_cache_row - 1) / |
|
reduced_rows_per_cache_row; |
|
|
|
|
|
std::vector<int> work_upto_row(cache_rows + 1, 0); |
|
int extra_row_work = 2 * reduced_rows_per_cache_row; |
|
for (int i = 0; i < cache_rows; ++i) { |
|
int new_nnz = 0; |
|
for (int j = 0; j < reduced_rows_per_cache_row; ++j) { |
|
|
|
|
|
int index = i * reduced_rows_per_cache_row + j; |
|
if (index < reduced_rows) new_nnz += nnz_per_row[index]; |
|
} |
|
work_upto_row[i + 1] = new_nnz + extra_row_work + work_upto_row[i]; |
|
} |
|
int total_work = work_upto_row.back(); |
|
|
|
|
|
int prev_split = 0; |
|
for (int i = 1; i <= num_threads; ++i) { |
|
int split = std::distance( |
|
work_upto_row.begin(), |
|
std::lower_bound(work_upto_row.begin(), work_upto_row.end(), |
|
i * total_work / num_threads)); |
|
int split_row = split * reduced_rows_per_cache_row; |
|
if (i == num_threads) { |
|
split_row = reduced_rows; |
|
} |
|
|
|
VLOG(2) << "tid=" << i - 1 << " num rows=" << split_row - row_starts_.back() |
|
<< " work=" << work_upto_row[split] - work_upto_row[prev_split]; |
|
row_starts_.push_back(split_row); |
|
prev_split = split; |
|
} |
|
VLOG(2) << "total rows=" << reduced_rows << " total work=" << total_work; |
|
} |
|
|
|
} |
|
|