File size: 2,373 Bytes
8b7c501 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <benchmark/benchmark.h>
#define BENCHMARK_BGEMM(bgemm_fn) \
BENCHMARK_CAPTURE(bgemm_fn, albert, "Albert")->Apply(AlbertBgemmArguments)->UseRealTime(); \
BENCHMARK_CAPTURE(bgemm_fn, mobilebert, "MobileBert")->Apply(MobilebertBgemmArguments)->UseRealTime(); \
BENCHMARK_CAPTURE(bgemm_fn, sd1x_diffusion, "SD1.X Diffusion")->Apply(SD1XDiffusionBgemmArguments)->UseRealTime(); \
BENCHMARK_CAPTURE(bgemm_fn, sd1x_encoder_decoder, "SD1.X Encoder-Decoder")->Apply(SD1XEncoderDecoderBgemmArguments)->UseRealTime(); \
BENCHMARK_CAPTURE(bgemm_fn, sd1x_text_encoder, "SD1.X Text Encoder")->Apply(SD1XTextEncoderBgemmArguments)->UseRealTime();
static void AlbertBgemmArguments(benchmark::internal::Benchmark* b) {
b->ArgNames({"B", "M", "N", "K"});
/* B M N K */
b->Args({12, 384, 64, 384});
b->Args({12, 384, 384, 64});
}
static void MobilebertBgemmArguments(benchmark::internal::Benchmark* b) {
b->ArgNames({"B", "M", "N", "K"});
/* B M N K */
b->Args({4, 384, 32, 384});
b->Args({4, 384, 384, 32});
}
static void SD1XDiffusionBgemmArguments(benchmark::internal::Benchmark* b) {
b->ArgNames({"B", "M", "N", "K"});
/* B M N K */
b->Args({8, 4096, 4096, 40});
b->Args({8, 4096, 40, 4096});
b->Args({8, 4096, 77, 40});
b->Args({8, 4096, 40, 77});
b->Args({8, 1024, 1024, 80});
b->Args({8, 1024, 80, 1024});
b->Args({8, 1024, 77, 80});
b->Args({8, 1024, 80, 77});
b->Args({8, 256, 256, 160});
b->Args({8, 256, 160, 256});
b->Args({8, 256, 77, 160});
b->Args({8, 256, 160, 77});
b->Args({8, 64, 64, 160});
b->Args({8, 64, 160, 64});
b->Args({8, 64, 77, 160});
b->Args({8, 64, 160, 77});
}
static void SD1XEncoderDecoderBgemmArguments(benchmark::internal::Benchmark* b) {
b->ArgNames({"B", "M", "N", "K"});
/* B M N K */
b->Args({1, 4096, 4096, 512});
b->Args({1, 512, 4096, 4096});
}
static void SD1XTextEncoderBgemmArguments(benchmark::internal::Benchmark* b) {
b->ArgNames({"B", "M", "N", "K"});
/* B M N K */
b->Args({12, 77, 77, 64});
b->Args({12, 77, 64, 77});
}
|