NTT123 commited on
Commit
d1a84ee
1 Parent(s): df1ad02

add fast cpp wavegru

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. BUILD +44 -0
  2. WORKSPACE +154 -0
  3. app.py +12 -1
  4. inference.py +7 -6
  5. packages.txt +2 -1
  6. sparse_matmul/BUILD +22 -0
  7. sparse_matmul/compute/BUILD +88 -0
  8. sparse_matmul/compute/ar_inputs.h +37 -0
  9. sparse_matmul/compute/gru_gates.h +214 -0
  10. sparse_matmul/compute/gru_gates_arm.h +288 -0
  11. sparse_matmul/compute/gru_gates_avx_fixed.h +348 -0
  12. sparse_matmul/compute/gru_gates_generic.h +97 -0
  13. sparse_matmul/compute/gru_gates_test.cc +164 -0
  14. sparse_matmul/compute/kernels_arm.h +0 -0
  15. sparse_matmul/compute/kernels_avx.h +601 -0
  16. sparse_matmul/compute/kernels_generic.h +273 -0
  17. sparse_matmul/compute/matmul.h +199 -0
  18. sparse_matmul/compute/matmul_fixed_avx2.cc +235 -0
  19. sparse_matmul/compute/matmul_fixed_avx2.h +49 -0
  20. sparse_matmul/compute/matmul_generic.cc +122 -0
  21. sparse_matmul/compute/matmul_generic.h +41 -0
  22. sparse_matmul/compute/thread_bounds.cc +106 -0
  23. sparse_matmul/compute/thread_bounds.h +74 -0
  24. sparse_matmul/layers/BUILD +146 -0
  25. sparse_matmul/layers/csr_blocksparse_matrix.h +835 -0
  26. sparse_matmul/layers/csrblocksparse_test.cc +977 -0
  27. sparse_matmul/layers/errno_mapping.cc +195 -0
  28. sparse_matmul/layers/errno_mapping.h +29 -0
  29. sparse_matmul/layers/masked_sparse_matrix.h +206 -0
  30. sparse_matmul/layers/read_array_ifstream.h +66 -0
  31. sparse_matmul/layers/sparse_linear_layer.h +365 -0
  32. sparse_matmul/layers/sparse_linear_layer_test.cc +187 -0
  33. sparse_matmul/layers/status_macros.h +34 -0
  34. sparse_matmul/layers/testdata/768_512_95_4x4_QRhat_weights.raw.gz +3 -0
  35. sparse_matmul/layers/testdata/768_512_95_4x4_What_weights.raw.gz +3 -0
  36. sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_bias.raw.gz +3 -0
  37. sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_mask.raw.gz +3 -0
  38. sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_weights.raw.gz +3 -0
  39. sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_bias.raw.gz +3 -0
  40. sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_mask.raw.gz +3 -0
  41. sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_weights.raw.gz +3 -0
  42. sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_bias.raw.gz +3 -0
  43. sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_mask.raw.gz +3 -0
  44. sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_weights.raw.gz +3 -0
  45. sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_bias.raw.gz +3 -0
  46. sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_mask.raw.gz +3 -0
  47. sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_weights.raw.gz +3 -0
  48. sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_bias.raw.gz +3 -0
  49. sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_mask.raw.gz +3 -0
  50. sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_weights.raw.gz +3 -0
BUILD ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [internal] load cc_fuzz_target.bzl
2
+ # [internal] load cc_proto_library.bzl
3
+ # [internal] load android_cc_test:def.bzl
4
+
5
+ load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
6
+
7
+ package(default_visibility = [":__subpackages__"])
8
+
9
+ licenses(["notice"])
10
+
11
+ # To run all cc_tests in this directory:
12
+ # bazel test //:all
13
+
14
+ # [internal] Command to run dsp_util_android_test.
15
+
16
+ # [internal] Command to run lyra_integration_android_test.
17
+
18
+ exports_files(
19
+ srcs = [
20
+ "wavegru_mod.cc",
21
+ ],
22
+ )
23
+
24
+ pybind_extension(
25
+ name = "wavegru_mod", # This name is not actually created!
26
+ srcs = ["wavegru_mod.cc"],
27
+ deps = [
28
+ "//sparse_matmul",
29
+ ],
30
+ )
31
+
32
+ py_library(
33
+ name = "wavegru_mod",
34
+ data = [":wavegru_mod.so"],
35
+ )
36
+
37
+ py_binary(
38
+ name = "wavegru",
39
+ srcs = ["wavegru.py"],
40
+ deps = [
41
+ ":wavegru_mod"
42
+ ],
43
+ )
44
+
WORKSPACE ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################
2
+ # Platform Independent #
3
+ ########################
4
+
5
+ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository", "new_git_repository")
6
+ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
7
+
8
+ # GoogleTest/GoogleMock framework.
9
+ git_repository(
10
+ name = "com_google_googletest",
11
+ remote = "https://github.com/google/googletest.git",
12
+ tag = "release-1.10.0",
13
+ )
14
+
15
+ # Google benchmark.
16
+ http_archive(
17
+ name = "com_github_google_benchmark",
18
+ urls = ["https://github.com/google/benchmark/archive/bf585a2789e30585b4e3ce6baf11ef2750b54677.zip"], # 2020-11-26T11:14:03Z
19
+ strip_prefix = "benchmark-bf585a2789e30585b4e3ce6baf11ef2750b54677",
20
+ sha256 = "2a778d821997df7d8646c9c59b8edb9a573a6e04c534c01892a40aa524a7b68c",
21
+ )
22
+
23
+ # proto_library, cc_proto_library, and java_proto_library rules implicitly
24
+ # depend on @com_google_protobuf for protoc and proto runtimes.
25
+ # This statement defines the @com_google_protobuf repo.
26
+ git_repository(
27
+ name = "com_google_protobuf",
28
+ remote = "https://github.com/protocolbuffers/protobuf.git",
29
+ tag = "v3.15.4",
30
+ )
31
+
32
+ load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps")
33
+ protobuf_deps()
34
+
35
+ # Google Abseil Libs
36
+ git_repository(
37
+ name = "com_google_absl",
38
+ remote = "https://github.com/abseil/abseil-cpp.git",
39
+ branch = "lts_2020_09_23",
40
+ )
41
+
42
+ # Filesystem
43
+ # The new_* prefix is used because it is not a bazel project and there is
44
+ # no BUILD file in that repo.
45
+ FILESYSTEM_BUILD = """
46
+ cc_library(
47
+ name = "filesystem",
48
+ hdrs = glob(["include/ghc/*"]),
49
+ visibility = ["//visibility:public"],
50
+ )
51
+ """
52
+
53
+ new_git_repository(
54
+ name = "gulrak_filesystem",
55
+ remote = "https://github.com/gulrak/filesystem.git",
56
+ tag = "v1.3.6",
57
+ build_file_content = FILESYSTEM_BUILD
58
+ )
59
+
60
+ # Audio DSP
61
+ git_repository(
62
+ name = "com_google_audio_dsp",
63
+ remote = "https://github.com/google/multichannel-audio-tools.git",
64
+ # There are no tags for this repo, we are synced to bleeding edge.
65
+ branch = "master",
66
+ repo_mapping = {
67
+ "@com_github_glog_glog" : "@com_google_glog"
68
+ }
69
+ )
70
+
71
+
72
+ http_archive(
73
+ name = "pybind11_bazel",
74
+ strip_prefix = "pybind11_bazel-72cbbf1fbc830e487e3012862b7b720001b70672",
75
+ urls = ["https://github.com/pybind/pybind11_bazel/archive/72cbbf1fbc830e487e3012862b7b720001b70672.zip"],
76
+ )
77
+ # We still require the pybind library.
78
+ http_archive(
79
+ name = "pybind11",
80
+ build_file = "@pybind11_bazel//:pybind11.BUILD",
81
+ strip_prefix = "pybind11-2.9.0",
82
+ urls = ["https://github.com/pybind/pybind11/archive/v2.9.0.tar.gz"],
83
+ )
84
+ load("@pybind11_bazel//:python_configure.bzl", "python_configure")
85
+ python_configure(name = "local_config_python")
86
+
87
+
88
+
89
+ # Transitive dependencies of Audio DSP.
90
+ http_archive(
91
+ name = "eigen_archive",
92
+ build_file = "eigen.BUILD",
93
+ sha256 = "f3d69ac773ecaf3602cb940040390d4e71a501bb145ca9e01ce5464cf6d4eb68",
94
+ strip_prefix = "eigen-eigen-049af2f56331",
95
+ urls = [
96
+ "http://mirror.tensorflow.org/bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
97
+ "https://bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
98
+ ],
99
+ )
100
+
101
+ http_archive(
102
+ name = "fft2d",
103
+ build_file = "fft2d.BUILD",
104
+ sha256 = "ada7e99087c4ed477bfdf11413f2ba8db8a840ba9bbf8ac94f4f3972e2a7cec9",
105
+ urls = [
106
+ "http://www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz",
107
+ ],
108
+ )
109
+
110
+ # Google logging
111
+ git_repository(
112
+ name = "com_google_glog",
113
+ remote = "https://github.com/google/glog.git",
114
+ branch = "master"
115
+ )
116
+ # Dependency for glog
117
+ git_repository(
118
+ name = "com_github_gflags_gflags",
119
+ remote = "https://github.com/mchinen/gflags.git",
120
+ branch = "android_linking_fix"
121
+ )
122
+
123
+ # Bazel/build rules
124
+
125
+ http_archive(
126
+ name = "bazel_skylib",
127
+ urls = [
128
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
129
+ "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
130
+ ],
131
+ sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44",
132
+ )
133
+ load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")
134
+ bazel_skylib_workspace()
135
+
136
+ http_archive(
137
+ name = "rules_android",
138
+ sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806",
139
+ strip_prefix = "rules_android-0.1.1",
140
+ urls = ["https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip"],
141
+ )
142
+
143
+ # Google Maven Repository
144
+ GMAVEN_TAG = "20180625-1"
145
+
146
+ http_archive(
147
+ name = "gmaven_rules",
148
+ strip_prefix = "gmaven_rules-%s" % GMAVEN_TAG,
149
+ url = "https://github.com/bazelbuild/gmaven_rules/archive/%s.tar.gz" % GMAVEN_TAG,
150
+ )
151
+
152
+ load("@gmaven_rules//:gmaven.bzl", "gmaven_rules")
153
+
154
+ gmaven_rules()
app.py CHANGED
@@ -1,6 +1,14 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
  from inference import load_tacotron_model, load_wavegru_net, text_to_mel, mel_to_wav
 
 
4
 
5
  alphabet, tacotron_net, tacotron_config = load_tacotron_model(
6
  "./alphabet.txt", "./tacotron.toml", "./pretrained_model_ljs_500k.ckpt"
@@ -11,10 +19,13 @@ wavegru_config, wavegru_net = load_wavegru_net(
11
  "./wavegru.yaml", "./wavegru_vocoder_tpu_gta_preemphasis_pruning_v7_0040000.ckpt"
12
  )
13
 
 
 
 
14
 
15
  def speak(text):
16
  mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
17
- y = mel_to_wav(wavegru_net, mel, wavegru_config)
18
  return 24_000, y
19
 
20
 
 
1
  import gradio as gr
2
+ import os
3
+
4
+
5
+ ## build wavegru-cpp
6
+ os.system("go get github.com/bazelbuild/bazelisk")
7
+ os.system("bazelisk build wavegru_mod -c opt --copt=-march=native")
8
 
9
  from inference import load_tacotron_model, load_wavegru_net, text_to_mel, mel_to_wav
10
+ from wavegru_cpp import load_wavegru_cpp, extract_weight_mask
11
+
12
 
13
  alphabet, tacotron_net, tacotron_config = load_tacotron_model(
14
  "./alphabet.txt", "./tacotron.toml", "./pretrained_model_ljs_500k.ckpt"
 
19
  "./wavegru.yaml", "./wavegru_vocoder_tpu_gta_preemphasis_pruning_v7_0040000.ckpt"
20
  )
21
 
22
+ wave_cpp_weight_mask = extract_weight_mask(wavegru_net)
23
+ wavecpp = load_wavegru_cpp(wave_cpp_weight_mask)
24
+
25
 
26
  def speak(text):
27
  mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
28
+ y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config)
29
  return 24_000, y
30
 
31
 
inference.py CHANGED
@@ -56,10 +56,10 @@ def load_wavegru_net(config_file, model_file):
56
  return config, net
57
 
58
 
59
- wavegru_inference = pax.pure(lambda net, mel: net.inference(mel, no_gru=False))
60
 
61
 
62
- def mel_to_wav(net, mel, config):
63
  """convert mel to wav"""
64
  if len(mel.shape) == 2:
65
  mel = mel[None]
@@ -69,10 +69,11 @@ def mel_to_wav(net, mel, config):
69
  [(0, 0), (pad, pad), (0, 0)],
70
  constant_values=np.log(config["mel_min"]),
71
  )
72
- x = wavegru_inference(net, mel)
73
- x = jax.device_get(x)
74
-
75
- wav = librosa.mu_expand(x - 127, mu=255)
 
76
  wav = librosa.effects.deemphasis(wav, coef=0.86)
77
  wav = wav * 2.0
78
  wav = wav / max(1.0, np.max(np.abs(wav)))
 
56
  return config, net
57
 
58
 
59
+ wavegru_inference = pax.pure(lambda net, mel: net.inference(mel, no_gru=True))
60
 
61
 
62
+ def mel_to_wav(net, netcpp, mel, config):
63
  """convert mel to wav"""
64
  if len(mel.shape) == 2:
65
  mel = mel[None]
 
69
  [(0, 0), (pad, pad), (0, 0)],
70
  constant_values=np.log(config["mel_min"]),
71
  )
72
+ ft = wavegru_inference(net, mel)
73
+ ft = jax.device_get(ft[0])
74
+ wav = netcpp.inference(ft, 1.0)
75
+ wav = np.array(wav)
76
+ wav = librosa.mu_expand(wav - 127, mu=255)
77
  wav = librosa.effects.deemphasis(wav, coef=0.86)
78
  wav = wav * 2.0
79
  wav = wav / max(1.0, np.max(np.abs(wav)))
packages.txt CHANGED
@@ -1 +1,2 @@
1
- libsndfile1-dev
 
 
1
+ libsndfile1-dev
2
+ golang-go
sparse_matmul/BUILD ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [internal] load placeholder
2
+
3
+ licenses(["notice"])
4
+
5
+ cc_library(
6
+ name = "sparse_matmul",
7
+ hdrs = [
8
+ "sparse_matmul.h",
9
+ ],
10
+ visibility = ["//visibility:public"],
11
+ deps = [
12
+ "//sparse_matmul/compute:gru_gates",
13
+ "//sparse_matmul/layers:layer",
14
+ "//sparse_matmul/layers:matrix",
15
+ "//sparse_matmul/layers:utils",
16
+ "//sparse_matmul/numerics:fast_transcendentals",
17
+ "//sparse_matmul/numerics:types",
18
+ "//sparse_matmul/os:coop_threads",
19
+ "//sparse_matmul/vector:cache_aligned_vector",
20
+ ], # internal :sparse_matmul deps placeholder
21
+ )
22
+
sparse_matmul/compute/BUILD ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Low-level computation code, including generic and architecture-specific
2
+ # variants.
3
+
4
+ licenses(["notice"])
5
+
6
+ cc_library(
7
+ name = "gru_gates",
8
+ srcs = [
9
+ "ar_inputs.h",
10
+ "gru_gates_arm.h",
11
+ "gru_gates_avx_fixed.h",
12
+ "gru_gates_generic.h",
13
+ ],
14
+ hdrs = ["gru_gates.h"],
15
+ visibility = [
16
+ "//visibility:public",
17
+ ],
18
+ deps = [
19
+ ":matmul",
20
+ "//sparse_matmul/numerics:fast_transcendentals",
21
+ "//sparse_matmul/numerics:types",
22
+ "//sparse_matmul/vector:cache_aligned_vector",
23
+ ],
24
+ )
25
+
26
+ cc_library(
27
+ name = "kernels",
28
+ srcs = [
29
+ "kernels_arm.h",
30
+ "kernels_avx.h",
31
+ ],
32
+ hdrs = [
33
+ "kernels_generic.h",
34
+ ],
35
+ visibility = [
36
+ "//sparse_matmul:__subpackages__",
37
+ ],
38
+ deps = [
39
+ "//sparse_matmul/numerics:fast_transcendentals",
40
+ "//sparse_matmul/numerics:types",
41
+ ],
42
+ )
43
+
44
+ cc_library(
45
+ name = "matmul",
46
+ srcs = [
47
+ "matmul_fixed_avx2.cc",
48
+ "matmul_fixed_avx2.h",
49
+ "matmul_generic.cc",
50
+ "matmul_generic.h",
51
+ ],
52
+ hdrs = [
53
+ "matmul.h",
54
+ ],
55
+ visibility = [
56
+ "//sparse_matmul:__subpackages__",
57
+ ],
58
+ deps = [
59
+ "//sparse_matmul/numerics:types",
60
+ "@com_google_absl//absl/time",
61
+ ],
62
+ )
63
+
64
+ cc_library(
65
+ name = "thread_bounds",
66
+ srcs = ["thread_bounds.cc"],
67
+ hdrs = ["thread_bounds.h"],
68
+ visibility = [
69
+ "//sparse_matmul:__subpackages__",
70
+ ],
71
+ deps = [
72
+ "@com_google_glog//:glog",
73
+ ],
74
+ )
75
+
76
+ cc_test(
77
+ name = "gru_gates_test",
78
+ size = "small",
79
+ srcs = [
80
+ "gru_gates_test.cc",
81
+ ],
82
+ deps = [
83
+ ":gru_gates",
84
+ "@com_google_absl//absl/memory",
85
+ "@com_google_absl//absl/types:span",
86
+ "@com_google_googletest//:gtest_main",
87
+ ],
88
+ )
sparse_matmul/compute/ar_inputs.h ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_
19
+
20
+ namespace csrblocksparse {
21
+
22
+ // Possible numbers of Autoregressive inputs.
23
+ // TODO(b/188702959): Generalize to any non-negative integer value?
24
+ enum class ARInputsMode {
25
+ // There are no autoregressive inputs. Inputs to the GRU gates are strictly
26
+ // from the gate-recurrent matmul and other unrelated inputs.
27
+ k0ARInputs,
28
+ // Two autoregressive inputs, such as coarse and fine for WaveRNN.
29
+ k2ARInputs,
30
+ // Three autoregressive inputs, such as prev coarse and fine plus current
31
+ // coarse for WaveRNN.
32
+ k3ARInputs,
33
+ };
34
+
35
+ } // namespace csrblocksparse
36
+
37
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_
sparse_matmul/compute/gru_gates.h ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_
19
+
20
+ #include <cstdint>
21
+ #include <vector>
22
+
23
+ // IWYU pragma: begin_exports
24
+ #include "sparse_matmul/compute/ar_inputs.h"
25
+ #include "sparse_matmul/compute/gru_gates_arm.h"
26
+ #include "sparse_matmul/compute/gru_gates_avx_fixed.h"
27
+ #include "sparse_matmul/compute/gru_gates_generic.h"
28
+ #include "sparse_matmul/compute/matmul.h"
29
+ #include "sparse_matmul/numerics/fixed_types.h"
30
+ #include "sparse_matmul/numerics/type_utils.h"
31
+ #include "sparse_matmul/vector/cache_aligned_vector.h"
32
+ // IWYU pragma: end_exports
33
+
34
+ namespace csrblocksparse {
35
+
36
+ // The master template is really a catch-all for the unimplemented cases to
37
+ // run the generics.
38
+ template <typename GRUStateType, typename InputType, typename SampleType = void>
39
+ class GruGates : public MatmulBase {
40
+ public:
41
+ using SampleWeightType = float;
42
+ static constexpr int kSIMDWidth = kGenericSIMDWidth;
43
+
44
+ // Generic GRU function covers all uses for WaveRNN-like architectures and
45
+ // conditioning.
46
+ // Controlled by template parameters thus:
47
+ // - |kInputsMode| == |k0ARInputs|: There are no autoregressive inputs so
48
+ // |ar_sample0|, |ar_sample1|, |ar_sample2|, |ar_01_weights|,
49
+ // |ar_2_weights| are ignored.
50
+ // - |kInputsMode| == |k2ARInputs|: |ar_sample0|, |ar_sample1| are multiplied
51
+ // by |ar_01_weights| and added to the (conditioning) input.
52
+ // - |kInputsMode| == |k3ARInputs|: |ar_sample2| is multiplied by
53
+ // |ar_2_weights| and added to the other two |ar_inputs| (and added to the
54
+ // conditioning input).
55
+ // - If |kSplitGates| is true: The |*gru_recurrent_other_ptr| is secondary
56
+ // recurrent input that must be added to |*gru_recurrent_ptr|.
57
+ // - |num_replicas| determines the number of duplicates of the output to be
58
+ // written, separated by |replica_stride|.
59
+ // - |start|, |end| are |rows| in [0, |state_size|] to be processed by this
60
+ // thread.
61
+ //
62
+ // Previous state is read from |*gru_state_ptr| and the new state is written
63
+ // to *(|gru_state_ptr| + i * |replica_stride| for i in [0, |num_replicas|)).
64
+ template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
65
+ bool kSplitGates = false>
66
+ void GruWithARInput(int start, int end, int state_size,
67
+ const InputType* gru_recurrent_ptr,
68
+ const InputType* input_ptr, GRUStateType* gru_state_ptr,
69
+ const SampleType* ar_sample0 = nullptr,
70
+ const SampleType* ar_sample1 = nullptr,
71
+ const SampleWeightType* ar_01_weights = nullptr,
72
+ int num_replicas = 1, int replica_stride = 0,
73
+ const SampleType* ar_sample2 = nullptr,
74
+ const SampleWeightType* ar_2_weights = nullptr,
75
+ const InputType* gru_recurrent_other_ptr = nullptr) {
76
+ CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica";
77
+ GoThroughGates<GRUStateType, InputType, SampleWeightType, SampleType,
78
+ kInputsMode, kSplitGates>(
79
+ start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr,
80
+ input_ptr, gru_state_ptr, ar_2_weights, state_size, ar_sample0,
81
+ ar_sample1, ar_sample2);
82
+ }
83
+
84
+ // No AR inputs, no split gates, no batching, no replicated outputs.
85
+ // TODO(b/188702959): Redirect conditioning GRU here, removing code from
86
+ // gru_layer.h.
87
+ // Copy to specializations.
88
+ void PlainGru(int start, int end, int state_size,
89
+ const InputType* gru_recurrent_ptr, const InputType* input_ptr,
90
+ GRUStateType* gru_state_ptr) {
91
+ GruWithARInput<ARInputsMode::k0ARInputs>(
92
+ start, end, state_size, gru_recurrent_ptr, input_ptr, gru_state_ptr);
93
+ }
94
+ };
95
+
96
+ #if defined __ARM_NEON || defined __aarch64__
97
+ // Partial specialization for float.
98
+ template <>
99
+ class GruGates<float, float, float> : public MatmulBase {
100
+ public:
101
+ static constexpr int kSIMDWidth = kNeonSIMDWidth;
102
+
103
+ // Generic GRU function covers all uses for WaveRNN-like architectures and
104
+ // conditioning.
105
+ template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
106
+ bool kSplitGates = false>
107
+ void GruWithARInput(int start, int end, int state_size,
108
+ const float* gru_recurrent_data, const float* input_data,
109
+ float* gru_state_data, const float* ar_sample0 = nullptr,
110
+ const float* ar_sample1 = nullptr,
111
+ const float* ar_01_weights = nullptr,
112
+ int num_replicas = 1, int replica_stride = 0,
113
+ const float* ar_sample2 = nullptr,
114
+ const float* ar_2_weights = nullptr,
115
+ const float* gru_recurrent_other_data = nullptr) {
116
+ DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica";
117
+ GoThroughGatesFloat<kInputsMode, kSplitGates>(
118
+ start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data,
119
+ input_data, gru_state_data, ar_2_weights, state_size, ar_sample0,
120
+ ar_sample1, ar_sample2);
121
+ }
122
+ };
123
+ #endif // defined __ARM_NEON || defined __aarch64__
124
+
125
+ // Partial specialization for fixed types. The sample weights are always float
126
+ // whatever the fixed type of the other weights.
127
+ template <int kGRUStateBits, int kInputBits, int kSampleBits>
128
+ class GruGates<fixed16<kGRUStateBits>, fixed32<kInputBits>,
129
+ fixed16<kSampleBits>> : public MatmulBase {
130
+ public:
131
+ #if defined __ARM_NEON || defined __aarch64__
132
+ static constexpr int kSIMDWidth = kNeonSIMDWidth;
133
+ #elif defined __AVX2__
134
+ static constexpr int kSIMDWidth = kAVX2SIMDWidth * 2;
135
+ #else // Generic case.
136
+ static constexpr int kSIMDWidth = kGenericSIMDWidth;
137
+ #endif // __ARM_NEON || defined __aarch64__ / __AVX2__
138
+
139
+ using GRUStateType = fixed16<kGRUStateBits>;
140
+ using InputType = fixed32<kInputBits>;
141
+ using SampleType = fixed16<kSampleBits>;
142
+ using SampleWeightType = float;
143
+ static constexpr int kInputMantissaBits = InputType::kMantissaBits;
144
+ static constexpr int kSampleMantissaBits = SampleType::kMantissaBits;
145
+ static constexpr int kStateMantissaBits = GRUStateType::kMantissaBits;
146
+ // Generic GRU function covers all uses for WaveRNN-like architectures and
147
+ // conditioning.
148
+ template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
149
+ bool kSplitGates = false>
150
+ void GruWithARInput(int start, int end, int state_size,
151
+ const InputType* gru_recurrent_data,
152
+ const InputType* input_data, GRUStateType* gru_state_data,
153
+ const SampleType* ar_sample0 = nullptr,
154
+ const SampleType* ar_sample1 = nullptr,
155
+ const SampleWeightType* ar_01_weights = nullptr,
156
+ int num_replicas = 1, int replica_stride = 0,
157
+ const SampleType* ar_sample2 = nullptr,
158
+ const SampleWeightType* ar_2_weights = nullptr,
159
+ const InputType* gru_recurrent_other_data = nullptr) {
160
+ #if defined __ARM_NEON || defined __aarch64__ || defined __AVX2__
161
+ const int32_t* gru_recurrent_ptr =
162
+ reinterpret_cast<const int32_t*>(gru_recurrent_data);
163
+ const int32_t* gru_recurrent_other_ptr =
164
+ reinterpret_cast<const int32_t*>(gru_recurrent_other_data);
165
+ const int32_t* input_ptr = reinterpret_cast<const int32_t*>(input_data);
166
+ int16_t* gru_state_ptr = reinterpret_cast<int16_t*>(gru_state_data);
167
+ #if defined __AVX2__
168
+ // The samples are fixed16, but we scale them up here and convert to float
169
+ // so that the product with the QR weights is always on the same scale as
170
+ // InputType, so we don't have to do any more scaling inside.
171
+ const float sample_factor = static_cast<float>(1 << kInputMantissaBits);
172
+ #else
173
+ const float sample_factor = 1.0f;
174
+ #endif
175
+ // AR sample 0 and 1 are packed into a pair because the QR weights are
176
+ // formatted with the weights interleaved for sample 0 and 1.
177
+ std::pair<float, float> ar_sample01;
178
+ float ar_sample2_float = 0.0f;
179
+ if (kInputsMode == ARInputsMode::k2ARInputs ||
180
+ kInputsMode == ARInputsMode::k3ARInputs) {
181
+ ar_sample01 = {static_cast<float>(*ar_sample0) * sample_factor,
182
+ static_cast<float>(*ar_sample1) * sample_factor};
183
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
184
+ ar_sample2_float = static_cast<float>(*ar_sample2) * sample_factor;
185
+ }
186
+ }
187
+ #if defined __AVX2__
188
+ CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!";
189
+ GruGatesAVXFixed<kInputMantissaBits, kStateMantissaBits, kInputsMode,
190
+ kSplitGates>(
191
+ start, end, state_size, gru_recurrent_ptr, input_ptr, &ar_sample01,
192
+ ar_01_weights, num_replicas, replica_stride, &ar_sample2_float,
193
+ ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
194
+ #else // ARM.
195
+ DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica";
196
+ GoThroughGatesFixed<GRUStateType, InputType, kInputsMode, kSplitGates>(
197
+ start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr,
198
+ input_ptr, gru_state_ptr, ar_2_weights, state_size, &ar_sample01,
199
+ &ar_sample2_float);
200
+ #endif // __AVX2__ / ARM.
201
+ #else // Generic case.
202
+ CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica";
203
+ GoThroughGates<GRUStateType, InputType, SampleWeightType, SampleType,
204
+ kInputsMode, kSplitGates>(
205
+ start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data,
206
+ input_data, gru_state_data, ar_2_weights, state_size, ar_sample0,
207
+ ar_sample1, ar_sample2);
208
+ #endif // __ARM_NEON || defined __aarch64__ / __AVX2__
209
+ }
210
+ };
211
+
212
+ } // namespace csrblocksparse
213
+
214
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_
sparse_matmul/compute/gru_gates_arm.h ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_
19
+
20
+ #if defined __ARM_NEON || defined __aarch64__
21
+ #include <arm_neon.h>
22
+ #endif
23
+ #include <cstdint>
24
+
25
+ #include "sparse_matmul/compute/ar_inputs.h"
26
+ #include "sparse_matmul/numerics/fast_transcendentals.h"
27
+
28
+ namespace csrblocksparse {
29
+
30
+ static constexpr int kNeonSIMDWidth = 4;
31
+
32
+ // ------ Scalar calculation --------
33
+ // See "Efficient Neural Audio Synthesis" for a description of the calculation.
34
+ // https://arxiv.org/abs/1802.08435
35
+ //
36
+ // NOTE:
37
+ // |sample| = (|coarse_at_sminus1|, |fine_at_sminus1|,
38
+ // |coarse_at_sminus1|, |fine_at_sminus1|)
39
+ // |w_sample| = (|coarse_at_s|, |coarse_at_s|, |coarse_at_s|, |coarse_at_s|)
40
+ //
41
+ // CHEATSHEET:
42
+ // vld1q_f32 = load 4 32-bit floats
43
+ // vmulq_f32(a, b) : return a * b;
44
+ // vaddq_f32(a, b) : return a + b;
45
+ // vmlaq_f32(c, a, b) : return c + a * b;
46
+ // vpaddq_f32(a, b) : return (a0 + a1, a2 + a3, b0 + b1, b2 + b3)
47
+ // vsubq_f32(a, b) : return a - b;
48
+ // vst1q_f32 = store 4 32-bit floats
49
+ #if defined __ARM_NEON || defined __aarch64__
50
+
51
+ #if !defined __aarch64__
52
+ // Backport of vpaddq_f32 to ARM32.
53
+ inline float32x4_t vpaddq_f32(float32x4_t a, float32x4_t b) {
54
+ float32x2_t a10 = vget_low_f32(a);
55
+ float32x2_t a32 = vget_high_f32(a);
56
+ float32x2_t b10 = vget_low_f32(b);
57
+ float32x2_t b32 = vget_high_f32(b);
58
+ return vcombine_f32(vpadd_f32(a10, a32), vpadd_f32(b10, b32));
59
+ }
60
+ #endif
61
+
62
+ template <ARInputsMode kInputsMode, bool SplitGates>
63
+ void GoThroughGatesFloat(int start, int end, const float* qr_ptr,
64
+ const float* gru_gates_ptr,
65
+ const float* gru_gates_other_ptr,
66
+ const float* conditioning_ptr, float* gru_h_ptr,
67
+ const float* w_hat, int proj_size,
68
+ const float* coarse_at_sminus1,
69
+ const float* fine_at_sminus1,
70
+ const float* coarse_at_s) {
71
+ // Increment all the pointers to save on pointer arithmetic in the loop.
72
+ conditioning_ptr += start;
73
+ gru_h_ptr += start;
74
+ gru_gates_ptr += start;
75
+ if (SplitGates) {
76
+ DCHECK_NE(gru_gates_other_ptr, nullptr);
77
+ gru_gates_other_ptr += start;
78
+ }
79
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
80
+ DCHECK_NE(qr_ptr, nullptr);
81
+ qr_ptr += 2 * start;
82
+ DCHECK_NE(coarse_at_sminus1, nullptr);
83
+ DCHECK_NE(fine_at_sminus1, nullptr);
84
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
85
+ DCHECK_NE(w_hat, nullptr);
86
+ DCHECK_NE(coarse_at_s, nullptr);
87
+ w_hat += start;
88
+ }
89
+ }
90
+ for (int i = start; i < end; i += kNeonSIMDWidth) {
91
+ float32x4_t reset = vld1q_f32(gru_gates_ptr);
92
+ float32x4_t update = vld1q_f32(gru_gates_ptr + proj_size);
93
+ float32x4_t cell = vld1q_f32(gru_gates_ptr + 2 * proj_size);
94
+ float32x4_t qr_cell;
95
+ if (SplitGates) {
96
+ reset = vaddq_f32(reset, vld1q_f32(gru_gates_other_ptr));
97
+ update = vaddq_f32(update, vld1q_f32(gru_gates_other_ptr + proj_size));
98
+ cell = vaddq_f32(cell, vld1q_f32(gru_gates_other_ptr + 2 * proj_size));
99
+ }
100
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
101
+ // Setup the sample vector.
102
+ float32x4_t sample = vdupq_n_f32(*coarse_at_sminus1);
103
+ sample = vsetq_lane_f32(*fine_at_sminus1, sample, 1);
104
+ sample = vsetq_lane_f32(*fine_at_sminus1, sample, 3);
105
+
106
+ // All auto types are float32x4_t, auto used to fit statements on one line
107
+ // for readability. Do two rows of QR at once.
108
+ auto qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample);
109
+ auto qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample);
110
+ auto qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1);
111
+
112
+ auto qr_update_0 = vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample);
113
+ auto qr_update_1 =
114
+ vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample);
115
+ auto qr_update = vpaddq_f32(qr_update_0, qr_update_1);
116
+
117
+ auto qr_cell_0 = vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample);
118
+ auto qr_cell_1 = vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample);
119
+ qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1);
120
+
121
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
122
+ float32x4_t w_sample = vdupq_n_f32(*coarse_at_s);
123
+ qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample);
124
+ qr_update =
125
+ vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample);
126
+ qr_cell =
127
+ vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample);
128
+ }
129
+ reset = vaddq_f32(reset, qr_reset);
130
+ update = vaddq_f32(update, qr_update);
131
+ }
132
+ auto reset_conditioning = vld1q_f32(conditioning_ptr);
133
+ auto update_conditioning = vld1q_f32(conditioning_ptr + proj_size);
134
+ auto cell_conditioning = vld1q_f32(conditioning_ptr + 2 * proj_size);
135
+
136
+ reset = fast_sigmoid(vaddq_f32(reset, reset_conditioning));
137
+ update = fast_sigmoid(vaddq_f32(update, update_conditioning));
138
+ if (kInputsMode == ARInputsMode::k0ARInputs) {
139
+ cell = vmulq_f32(reset, cell);
140
+ } else {
141
+ cell = vmlaq_f32(qr_cell, reset, cell);
142
+ }
143
+ auto hbar = fast_tanh(vaddq_f32(cell, cell_conditioning));
144
+
145
+ auto prev_h = vld1q_f32(gru_h_ptr);
146
+ auto diff = vsubq_f32(prev_h, hbar);
147
+ auto new_h = vmlaq_f32(hbar, diff, update);
148
+
149
+ vst1q_f32(gru_h_ptr, new_h);
150
+ // Increment all the pointers.
151
+ conditioning_ptr += kNeonSIMDWidth;
152
+ gru_h_ptr += kNeonSIMDWidth;
153
+ gru_gates_ptr += kNeonSIMDWidth;
154
+ if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth;
155
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
156
+ qr_ptr += 2 * kNeonSIMDWidth;
157
+ if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth;
158
+ }
159
+ }
160
+ }
161
+
162
+ // This version should only be used if all of the 32-bit fixed point
163
+ // representations have the same number of mantissa bits.
164
+ // |ar_at_sminus1| packs sample 0 and 1 into a pair because the QR weights are
165
+ // formatted with the weights interleaved for sample 0 and 1. The two samples
166
+ // represent coarse and fine for WaveRNN.
167
+ template <typename GRUStateType, typename GRUMatMulOutType,
168
+ ARInputsMode kInputsMode, bool SplitGates>
169
+ void GoThroughGatesFixed(int start, int end, const float* qr_ptr,
170
+ const int32_t* gru_gates_ptr,
171
+ const int32_t* gru_gates_other_ptr,
172
+ const int32_t* conditioning_ptr, int16_t* gru_h_ptr,
173
+ const float* w_hat, int proj_size,
174
+ const std::pair<float, float>* ar_at_sminus1,
175
+ const float* coarse_at_s) {
176
+ // Increment all the pointers to save on pointer arithmetic in the loop.
177
+ conditioning_ptr += start;
178
+ gru_h_ptr += start;
179
+ gru_gates_ptr += start;
180
+ if (SplitGates) {
181
+ DCHECK_NE(gru_gates_other_ptr, nullptr);
182
+ gru_gates_other_ptr += start;
183
+ }
184
+ float32x4_t sample01;
185
+ float32x4_t w_sample;
186
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
187
+ DCHECK_NE(qr_ptr, nullptr);
188
+ qr_ptr += 2 * start;
189
+ DCHECK_NE(ar_at_sminus1, nullptr);
190
+ sample01 = vdupq_n_f32(ar_at_sminus1->first);
191
+ sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 1);
192
+ sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 3);
193
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
194
+ DCHECK_NE(w_hat, nullptr);
195
+ DCHECK_NE(coarse_at_s, nullptr);
196
+ w_hat += start;
197
+ w_sample = vdupq_n_f32(*coarse_at_s);
198
+ }
199
+ }
200
+ for (int i = start; i < end; i += kNeonSIMDWidth) {
201
+ auto reset = vld1q_s32(gru_gates_ptr);
202
+ auto update = vld1q_s32(gru_gates_ptr + proj_size);
203
+ // vcvtq_n_f32_s32 = convert 32-bit fixed point to fp32
204
+ auto cell_int = vld1q_s32(gru_gates_ptr + 2 * proj_size);
205
+ if (SplitGates) {
206
+ reset = vaddq_s32(reset, vld1q_s32(gru_gates_other_ptr));
207
+ update = vaddq_s32(update, vld1q_s32(gru_gates_other_ptr + proj_size));
208
+ cell_int =
209
+ vaddq_s32(cell_int, vld1q_s32(gru_gates_other_ptr + 2 * proj_size));
210
+ }
211
+ float32x4_t cell =
212
+ vcvtq_n_f32_s32(cell_int, GRUMatMulOutType::kMantissaBits);
213
+ float32x4_t qr_cell;
214
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
215
+ // Do two rows of QR at once.
216
+ float32x4_t qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample01);
217
+ float32x4_t qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample01);
218
+ float32x4_t qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1);
219
+
220
+ float32x4_t qr_update_0 =
221
+ vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample01);
222
+ float32x4_t qr_update_1 =
223
+ vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample01);
224
+ float32x4_t qr_update = vpaddq_f32(qr_update_0, qr_update_1);
225
+
226
+ float32x4_t qr_cell_0 =
227
+ vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample01);
228
+ float32x4_t qr_cell_1 =
229
+ vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample01);
230
+ qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1);
231
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
232
+ float32x4_t w_sample = vdupq_n_f32(*coarse_at_s);
233
+ qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample);
234
+ qr_update =
235
+ vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample);
236
+ qr_cell =
237
+ vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample);
238
+ }
239
+ reset = vaddq_s32(
240
+ reset, vcvtq_n_s32_f32(qr_reset, GRUMatMulOutType::kMantissaBits));
241
+ update = vaddq_s32(
242
+ update, vcvtq_n_s32_f32(qr_update, GRUMatMulOutType::kMantissaBits));
243
+ }
244
+
245
+ auto reset_conditioning = vld1q_s32(conditioning_ptr);
246
+ auto update_conditioning = vld1q_s32(conditioning_ptr + proj_size);
247
+ float32x4_t cell_conditioning =
248
+ vcvtq_n_f32_s32(vld1q_s32(conditioning_ptr + 2 * proj_size),
249
+ GRUMatMulOutType::kMantissaBits);
250
+
251
+ float32x4_t reset_f32 = fast_sigmoid<GRUMatMulOutType::kExponentBits>(
252
+ vaddq_s32(reset, reset_conditioning));
253
+ float32x4_t update_f32 = fast_sigmoid<GRUMatMulOutType::kExponentBits>(
254
+ vaddq_s32(update, update_conditioning));
255
+ if (kInputsMode == ARInputsMode::k0ARInputs) {
256
+ cell = vmulq_f32(reset_f32, cell);
257
+ } else {
258
+ cell = vmlaq_f32(qr_cell, reset_f32, cell);
259
+ }
260
+ float32x4_t hbar = fast_tanh(vaddq_f32(cell, cell_conditioning));
261
+
262
+ float32x4_t prev_h = vcvtq_n_f32_s32(vmovl_s16(vld1_s16(gru_h_ptr)),
263
+ GRUStateType::kMantissaBits);
264
+ float32x4_t diff = vsubq_f32(prev_h, hbar);
265
+ float32x4_t new_h = vmlaq_f32(hbar, diff, update_f32);
266
+
267
+ // vcvtq_n_s32_f32 = convert fp32 to signed 32-bit fixed point
268
+ // vqrshrn_n_s32 = saturating, rounding, narrowing right shift - used to
269
+ // convert a 32-bit fixed point value to a 16-bit fixed point value
270
+ vst1_s16(gru_h_ptr,
271
+ vqrshrn_n_s32(
272
+ vcvtq_n_s32_f32(new_h, GRUStateType::kMantissaBits + 16), 16));
273
+ // Increment all the pointers.
274
+ conditioning_ptr += kNeonSIMDWidth;
275
+ gru_h_ptr += kNeonSIMDWidth;
276
+ gru_gates_ptr += kNeonSIMDWidth;
277
+ if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth;
278
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
279
+ qr_ptr += 2 * kNeonSIMDWidth;
280
+ if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth;
281
+ }
282
+ }
283
+ }
284
+ #endif // defined __ARM_NEON || defined __aarch64__
285
+
286
+ } // namespace csrblocksparse
287
+
288
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_
sparse_matmul/compute/gru_gates_avx_fixed.h ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_
19
+
20
+ #include <cstdint>
21
+ #if defined __AVX2__
22
+ #include <immintrin.h>
23
+ #endif
24
+ #include <vector>
25
+
26
+ #include "sparse_matmul/compute/ar_inputs.h"
27
+ #include "sparse_matmul/numerics/fast_transcendentals.h"
28
+
29
+ namespace csrblocksparse {
30
+
31
+ #if defined __AVX2__
32
+
33
+ constexpr int kAVX2SIMDWidth = 8;
34
+
35
+ // Loads 8x fixed32 from |ptr0| and adds to |input|.
36
+ // If |kTwoInputs|, also loads from |ptr1| and adds that as well.
37
+ // Returns the 2 or 3-way sum.
38
+ template <bool kTwoInputs>
39
+ inline __m256i LoadAndAddFixed32(const int32_t* ptr0, const int32_t* ptr1,
40
+ const __m256i& input) {
41
+ __m256i data0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr0));
42
+ if (kTwoInputs) {
43
+ __m256i data1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr1));
44
+ data0 = _mm256_add_epi32(data0, data1);
45
+ }
46
+ return _mm256_add_epi32(data0, input);
47
+ }
48
+
49
+ // Loads 8x fixed32 from ptr0.
50
+ // If |kTwoInputs|, also loads from |ptr1| and adds.
51
+ // Multiplies the loaded values by the factor and adds to |input|, which also
52
+ // is converted to float.
53
+ // Returns the sum.
54
+ template <bool kTwoInputs>
55
+ inline __m256 LoadMultiplyAddToFloat(const int32_t* ptr0, const int32_t* ptr1,
56
+ const __m256& float_factor,
57
+ const __m256& input) {
58
+ __m256i data0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr0));
59
+ if (kTwoInputs) {
60
+ __m256i data1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr1));
61
+ data0 = _mm256_add_epi32(data0, data1);
62
+ }
63
+ __m256 float_result = _mm256_cvtepi32_ps(data0);
64
+ float_result = _mm256_mul_ps(float_result, float_factor);
65
+ return _mm256_add_ps(float_result, input);
66
+ }
67
+
68
+ // Loads 16x float in 2x 8x registers from |ptr0_1| and multiplies by
69
+ // |input_pairs|, likewise formatted as 8x floats, alternating between the two
70
+ // AR inputs and sums each pair of results, making 8x float results.
71
+ // If |kThreeInputs|, also loads 8x float from |ptr2| and multiplies by
72
+ // |third_input|, which must be formatted as 8x float. The second product is
73
+ // added to the previous result.
74
+ // Returns the sum added to |accumulator|.
75
+ template <bool kThreeInputs>
76
+ inline __m256 MultiplyAddFloat(const __m256& input_pairs,
77
+ const __m256& third_input, const float* ptr0_1,
78
+ const float* ptr2, const __m256& accumulator) {
79
+ __m256 data_pair0 = _mm256_load_ps(ptr0_1);
80
+ __m256 data_pair1 = _mm256_load_ps(ptr0_1 + 8);
81
+ data_pair0 = _mm256_mul_ps(data_pair0, input_pairs);
82
+ data_pair1 = _mm256_mul_ps(data_pair1, input_pairs);
83
+ data_pair0 = _mm256_hadd_ps(data_pair0, data_pair1);
84
+ // Swap the middle 2 64 bit pairs to correct the hadd result.
85
+ data_pair0 = _mm256_permute4x64_pd((__m256d)data_pair0, 0xd8);
86
+ if (kThreeInputs) {
87
+ // Load 256 bits (8 x float) of data, then multiply-accumulate.
88
+ data_pair1 = _mm256_load_ps(ptr2);
89
+ data_pair1 = _mm256_mul_ps(data_pair1, third_input);
90
+ data_pair0 = _mm256_add_ps(data_pair0, data_pair1);
91
+ }
92
+ // Add conditioning.
93
+ return _mm256_add_ps(data_pair0, accumulator);
94
+ }
95
+
96
+ // Processes the tanh and the final combination, returns the new GRU state.
97
+ template <int kInputMantissaBits, int kStateMantissaBits, bool kSplitGates>
98
+ inline __m256i GRUComputeState(const __m256& cell0, const __m256& cell1,
99
+ const __m256& reset0, const __m256& reset1,
100
+ const __m256& update0, const __m256& update1,
101
+ const int32_t* gate_ptr,
102
+ const int32_t* gate_other_ptr,
103
+ const void* gru_h_ptr) {
104
+ // Multiply the cell gru output and the reset.
105
+ __m256 float_gru0 = LoadMultiplyAddToFloat<kSplitGates>(
106
+ gate_ptr, gate_other_ptr, reset0, cell0);
107
+ __m256 float_gru1 = LoadMultiplyAddToFloat<kSplitGates>(
108
+ gate_ptr + kAVX2SIMDWidth, gate_other_ptr + kAVX2SIMDWidth, reset1,
109
+ cell1);
110
+ // Compute tanh on the result.
111
+ __m256 hbar0, hbar1;
112
+ float_tanh_float<kInputMantissaBits, TM_ORDER4_FLOAT>(float_gru0, float_gru1,
113
+ hbar0, hbar1);
114
+ // Load the 16-bit previous gru state and update.
115
+ __m256i gru = _mm256_load_si256(reinterpret_cast<__m256i const*>(gru_h_ptr));
116
+ __m256 state_factor =
117
+ _mm256_set1_ps(1.0f / (static_cast<float>(1 << kStateMantissaBits)));
118
+ float_gru0 =
119
+ _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(gru)));
120
+ float_gru1 = _mm256_cvtepi32_ps(
121
+ _mm256_cvtepi16_epi32(_mm256_extractf128_si256(gru, 1)));
122
+ float_gru0 = _mm256_mul_ps(float_gru0, state_factor);
123
+ float_gru1 = _mm256_mul_ps(float_gru1, state_factor);
124
+ float_gru0 = _mm256_sub_ps(float_gru0, hbar0);
125
+ float_gru1 = _mm256_sub_ps(float_gru1, hbar1);
126
+ float_gru0 = _mm256_mul_ps(float_gru0, update0);
127
+ float_gru1 = _mm256_mul_ps(float_gru1, update1);
128
+ state_factor = _mm256_set1_ps(static_cast<float>(1 << kStateMantissaBits));
129
+ float_gru0 = _mm256_add_ps(float_gru0, hbar0);
130
+ float_gru1 = _mm256_add_ps(float_gru1, hbar1);
131
+ float_gru0 = _mm256_mul_ps(float_gru0, state_factor);
132
+ float_gru1 = _mm256_mul_ps(float_gru1, state_factor);
133
+ return PackFloatsToFixed16(float_gru0, float_gru1);
134
+ }
135
+
136
+ // According to |kInputsMode|, processes 0, 2 or 3 autoregressive inputs and
137
+ // combines with |input| and |gates*|.
138
+ // With 2 AR inputs, loads 8x pairs of float from |pair_weights| and multiplies
139
+ // by |paired_ar|, likewise formatted as 8x float, but scaled such that the
140
+ // product with pair_weights is on the same scale as |*input| and |*gates0|,
141
+ // and sums each pair result, making 8x float results.
142
+ // If 3 AR inputs, also loads 8x float from |third_weights| and multiplies by
143
+ // |third_ar|, which must be formatted as 8x scaled floats. The second product
144
+ // is added to the previous result.
145
+ // Inputs, 8x fixed32 are loaded from |input|, and added to the total.
146
+ // Finally 8x fixed32 from |gates0| (and |gates1| if |kTwoGates|) are added as
147
+ // well.
148
+ // Returns the total sum as a float, but on the scale of |*input|.
149
+ template <bool kTwoGates, ARInputsMode kInputsMode>
150
+ inline __m256 GruInput32ToFloat(const __m256& paired_ar,
151
+ const __m256& third_ar,
152
+ const float* pair_weights,
153
+ const float* third_weights,
154
+ const int32_t* gates0, const int32_t* gates1,
155
+ const int32_t* input) {
156
+ __m256i data32 = _mm256_load_si256(reinterpret_cast<__m256i const*>(input));
157
+ data32 = LoadAndAddFixed32<kTwoGates>(gates0, gates1, data32);
158
+ __m256 float_data = _mm256_cvtepi32_ps(data32);
159
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
160
+ float_data = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>(
161
+ paired_ar, third_ar, pair_weights, third_weights, float_data);
162
+ }
163
+ return float_data;
164
+ }
165
+
166
+ // Generic GRU gates function controlled by template parameters thus:
167
+ // - |kInputBits|: the mantissa bits in |*input_ptr|, |*gru_recurrent_ptr|.
168
+ // - |kStateBits|: the mantissa_bits in |*gru_state_ptr|.
169
+ // - |kInputsMode == |k0ARInputs|: There are no autoregressive inputs so
170
+ // |ar_sample, |ar_sample1|, |ar_sample2|, |ar_01_weights|, |ar_2_weights| are
171
+ // ignored.
172
+ // - |kInputsMode| == |k2ARInputs|: |ar_sample0|, |ar_sample1| are multiplied by
173
+ // |ar_01_weights| and added to the (conditioning) input.
174
+ // - |kInputsMode| == |k3ARInputs|: |ar_sample2| is multiplied by |ar_2_weights|
175
+ // and added to the other two AR inputs (and added to the conditioning input).
176
+ // - |kReplicas| determines the number of duplicates of the output to be
177
+ // written, separated by |replica_stride|. If zero, then the number of
178
+ // replicas is variable and taken from the |replicas| argument.
179
+ // - If |kSplitGates| is true: The |*gru_recurrent_other_ptr| is secondary
180
+ // recurrent input that must be added to |*gru_recurrent_ptr|.
181
+ // - |start|, |end| are |rows| in [0, |state_size|] to be processed by this
182
+ // thread.
183
+ //
184
+ // Previous state is read from |*gru_state_ptr| and the new state is written to
185
+ // *(|gru_state_ptr| + i * |replica_stride| for i in [0, |kReplicas|]).
186
+ template <int kInputBits, int kStateBits,
187
+ ARInputsMode kInputsMode = ARInputsMode::k0ARInputs,
188
+ int kReplicas = 1, bool kSplitGates = false>
189
+ inline void GruGatesTemplate(
190
+ int start, int end, int state_size, int replicas, int replica_stride,
191
+ const int32_t* gru_recurrent_ptr, const int32_t* input_ptr,
192
+ const std::pair<float, float>* ar_sample01, const float* ar_01_weights,
193
+ const float* ar_sample2, const float* ar_2_weights,
194
+ const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) {
195
+ constexpr int kQRIncrement = kAVX2SIMDWidth;
196
+ // Increment all the pointers to save on pointer arithmetic in the loop.
197
+ input_ptr += start;
198
+ gru_state_ptr += start;
199
+ gru_recurrent_ptr += start;
200
+ if (kSplitGates) gru_recurrent_other_ptr += start;
201
+ __m256 ar_2_inputs, ar_3rd_input;
202
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
203
+ ar_01_weights += 2 * start;
204
+ ar_2_inputs = _mm256_castsi256_ps(
205
+ _mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(ar_sample01)));
206
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
207
+ ar_2_weights += start;
208
+ ar_3rd_input = _mm256_set1_ps(*ar_sample2);
209
+ } else {
210
+ ar_3rd_input = {};
211
+ }
212
+ } else {
213
+ ar_2_inputs = {};
214
+ ar_3rd_input = {};
215
+ }
216
+ // The transcendentals handle 2x registers of data at once, so we have to do
217
+ // everything in duplicate.
218
+ for (int i = start; i < end; i += kQRIncrement * 2) {
219
+ // Load 8 pairs of fixed16s for each of reset, update and cell.
220
+ __m256 reset0 = GruInput32ToFloat<kSplitGates, kInputsMode>(
221
+ ar_2_inputs, ar_3rd_input, ar_01_weights, ar_2_weights,
222
+ gru_recurrent_ptr, gru_recurrent_other_ptr, input_ptr);
223
+ __m256 reset1 = GruInput32ToFloat<kSplitGates, kInputsMode>(
224
+ ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * kQRIncrement,
225
+ ar_2_weights + kQRIncrement, gru_recurrent_ptr + kAVX2SIMDWidth,
226
+ gru_recurrent_other_ptr + kAVX2SIMDWidth, input_ptr + kAVX2SIMDWidth);
227
+ float_sigmoid_float<kInputBits>(reset0, reset1);
228
+ __m256 update0 = GruInput32ToFloat<kSplitGates, kInputsMode>(
229
+ ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * state_size,
230
+ ar_2_weights + state_size, gru_recurrent_ptr + state_size,
231
+ gru_recurrent_other_ptr + state_size, input_ptr + state_size);
232
+ __m256 update1 = GruInput32ToFloat<kSplitGates, kInputsMode>(
233
+ ar_2_inputs, ar_3rd_input,
234
+ ar_01_weights + 2 * state_size + 2 * kQRIncrement,
235
+ ar_2_weights + state_size + kQRIncrement,
236
+ gru_recurrent_ptr + state_size + kAVX2SIMDWidth,
237
+ gru_recurrent_other_ptr + state_size + kAVX2SIMDWidth,
238
+ input_ptr + state_size + kAVX2SIMDWidth);
239
+ float_sigmoid_float<kInputBits>(update0, update1);
240
+ __m256 cell0 = _mm256_cvtepi32_ps(_mm256_load_si256(
241
+ reinterpret_cast<__m256i const*>(input_ptr + 2 * state_size)));
242
+ __m256 cell1 =
243
+ _mm256_cvtepi32_ps(_mm256_load_si256(reinterpret_cast<__m256i const*>(
244
+ input_ptr + 2 * state_size + kAVX2SIMDWidth)));
245
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
246
+ cell0 = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>(
247
+ ar_2_inputs, ar_3rd_input, ar_01_weights + 4 * state_size,
248
+ ar_2_weights + 2 * state_size, cell0);
249
+ cell1 = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>(
250
+ ar_2_inputs, ar_3rd_input,
251
+ ar_01_weights + 4 * state_size + 2 * kQRIncrement,
252
+ ar_2_weights + 2 * state_size + kQRIncrement, cell1);
253
+ }
254
+ __m256i gru_state = GRUComputeState<kInputBits, kStateBits, kSplitGates>(
255
+ cell0, cell1, reset0, reset1, update0, update1,
256
+ gru_recurrent_ptr + 2 * state_size,
257
+ gru_recurrent_other_ptr + 2 * state_size, gru_state_ptr);
258
+ if (kReplicas > 0) {
259
+ // With |kReplicas| a template parameter, the compiler will unroll the
260
+ // loop.
261
+ for (int j = 0; j < kReplicas; ++j) {
262
+ _mm256_store_si256(
263
+ reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride),
264
+ gru_state);
265
+ }
266
+ } else {
267
+ // This loop will not unroll as replicas is variable.
268
+ for (int j = 0; j < replicas; ++j) {
269
+ _mm256_store_si256(
270
+ reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride),
271
+ gru_state);
272
+ }
273
+ }
274
+ // Increment all the pointers.
275
+ input_ptr += 2 * kAVX2SIMDWidth;
276
+ gru_state_ptr += 2 * kAVX2SIMDWidth;
277
+ gru_recurrent_ptr += 2 * kAVX2SIMDWidth;
278
+ if (kSplitGates) gru_recurrent_other_ptr += 2 * kAVX2SIMDWidth;
279
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
280
+ ar_01_weights += 4 * kQRIncrement;
281
+ if (kInputsMode == ARInputsMode::k3ARInputs)
282
+ ar_2_weights += 2 * kQRIncrement;
283
+ }
284
+ }
285
+ }
286
+
287
+ // Dispatches calls to the GruGatesTemplate function above converting the
288
+ // replicas variable argument to a template parameter to allow the compiler to
289
+ // unroll the write loop.
290
+ // |ar_sample01| packs sample 0 and 1 into a pair because the QR weights are
291
+ // formatted with the weights interleaved for sample 0 and 1. The two samples
292
+ // represent coarse and fine for WaveRNN.
293
+ template <int kInputBits, int kStateBits,
294
+ ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
295
+ bool kSplitGates = false>
296
+ inline void GruGatesAVXFixed(
297
+ int start, int end, int state_size, const int32_t* gru_recurrent_ptr,
298
+ const int32_t* input_ptr, const std::pair<float, float>* ar_sample01,
299
+ const float* ar_01_weights, int num_replicas, int replica_stride,
300
+ const float* ar_sample2, const float* ar_2_weights,
301
+ const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) {
302
+ // Convert the number of replicas from a variable to a template parameter
303
+ // with a switch. This enables the compiler to unroll the loop for
304
+ // the write, making it faster for common numbers of threads.
305
+ switch (num_replicas) {
306
+ case 1:
307
+ GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/1,
308
+ kSplitGates>(
309
+ start, end, state_size, num_replicas, replica_stride,
310
+ gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
311
+ ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
312
+ break;
313
+ case 2:
314
+ GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/2,
315
+ kSplitGates>(
316
+ start, end, state_size, num_replicas, replica_stride,
317
+ gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
318
+ ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
319
+ break;
320
+ case 4:
321
+ GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/4,
322
+ kSplitGates>(
323
+ start, end, state_size, num_replicas, replica_stride,
324
+ gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
325
+ ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
326
+ break;
327
+ case 6:
328
+ GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/6,
329
+ kSplitGates>(
330
+ start, end, state_size, num_replicas, replica_stride,
331
+ gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
332
+ ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
333
+ break;
334
+ default:
335
+ // Zero |kReplicas| tells the function to use the |num_replicas| variable.
336
+ GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/0,
337
+ kSplitGates>(
338
+ start, end, state_size, num_replicas, replica_stride,
339
+ gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
340
+ ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
341
+ }
342
+ }
343
+
344
+ #endif // __AVX2__
345
+
346
+ } // namespace csrblocksparse
347
+
348
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_
sparse_matmul/compute/gru_gates_generic.h ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_
19
+
20
+ #include "sparse_matmul/compute/ar_inputs.h"
21
+ #include "sparse_matmul/numerics/fast_transcendentals.h"
22
+
23
+ namespace csrblocksparse {
24
+
25
+ constexpr int kGenericSIMDWidth = 4;
26
+
27
+ // TODO(b/188702959): Rename arguments to match gru_gates.h.
28
+ template <typename GRUStateType, typename GRUMatMulOutType, typename QR_W_Type,
29
+ typename SampleType, ARInputsMode kInputsMode,
30
+ bool SplitGates = false>
31
+ void GoThroughGates(int start, int end, const QR_W_Type* qr_ptr,
32
+ const GRUMatMulOutType* gru_gates_ptr,
33
+ const GRUMatMulOutType* gru_gates_other_ptr,
34
+ const GRUMatMulOutType* conditioning_ptr,
35
+ GRUStateType* gru_h_ptr, const QR_W_Type* w_hat,
36
+ int proj_size, const SampleType* coarse_at_sminus1,
37
+ const SampleType* fine_at_sminus1,
38
+ const SampleType* coarse_at_s = nullptr) {
39
+ float qr_cell = 0.0f, reset, update, cell;
40
+ for (int i = start; i < end; ++i) {
41
+ if (kInputsMode == ARInputsMode::k0ARInputs) {
42
+ reset = static_cast<float>(gru_gates_ptr[i]);
43
+ update = static_cast<float>(gru_gates_ptr[proj_size + i]);
44
+ } else {
45
+ float qr_c_reset = static_cast<float>(qr_ptr[2 * i + 0]);
46
+ float qr_f_reset = static_cast<float>(qr_ptr[2 * i + 1]);
47
+ float qr_c_update = static_cast<float>(qr_ptr[2 * proj_size + 2 * i + 0]);
48
+ float qr_f_update = static_cast<float>(qr_ptr[2 * proj_size + 2 * i + 1]);
49
+ float qr_c_cell = static_cast<float>(qr_ptr[4 * proj_size + 2 * i + 0]);
50
+ float qr_f_cell = static_cast<float>(qr_ptr[4 * proj_size + 2 * i + 1]);
51
+ float w_hat_i_reset = 0.0f;
52
+ float w_hat_i_update = 0.0f;
53
+ float w_hat_i_cell = 0.0f;
54
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
55
+ w_hat_i_reset = static_cast<float>(w_hat[i]);
56
+ w_hat_i_update = static_cast<float>(w_hat[proj_size + i]);
57
+ w_hat_i_cell = static_cast<float>(w_hat[2 * proj_size + i]);
58
+ }
59
+ float coarse = static_cast<float>(coarse_at_sminus1[0]);
60
+ float fine = static_cast<float>(fine_at_sminus1[0]);
61
+ reset = qr_c_reset * coarse + qr_f_reset * fine;
62
+ update = qr_c_update * coarse + qr_f_update * fine;
63
+ qr_cell = qr_c_cell * coarse + qr_f_cell * fine;
64
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
65
+ float coarse = static_cast<float>(coarse_at_s[0]);
66
+ reset += w_hat_i_reset * coarse;
67
+ update += w_hat_i_update * coarse;
68
+ qr_cell += w_hat_i_cell * coarse;
69
+ }
70
+ reset += static_cast<float>(gru_gates_ptr[i]);
71
+ update += static_cast<float>(gru_gates_ptr[proj_size + i]);
72
+ }
73
+ cell = static_cast<float>(gru_gates_ptr[2 * proj_size + i]);
74
+ if (SplitGates) {
75
+ reset += static_cast<float>(gru_gates_other_ptr[i]);
76
+ update += static_cast<float>(gru_gates_other_ptr[proj_size + i]);
77
+ cell += static_cast<float>(gru_gates_other_ptr[2 * proj_size + i]);
78
+ }
79
+ float reset_conditioning = static_cast<float>(conditioning_ptr[i]);
80
+ float update_conditioning =
81
+ static_cast<float>(conditioning_ptr[proj_size + i]);
82
+ float cell_conditioning =
83
+ static_cast<float>(conditioning_ptr[2 * proj_size + i]);
84
+ reset = fast_sigmoid(reset + reset_conditioning);
85
+ update = fast_sigmoid(update + update_conditioning);
86
+ float hbar = fast_tanh(qr_cell + reset * cell + cell_conditioning);
87
+ int h_index = i;
88
+ float prev_h = static_cast<float>(gru_h_ptr[h_index]);
89
+ float diff = prev_h - hbar;
90
+ float new_h = hbar + diff * update;
91
+ gru_h_ptr[h_index] = static_cast<GRUStateType>(new_h);
92
+ }
93
+ }
94
+
95
+ } // namespace csrblocksparse
96
+
97
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_
sparse_matmul/compute/gru_gates_test.cc ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include "sparse_matmul/compute/gru_gates.h"
16
+
17
+ #include <cstdint>
18
+ #include <cstring>
19
+ #include <numeric>
20
+
21
+ #include "absl/memory/memory.h"
22
+ #include "absl/types/span.h"
23
+ #include "gmock/gmock.h"
24
+ #include "gtest/gtest.h"
25
+
26
+ namespace {
27
+
28
+ using csrblocksparse::ARInputsMode;
29
+
30
+ template <typename GRUStateType, typename InputType, typename SampleType = void,
31
+ csrblocksparse::ARInputsMode kInputsMode, bool kSplitGates>
32
+ csrblocksparse::CacheAlignedVector<GRUStateType> TestGruGates() {
33
+ using SampleWeightType = float;
34
+ constexpr int kStateSize = 16;
35
+ csrblocksparse::CacheAlignedVector<SampleWeightType> qr(6 * kStateSize);
36
+ csrblocksparse::CacheAlignedVector<SampleWeightType> w(3 * kStateSize);
37
+ csrblocksparse::CacheAlignedVector<InputType> gru_gates(3 * kStateSize);
38
+ csrblocksparse::CacheAlignedVector<InputType> gru_other_gates(3 * kStateSize);
39
+ csrblocksparse::CacheAlignedVector<InputType> conditioning(3 * kStateSize);
40
+ csrblocksparse::CacheAlignedVector<GRUStateType> gru_h(kStateSize);
41
+ csrblocksparse::GruGates<GRUStateType, InputType, SampleType> gru_gates_impl;
42
+ const SampleType kCoarseAtSMinus1(0.03f);
43
+ const SampleType kFineAtSMinus1(0.07f);
44
+ const SampleType kCoarseAtS(-0.02f);
45
+
46
+ qr.FillOnes();
47
+ w.FillOnes();
48
+ gru_gates.FillRandom();
49
+ gru_other_gates.FillRandom();
50
+ conditioning.FillRandom();
51
+ gru_h.FillZero();
52
+
53
+ gru_gates_impl.template GruWithARInput<kInputsMode, kSplitGates>(
54
+ /*start=*/0, /*end=*/kStateSize, kStateSize, gru_gates.data(),
55
+ conditioning.data(), gru_h.data(), &kCoarseAtSMinus1, &kFineAtSMinus1,
56
+ qr.data(),
57
+ /*num_replicas=*/1, /*replica_stride=*/0, &kCoarseAtS, w.data(),
58
+ gru_other_gates.data());
59
+ return gru_h;
60
+ }
61
+
62
+ TEST(GruGates, FloatWaveRNNCoarseMatchesGolden) {
63
+ // If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers
64
+ // will also need to change.
65
+ const std::vector<float> kGoldenValues = {
66
+ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.746f, 0.0f, 0.0f,
67
+ 0.0f, 0.0f, 0.970f, 0.0f, 0.0f, 1.0f, 0.0f, -0.993f};
68
+ csrblocksparse::CacheAlignedVector<float> gru_h =
69
+ TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
70
+ /*kSplitGates=*/true>();
71
+
72
+ ASSERT_EQ(kGoldenValues.size(), gru_h.size());
73
+ for (int i = 0; i < gru_h.size(); ++i) {
74
+ EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i;
75
+ }
76
+ }
77
+
78
+ TEST(GruGates, FloatWaveRNNFineMatchesGolden) {
79
+ // If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers
80
+ // will also need to change.
81
+ const std::vector<float> kGoldenValues = {
82
+ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.737f, 0.0f, 0.0f,
83
+ 0.0f, 0.0f, 0.969f, 0.0f, 0.0f, 1.0f, 0.0f, -0.994f};
84
+ csrblocksparse::CacheAlignedVector<float> gru_h =
85
+ TestGruGates<float, float, float, ARInputsMode::k3ARInputs,
86
+ /*kSplitGates=*/true>();
87
+
88
+ ASSERT_EQ(kGoldenValues.size(), gru_h.size());
89
+ for (int i = 0; i < gru_h.size(); ++i) {
90
+ EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i;
91
+ }
92
+ }
93
+
94
+ TEST(GruGates, FloatTwoArInputsNonSplitGateMatchesGolden) {
95
+ // If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers
96
+ // will also need to change.
97
+ const std::vector<float> kGoldenValues = {
98
+ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.714f, 0.0f, -0.002f,
99
+ 0.0f, 0.0f, 0.970f, 0.0f, 0.0f, 1.0f, 0.0f, -0.965f};
100
+ csrblocksparse::CacheAlignedVector<float> gru_h =
101
+ TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
102
+ /*kSplitGates=*/false>();
103
+
104
+ ASSERT_EQ(kGoldenValues.size(), gru_h.size());
105
+ for (int i = 0; i < gru_h.size(); ++i) {
106
+ EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i;
107
+ }
108
+ }
109
+
110
+ TEST(GruGates, FixedWaveRNNCoarseMatchesFloat) {
111
+ using GRUMatMulOutType = csrblocksparse::fixed32<11>;
112
+ using GRUStateType = csrblocksparse::fixed16<2>;
113
+ using SampleType = csrblocksparse::fixed16<0>;
114
+ csrblocksparse::CacheAlignedVector<float> float_gru_h =
115
+ TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
116
+ /*kSplitGates=*/true>();
117
+ csrblocksparse::CacheAlignedVector<GRUStateType> fixed_gru_h =
118
+ TestGruGates<GRUStateType, GRUMatMulOutType, SampleType,
119
+ ARInputsMode::k2ARInputs, /*kSplitGates=*/true>();
120
+
121
+ ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size());
122
+ for (int i = 0; i < fixed_gru_h.size(); ++i) {
123
+ EXPECT_NEAR(float_gru_h[i], static_cast<float>(fixed_gru_h[i]), 1e-3)
124
+ << "i=" << i;
125
+ }
126
+ }
127
+
128
+ TEST(GruGates, FixedWaveRNNFineMatchesFloat) {
129
+ using GRUMatMulOutType = csrblocksparse::fixed32<11>;
130
+ using GRUStateType = csrblocksparse::fixed16<2>;
131
+ using SampleType = csrblocksparse::fixed16<0>;
132
+ csrblocksparse::CacheAlignedVector<float> float_gru_h =
133
+ TestGruGates<float, float, float, ARInputsMode::k3ARInputs,
134
+ /*kSplitGates=*/true>();
135
+ csrblocksparse::CacheAlignedVector<GRUStateType> fixed_gru_h =
136
+ TestGruGates<GRUStateType, GRUMatMulOutType, SampleType,
137
+ ARInputsMode::k3ARInputs, /*kSplitGates=*/true>();
138
+
139
+ ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size());
140
+ for (int i = 0; i < fixed_gru_h.size(); ++i) {
141
+ EXPECT_NEAR(float_gru_h[i], static_cast<float>(fixed_gru_h[i]), 1e-3)
142
+ << "i=" << i;
143
+ }
144
+ }
145
+
146
+ TEST(GruGates, FixedTwoArInputsNonSplitGateMatchesFloat) {
147
+ using GRUMatMulOutType = csrblocksparse::fixed32<11>;
148
+ using GRUStateType = csrblocksparse::fixed16<2>;
149
+ using SampleType = csrblocksparse::fixed16<0>;
150
+ csrblocksparse::CacheAlignedVector<float> float_gru_h =
151
+ TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
152
+ /*kSplitGates=*/false>();
153
+ csrblocksparse::CacheAlignedVector<GRUStateType> fixed_gru_h =
154
+ TestGruGates<GRUStateType, GRUMatMulOutType, SampleType,
155
+ ARInputsMode::k2ARInputs, /*kSplitGates=*/false>();
156
+
157
+ ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size());
158
+ for (int i = 0; i < fixed_gru_h.size(); ++i) {
159
+ EXPECT_NEAR(float_gru_h[i], static_cast<float>(fixed_gru_h[i]), 1e-3)
160
+ << "i=" << i;
161
+ }
162
+ }
163
+
164
+ } // namespace
sparse_matmul/compute/kernels_arm.h ADDED
The diff for this file is too large to render. See raw diff
 
sparse_matmul/compute/kernels_avx.h ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_
19
+
20
+ #if defined __AVX__
21
+ #include <immintrin.h>
22
+
23
+ #include <algorithm>
24
+ #include <type_traits>
25
+ // TODO(b/188702959): Remove fast_transcendentals with GRU refactor.
26
+ #include "sparse_matmul/numerics/fast_transcendentals.h"
27
+ #include "sparse_matmul/numerics/fixed_types.h"
28
+ #include "sparse_matmul/numerics/float16_types.h"
29
+ #include "sparse_matmul/numerics/type_utils.h"
30
+
31
+ namespace csrblocksparse {
32
+ namespace detail {
33
+
34
+ template <typename WeightType, typename RhsType, typename OutType>
35
+ struct IsAllowableFloatTypes
36
+ : std::integral_constant<bool, std::is_same<WeightType, float>::value &&
37
+ std::is_same<RhsType, float>::value &&
38
+ std::is_same<OutType, float>::value> {};
39
+
40
+ #if defined __AVX2__
41
+ // 16-bit inputs, 32-bit output exponent matches sum of input exponents
42
+ // OR
43
+ // 16-bit inputs, 16-bit output - will shift to match exponent
44
+ template <typename WeightType, typename RhsType, typename OutType>
45
+ struct IsAllowableFixedTypes
46
+ : std::integral_constant<bool, (IsFixed16Type<WeightType>::value &&
47
+ IsFixed16Type<RhsType>::value) &&
48
+ (IsFixed32Type<OutType>::value ||
49
+ IsFixed16Type<OutType>::value)> {};
50
+
51
+ template <typename WeightType, typename RhsType, typename OutType>
52
+ struct ShouldEnableGenericKernel
53
+ : std::integral_constant<
54
+ bool,
55
+ !IsAllowableFloatTypes<WeightType, RhsType, OutType>::value &&
56
+ !IsAllowableFixedTypes<WeightType, RhsType, OutType>::value> {};
57
+
58
+ template <typename Type>
59
+ struct IsAddableFixedTypes
60
+ : std::integral_constant<bool, IsFixed32Type<Type>::value ||
61
+ IsFixed16Type<Type>::value> {};
62
+ template <typename Type>
63
+ struct ShouldEnableGenericAdd
64
+ : std::integral_constant<bool, !IsAddableFixedTypes<Type>::value> {};
65
+
66
+ #else // No AVX2.
67
+
68
+ template <typename WeightType, typename RhsType, typename OutType>
69
+ struct ShouldEnableGenericKernel
70
+ : std::integral_constant<
71
+ bool, !IsAllowableFloatTypes<WeightType, RhsType, OutType>::value> {};
72
+
73
+ template <typename Type>
74
+ struct ShouldEnableGenericAdd : std::true_type {};
75
+ #endif // __AVX2__
76
+
77
+ template <typename WeightType, typename RhsType, typename OutType>
78
+ struct ShouldEnableGenericSpMV_4x4
79
+ : ShouldEnableGenericKernel<WeightType, RhsType, OutType> {};
80
+ template <typename WeightType, typename RhsType, typename OutType>
81
+ struct ShouldEnableGenericSpMM5_4x4
82
+ : ShouldEnableGenericKernel<WeightType, RhsType, OutType> {};
83
+ template <typename WeightType, typename RhsType, typename OutType>
84
+ struct ShouldEnableGenericSpMV_1x1 : std::true_type {};
85
+ template <typename WeightType, typename RhsType, typename OutType>
86
+ struct ShouldEnableGenericSpMM5_1x1 : std::true_type {};
87
+
88
+ // The computational routines do NO error checking for speed. It is assumed
89
+ // that this has been handled by CSRBlockSparseMatrix.
90
+
91
+ // In-line function to extract results from a pair of registers and store in
92
+ // memory. Note that the non-const references are registers, and are modified
93
+ // by this function!
94
+ inline void Extract4Results(bool relu, __m256& sum1, __m256& sum2,
95
+ float** out_ptr) {
96
+ // Horizontally add the results. We have 2 registers, |sum1| and |sum2| that
97
+ // each contain 2 sets of 4 values that need to be added.
98
+ sum1 = _mm256_hadd_ps(sum1, sum2);
99
+ sum1 = _mm256_hadd_ps(sum1, sum1);
100
+ // Now |sum1| contains [|res0|, |res2|, |res0|, |res2|, |res1|, |res3|,
101
+ // |res1|, |res3|]
102
+ if (relu) {
103
+ sum1 = _mm256_max_ps(sum1, _mm256_setzero_ps());
104
+ }
105
+ // It is really hard in AVX to cross the 128 bit 'lanes' and this is the
106
+ // *only* way to do it.
107
+ // Get the top half of |sum1| in to bottom of |sum2|.
108
+ sum2 = _mm256_permute2f128_ps(sum1, sum1, 1);
109
+ // Interleave the values between the two registers.
110
+ sum1 = _mm256_unpacklo_ps(sum1, sum2);
111
+ // Save the lower 128 bits (4 floats).
112
+ __m128 result = _mm256_extractf128_ps(sum1, 0);
113
+ _mm_store_ps(*out_ptr, result);
114
+ *out_ptr += 4;
115
+ }
116
+
117
+ // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
118
+ // blocked pattern, x is a vector and b is vector. Weights are stored for this
119
+ // routine by making each 4x4 block contiguous. Blocks are ordered in standard
120
+ // row-major format. column indices are converted to deltas and then multiplied
121
+ // by 2 to convert to bytes, so that the value can be used directly to offset
122
+ // the pointer into the rhs vector.
123
+ //
124
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
125
+ // this function. This is automatically taken care of in SparseLinearLayer.
126
+ // The bias is reconstructed through horizontal additions, leads to a small
127
+ // speedup by reducing latencies at the end of the loop.
128
+ template <typename WeightType, typename RhsType, typename OutType>
129
+ typename std::enable_if<std::is_same<WeightType, float>::value &&
130
+ std::is_same<RhsType, float>::value &&
131
+ std::is_same<OutType, float>::value>::type
132
+ SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
133
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
134
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
135
+ OutType* out_ptr, int64_t assigned_rows,
136
+ int64_t rows /* only used in SpMM variants */,
137
+ int64_t cols /* only used in SpMM variants */, int relu) {
138
+ for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
139
+ // Broadcast the biases by 4 to undo the division by 4 in the input biases.
140
+ __m256 sum1 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
141
+ _mm_broadcast_ss(bias_ptr));
142
+ bias_ptr += 2;
143
+ __m256 sum2 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
144
+ _mm_broadcast_ss(bias_ptr));
145
+ bias_ptr += 2;
146
+
147
+ int reduced_col_count = *nnz_per_row++;
148
+ for (int c = 0; c < reduced_col_count; ++c) {
149
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
150
+ rhs_ptr += col_delta;
151
+ // Multiply this 4x4 block.
152
+ __m256 rhs =
153
+ _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr));
154
+ __m256 weights1 = _mm256_load_ps(weights_ptr);
155
+ weights_ptr += 8;
156
+ sum1 = _mm256_add_ps(sum1, _mm256_mul_ps(weights1, rhs));
157
+ __m256 weights2 = _mm256_load_ps(weights_ptr);
158
+ weights_ptr += 8;
159
+ sum2 = _mm256_add_ps(sum2, _mm256_mul_ps(weights2, rhs));
160
+ }
161
+ Extract4Results(relu, sum1, sum2, &out_ptr);
162
+ }
163
+ }
164
+
165
+ // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
166
+ // blocked pattern, x is a fat vector with 5 columns and b is vector. b is
167
+ // broadcast. Weights are stored for this routine by making each 4x4 block
168
+ // contiguous. Blocks are ordered in standard row-major format. column indices
169
+ // are converted to deltas and then multiplied by 2 to convert to bytes, so
170
+ // that the value can be used directly to offset the pointer into the rhs
171
+ // vector.
172
+ //
173
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
174
+ // this function. This is automatically taken care of in SparseLinearLayer.
175
+ // The bias is reconstructed through horizontal additions, leads to a small
176
+ // speedup by reducing latencies at the end of the loop.
177
+ template <typename WeightType, typename RhsType, typename OutType>
178
+ typename std::enable_if<std::is_same<WeightType, float>::value &&
179
+ std::is_same<RhsType, float>::value &&
180
+ std::is_same<OutType, float>::value>::type
181
+ SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
182
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
183
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
184
+ OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
185
+ int relu) {
186
+ const RhsType* rhs_ptrs[5];
187
+ for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
188
+
189
+ OutType* out_ptrs[5];
190
+ for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
191
+
192
+ for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
193
+ // We will acumulate the results in 10 registers, |sum1_0| to |sum2_4|.
194
+ // Broadcast the biases by 4 to undo the division by 4 in the input biases.
195
+ __m256 sum1_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
196
+ _mm_broadcast_ss(bias_ptr));
197
+ bias_ptr += 2;
198
+ __m256 sum2_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
199
+ _mm_broadcast_ss(bias_ptr));
200
+ bias_ptr += 2;
201
+ __m256 sum1_1 = sum1_0;
202
+ __m256 sum2_1 = sum2_0;
203
+ __m256 sum1_2 = sum1_0;
204
+ __m256 sum2_2 = sum2_0;
205
+ __m256 sum1_3 = sum1_0;
206
+ __m256 sum2_3 = sum2_0;
207
+ __m256 sum1_4 = sum1_0;
208
+ __m256 sum2_4 = sum2_0;
209
+
210
+ int reduced_col_count = *nnz_per_row++;
211
+ for (int c = 0; c < reduced_col_count; ++c) {
212
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
213
+ for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta;
214
+
215
+ // Multiply this 4x4 block.
216
+ __m256 rhs =
217
+ _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[0]));
218
+ __m256 weights1 = _mm256_load_ps(weights_ptr);
219
+ weights_ptr += 8;
220
+ sum1_0 = _mm256_add_ps(sum1_0, _mm256_mul_ps(weights1, rhs));
221
+ __m256 weights2 = _mm256_load_ps(weights_ptr);
222
+ weights_ptr += 8;
223
+ sum2_0 = _mm256_add_ps(sum2_0, _mm256_mul_ps(weights2, rhs));
224
+ rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[1]));
225
+ sum1_1 = _mm256_add_ps(sum1_1, _mm256_mul_ps(weights1, rhs));
226
+ sum2_1 = _mm256_add_ps(sum2_1, _mm256_mul_ps(weights2, rhs));
227
+ rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[2]));
228
+ sum1_2 = _mm256_add_ps(sum1_2, _mm256_mul_ps(weights1, rhs));
229
+ sum2_2 = _mm256_add_ps(sum2_2, _mm256_mul_ps(weights2, rhs));
230
+ rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[3]));
231
+ sum1_3 = _mm256_add_ps(sum1_3, _mm256_mul_ps(weights1, rhs));
232
+ sum2_3 = _mm256_add_ps(sum2_3, _mm256_mul_ps(weights2, rhs));
233
+ rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[4]));
234
+ sum1_4 = _mm256_add_ps(sum1_4, _mm256_mul_ps(weights1, rhs));
235
+ sum2_4 = _mm256_add_ps(sum2_4, _mm256_mul_ps(weights2, rhs));
236
+ }
237
+
238
+ Extract4Results(relu, sum1_0, sum2_0, &out_ptrs[0]);
239
+ Extract4Results(relu, sum1_1, sum2_1, &out_ptrs[1]);
240
+ Extract4Results(relu, sum1_2, sum2_2, &out_ptrs[2]);
241
+ Extract4Results(relu, sum1_3, sum2_3, &out_ptrs[3]);
242
+ Extract4Results(relu, sum1_4, sum2_4, &out_ptrs[4]);
243
+ }
244
+ }
245
+
246
+ #ifdef __AVX2__
247
+
248
+ // In-line function to finish the computation of the result as 4x int32 in
249
+ // |sum|.
250
+ inline void Compute4Results(bool relu, int kShiftAmount, __m256i& sum) {
251
+ // Horizontally add the results. We have 1 register that contains results
252
+ // [0 0 1 1 2 2 3 3], but hadd (and almost no other AVX instruction) will not
253
+ // cross lanes, so we end up with [0 1 0 1 2 3 2 3]
254
+ sum = _mm256_hadd_epi32(sum, sum);
255
+ // Permutes the middle two pairs to get the answers together.
256
+ sum = _mm256_permute4x64_epi64(sum, 0xd8);
257
+ if (kShiftAmount > 0) {
258
+ // Shift right with rounding to get the right number of mantissa bits.
259
+ __m256i rounding = _mm256_set1_epi32(1 << (kShiftAmount - 1));
260
+ sum = _mm256_add_epi32(sum, rounding);
261
+ sum = _mm256_srai_epi32(sum, kShiftAmount);
262
+ }
263
+ // Now |sum| contains [|res0|, |res1|, |res2|, |res3|, |res0|, |res1|,
264
+ // |res2|, |res3|]
265
+ if (relu) {
266
+ sum = _mm256_max_epi32(sum, _mm256_setzero_si256());
267
+ }
268
+ }
269
+
270
+ // In-line function to extract the 4x int32 results from |sum| to memory.
271
+ // Non-const reference for |sum| as it is a register.
272
+ inline void Extract4xint32(bool relu, int kShiftAmount, __m256i& sum,
273
+ int32_t** out_ptr) {
274
+ Compute4Results(relu, kShiftAmount, sum);
275
+ // Save the lower 128 bits (4x int32).
276
+ __m128i result = _mm256_extractf128_si256(sum, 0);
277
+ _mm_store_si128(reinterpret_cast<__m128i*>(*out_ptr), result);
278
+ *out_ptr += 4;
279
+ }
280
+
281
+ // In-line function to extract the 4x int32 results from sum to 4x int16 in
282
+ // memory.
283
+ // Non-const reference for |sum| as it is a register.
284
+ inline void Extract4xint16(bool relu, int kShiftAmount, __m256i& sum,
285
+ int16_t** out_ptr) {
286
+ Compute4Results(relu, kShiftAmount, sum);
287
+ // Clip to 16 bit range (with saturation) and pack in the bottom 64 bits.
288
+ // Converts the lower 4x int32 in bottom 128 bits to 4x int16 in bottom 64
289
+ // bits, replicated in the next 64 bits.
290
+ sum = _mm256_packs_epi32(sum, sum);
291
+ // Save 4x int 16 from the bottom 64 bits.
292
+ *reinterpret_cast<int64_t*>(*out_ptr) = _mm256_extract_epi64(sum, 0);
293
+ *out_ptr += 4;
294
+ }
295
+
296
+ // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
297
+ // blocked pattern, x is a vector and b is vector. Weights are stored for this
298
+ // routine by making each 4x4 block contiguous. Blocks are ordered in standard
299
+ // row-major format. column indices are converted to deltas and then multiplied
300
+ // by 2 to convert to bytes, so that the value can be used directly to offset
301
+ // the pointer into the rhs vector.
302
+ //
303
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
304
+ // this function. This is automatically taken care of in SparseLinearLayer.
305
+ // The bias is reconstructed through horizontal additions, leads to a small
306
+ // speedup by reducing latencies at the end of the loop.
307
+ template <typename WeightType, typename RhsType, typename OutType>
308
+ typename std::enable_if<
309
+ IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value &&
310
+ (IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type
311
+ SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
312
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
313
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
314
+ OutType* out_ptr, int64_t assigned_rows,
315
+ int64_t rows /* only used in SpMM variants */,
316
+ int64_t cols /* only used in SpMM variants */, int relu) {
317
+ constexpr int kShiftAmount =
318
+ TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
319
+ OutType::kMantissaBits;
320
+ static_assert(kShiftAmount >= 0,
321
+ "Result must have fewer mantissa bits than product");
322
+ for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
323
+ // Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3].
324
+ __m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr));
325
+ __m256i biases = _mm256_set_m128i(bias, bias);
326
+ bias_ptr += 4;
327
+ // Swap the top two pairs: [0 1 2 3 2 3 0 1]
328
+ // TODO(b/188702959): consider |_mm256_permutevar8x32|, and set the index
329
+ // register outside the row loop.
330
+ biases = _mm256_permute4x64_epi64(biases, 0xb4);
331
+ // Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3].
332
+ biases = _mm256_unpacklo_epi32(biases, biases);
333
+ // Double the results to make up for the division by 4.
334
+ // TODO(b/188702959): consider moving this to where the biases are computed.
335
+ __m256i sum = _mm256_add_epi32(biases, biases);
336
+
337
+ // TODO(b/188702959): People don't like the old-fashioned, close-to-the-
338
+ // metal notation of *|nnz_per_row|++, so measure the effect of putting the
339
+ // increment in the for loop.
340
+ int reduced_col_count = *nnz_per_row;
341
+ ++nnz_per_row;
342
+ for (int c = 0; c < reduced_col_count; ++c) {
343
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
344
+ rhs_ptr += col_delta;
345
+ // Multiply this 4x4 block.
346
+ // Get the 4x int16 into the bottom of rhs_64.
347
+ __m128i rhs_64 =
348
+ _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptr));
349
+ // Load all 16 weights.
350
+ __m256i weights =
351
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
352
+ // Broadcast the rhs, pretending that each is a 64-bit unit:
353
+ // [0123 0123 0123 0123].
354
+ __m256i rhs = _mm256_broadcastq_epi64(rhs_64);
355
+ weights_ptr += 16;
356
+ // |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally
357
+ // adds adjacent pairs to make 8x32 bit results. Add these to the sum.
358
+ sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs));
359
+ }
360
+ static_assert(
361
+ IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value,
362
+ "AVX2 kernel only supports fixed16 and fixed32 types");
363
+ // The only significant difference between fixed16 and fixed32 is the size
364
+ // of the storage unit. The registers have to be repacked accordingly.
365
+ if (IsFixed32Type<OutType>::value) {
366
+ Extract4xint32(relu, kShiftAmount, sum,
367
+ reinterpret_cast<int32_t**>(&out_ptr));
368
+ } else {
369
+ Extract4xint16(relu, kShiftAmount, sum,
370
+ reinterpret_cast<int16_t**>(&out_ptr));
371
+ }
372
+ }
373
+ }
374
+
375
+ // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
376
+ // blocked pattern, x is a fat vector with 5 columns and b is vector. b is
377
+ // broadcast. Weights are stored for this routine by making each 4x4 block
378
+ // contiguous. Blocks are ordered in standard row-major format. column indices
379
+ // are converted to deltas and then multiplied by 2 to convert to bytes, so
380
+ // that the value can be used directly to offset the pointer into the rhs
381
+ // vector.
382
+ //
383
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
384
+ // this function. This is automatically taken care of in SparseLinearLayer.
385
+ // The bias is reconstructed through horizontal additions, leads to a small
386
+ // speedup by reducing latencies at the end of the loop.
387
+ template <typename WeightType, typename RhsType, typename OutType>
388
+ typename std::enable_if<
389
+ IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value &&
390
+ (IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type
391
+ SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
392
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
393
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
394
+ OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
395
+ int relu) {
396
+ constexpr int kShiftAmount =
397
+ TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
398
+ OutType::kMantissaBits;
399
+ static_assert(kShiftAmount >= 0,
400
+ "Result must have fewer mantissa bits than product");
401
+ const RhsType* rhs_ptrs[5];
402
+ for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
403
+
404
+ OutType* out_ptrs[5];
405
+ for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
406
+
407
+ for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
408
+ // We will acumulate the results in 5 registers, sum_0 to sum_4.
409
+ // Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3].
410
+ __m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr));
411
+ __m256i biases = _mm256_set_m128i(bias, bias);
412
+ bias_ptr += 4;
413
+ // Swap the top two pairs: [0 1 2 3 2 3 0 1]
414
+ biases = _mm256_permute4x64_epi64(biases, 0xb4);
415
+ // Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3].
416
+ biases = _mm256_unpacklo_epi32(biases, biases);
417
+ // Double the results to make up for the division by 4.
418
+ __m256i sum_0 = _mm256_add_epi32(biases, biases);
419
+ __m256i sum_1 = sum_0;
420
+ __m256i sum_2 = sum_0;
421
+ __m256i sum_3 = sum_0;
422
+ __m256i sum_4 = sum_0;
423
+
424
+ int reduced_col_count = *nnz_per_row;
425
+ ++nnz_per_row;
426
+ for (int c = 0; c < reduced_col_count; ++c) {
427
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
428
+ for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta;
429
+ // Multiply this 4x4 block.
430
+ // Get the 4x int16 into the bottom of |rhs_64|.
431
+ __m128i rhs_64 =
432
+ _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[0]));
433
+ // Load all 16 weights.
434
+ __m256i weights =
435
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
436
+ // Broadcast the rhs, pretending that each is a 64-bit unit:
437
+ // [0123 0123 0123 0123].
438
+ __m256i rhs = _mm256_broadcastq_epi64(rhs_64);
439
+ weights_ptr += 16;
440
+ // |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally
441
+ // adds adjacent pairs to make 8x32 bit results. Add these to the sum.
442
+ sum_0 = _mm256_add_epi32(sum_0, _mm256_madd_epi16(weights, rhs));
443
+ rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[1]));
444
+ rhs = _mm256_broadcastq_epi64(rhs_64);
445
+ sum_1 = _mm256_add_epi32(sum_1, _mm256_madd_epi16(weights, rhs));
446
+ rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[2]));
447
+ rhs = _mm256_broadcastq_epi64(rhs_64);
448
+ sum_2 = _mm256_add_epi32(sum_2, _mm256_madd_epi16(weights, rhs));
449
+ rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[3]));
450
+ rhs = _mm256_broadcastq_epi64(rhs_64);
451
+ sum_3 = _mm256_add_epi32(sum_3, _mm256_madd_epi16(weights, rhs));
452
+ rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[4]));
453
+ rhs = _mm256_broadcastq_epi64(rhs_64);
454
+ sum_4 = _mm256_add_epi32(sum_4, _mm256_madd_epi16(weights, rhs));
455
+ }
456
+ static_assert(
457
+ IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value,
458
+ "AVX2 kernel only supports fixed16 and fixed32 types");
459
+ // The only significant difference between fixed16 and fixed32 is the size
460
+ // of the storage unit. The registers have to be repacked accordingly.
461
+ if (IsFixed32Type<OutType>::value) {
462
+ Extract4xint32(relu, kShiftAmount, sum_0,
463
+ reinterpret_cast<int32_t**>(&out_ptrs[0]));
464
+ Extract4xint32(relu, kShiftAmount, sum_1,
465
+ reinterpret_cast<int32_t**>(&out_ptrs[1]));
466
+ Extract4xint32(relu, kShiftAmount, sum_2,
467
+ reinterpret_cast<int32_t**>(&out_ptrs[2]));
468
+ Extract4xint32(relu, kShiftAmount, sum_3,
469
+ reinterpret_cast<int32_t**>(&out_ptrs[3]));
470
+ Extract4xint32(relu, kShiftAmount, sum_4,
471
+ reinterpret_cast<int32_t**>(&out_ptrs[4]));
472
+ } else {
473
+ Extract4xint16(relu, kShiftAmount, sum_0,
474
+ reinterpret_cast<int16_t**>(&out_ptrs[0]));
475
+ Extract4xint16(relu, kShiftAmount, sum_1,
476
+ reinterpret_cast<int16_t**>(&out_ptrs[1]));
477
+ Extract4xint16(relu, kShiftAmount, sum_2,
478
+ reinterpret_cast<int16_t**>(&out_ptrs[2]));
479
+ Extract4xint16(relu, kShiftAmount, sum_3,
480
+ reinterpret_cast<int16_t**>(&out_ptrs[3]));
481
+ Extract4xint16(relu, kShiftAmount, sum_4,
482
+ reinterpret_cast<int16_t**>(&out_ptrs[4]));
483
+ }
484
+ }
485
+ }
486
+
487
+ // Processes one GRU gate input with sigmoid.
488
+ template <int InputMantissaBits, int StateMantissaBits, bool SplitGates>
489
+ inline __m256i GRUGateSigmoid(const void* gate_ptr, const void* gate_other_ptr,
490
+ const __m256i& input,
491
+ const int32_t* sigmoid_table) {
492
+ __m256i gate = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_ptr));
493
+ if (SplitGates) {
494
+ __m256i other =
495
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_other_ptr));
496
+ gate = _mm256_add_epi32(gate, other);
497
+ }
498
+ gate = _mm256_add_epi32(gate, input);
499
+ // Compute sigmoids on reset and update.
500
+ return csrblocksparse::fixed32_sigmoid_fixed16<InputMantissaBits,
501
+ StateMantissaBits>(
502
+ sigmoid_table, gate);
503
+ }
504
+
505
+ // Processes the tanh and the final combination, returning the new GRU state.
506
+ template <int InputMantissaBits, int StateMantissaBits, bool SplitGates = false>
507
+ inline __m256i GRUGateState(const __m256i& cell, const __m256i& reset,
508
+ const __m256i& update,
509
+ const __m256i& rounding_offset,
510
+ const void* gate_ptr, const void* gate_other_ptr,
511
+ const void* gru_h_ptr, const int32_t* tanh_table) {
512
+ // Multiply the cell GRU output and the reset. There is a slight danger of
513
+ // loss of precision here, so use 32x32=64 bit and shift back after.
514
+ __m256i gru = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_ptr));
515
+ if (SplitGates) {
516
+ __m256i other_gru =
517
+ _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_other_ptr));
518
+ gru = _mm256_add_epi32(gru, other_gru);
519
+ }
520
+ // This only computes the products of the low-order 32 bits of each pair.
521
+ __m256i gru_lo = _mm256_mul_epi32(gru, reset);
522
+ // Swap odd and even 32-bit units and do it again to get the high products.
523
+ gru = _mm256_shuffle_epi32(gru, 0xb1);
524
+ __m256i gru_hi = _mm256_mul_epi32(gru, _mm256_shuffle_epi32(reset, 0xb1));
525
+ // Now shift right to compensate for the multiply and re-interleave the
526
+ // 32-bit results.
527
+ // NOTE: There is no shift right arithmetic for 64 bit values until AVX512!
528
+ // Fortunately it doesn't matter, as the results are being truncated to 32
529
+ // bits and we aren't shifting right by more than 32 bits here.
530
+ gru_lo = _mm256_srli_epi64(gru_lo, StateMantissaBits);
531
+ // The upper results are shifted LEFT, so we can use blend to recombine in
532
+ // a single instruction.
533
+ gru_hi = _mm256_slli_epi64(gru_hi, 32 - StateMantissaBits);
534
+ // Recombine the 32 bit results from lo and hi, alternating.
535
+ gru = _mm256_blend_epi32(gru_lo, gru_hi, 0xaa);
536
+ gru = _mm256_add_epi32(cell, gru);
537
+ // Compute tanh on the result. Although this instantly discards a bunch of
538
+ // bits, there were only 7 surplus bits for the multiply, which isn't enough
539
+ // to do it as 16x16=32.
540
+ __m256i hbar =
541
+ csrblocksparse::fixed32_tanh_fixed16<InputMantissaBits,
542
+ StateMantissaBits>(tanh_table, gru);
543
+ // Load the 16-bit previous GRU state and sign-extend to 32 bits.
544
+ gru = _mm256_cvtepi16_epi32(
545
+ _mm_load_si128(reinterpret_cast<__m128i const*>(gru_h_ptr)));
546
+ gru = _mm256_sub_epi32(gru, hbar);
547
+ // Since |gru| is 16 bit sign-extended to 32, and |update| is the output of
548
+ // sigmoid, it is always contained within 16 bits and never negative, we can
549
+ // use |madd_epi16| to do 16x16=32 multiply with horizontal adding as the
550
+ // addend will always be zero, and this is twice as fast as full blown
551
+ // 32x32=32. The only possible problem is if the subtract above caused
552
+ // overflow.
553
+ gru = _mm256_madd_epi16(gru, update);
554
+ // Renormalize to fixed16. This time rounding is critical, as this is the
555
+ // output GRU state.
556
+ gru = _mm256_add_epi32(gru, rounding_offset);
557
+ gru = _mm256_srai_epi32(gru, StateMantissaBits);
558
+ return _mm256_add_epi32(gru, hbar);
559
+ }
560
+
561
+ template <typename Type>
562
+ typename std::enable_if<IsFixed32Type<Type>::value>::type SumVectors(
563
+ int start, int end, const Type* add1, const Type* add2, Type* result) {
564
+ constexpr int kSIMDWidth = 8;
565
+ for (int i = start; i < end; i += kSIMDWidth) {
566
+ __m256i data1 =
567
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i));
568
+ __m256i data2 =
569
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i));
570
+ data1 = _mm256_add_epi32(data1, data2);
571
+ _mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1);
572
+ }
573
+ }
574
+
575
+ template <typename Type>
576
+ typename std::enable_if<IsFixed16Type<Type>::value>::type SumVectors(
577
+ int start, int end, const Type* add1, const Type* add2, Type* result) {
578
+ constexpr int kSIMDWidth = 16;
579
+ for (int i = start; i < end; i += kSIMDWidth) {
580
+ __m256i data1 =
581
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i));
582
+ __m256i data2 =
583
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i));
584
+ data1 = _mm256_add_epi16(data1, data2);
585
+ _mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1);
586
+ }
587
+ }
588
+
589
+ #endif // __AVX2__
590
+
591
+ } // namespace detail
592
+ } // namespace csrblocksparse
593
+
594
+ #undef LABEL_COL_LOOP
595
+ #undef LABEL_ROW_LOOP
596
+ #undef LABEL_SKIP_COL_LOOP
597
+ #undef LABEL_TOP_LOOP
598
+
599
+ #endif // __AVX__
600
+
601
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_
sparse_matmul/compute/kernels_generic.h ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_
19
+
20
+ #include <algorithm>
21
+ #include <type_traits>
22
+
23
+ #include "sparse_matmul/numerics/fixed_types.h"
24
+ #include "sparse_matmul/numerics/float16_types.h"
25
+ #include "sparse_matmul/numerics/type_utils.h"
26
+
27
+ // Separate out the assembly kernels for readability. Eventually this will
28
+ // become an ifdef switch on the architecture type.
29
+ #if defined __aarch64__
30
+ #include "sparse_matmul/compute/kernels_arm.h"
31
+ #elif defined __AVX__
32
+ #include "sparse_matmul/compute/kernels_avx.h"
33
+ #else // defined __AVX__
34
+ // If there is no architecture-specific implementation, then always use generic.
35
+ template <typename WeightType, typename RhsType, typename OutType>
36
+ struct ShouldEnableGenericSpMV_4x4 : std::true_type {};
37
+ template <typename WeightType, typename RhsType, typename OutType>
38
+ struct ShouldEnableGenericSpMM5_4x4 : std::true_type {};
39
+ template <typename WeightType, typename RhsType, typename OutType>
40
+ struct ShouldEnableGenericSpMV_1x1 : std::true_type {};
41
+ template <typename WeightType, typename RhsType, typename OutType>
42
+ struct ShouldEnableGenericSpMM5_1x1 : std::true_type {};
43
+ template <typename Type>
44
+ struct ShouldEnableGenericAdd : std::true_type {};
45
+ #endif // defined __arch64__
46
+
47
+ namespace csrblocksparse {
48
+ namespace detail {
49
+
50
+ // The computational routines do NO error checking for speed. It is assumed
51
+ // that this has been handled by CSRBlockSparseMatrix.
52
+
53
+ // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
54
+ // blocked pattern, x is a vector and b is vector. Weights are stored for this
55
+ // routine by making each 4x4 block contiguous. Blocks are ordered in standard
56
+ // row-major format. column indices are converted to deltas and then multiplied
57
+ // by 2 to convert to bytes, so that the value can be used directly to offset
58
+ // the pointer into the rhs vector.
59
+ //
60
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
61
+ // this function. This is automatically taken care of in SparseLinearLayer.
62
+ // The bias is reconstructed through horizontal additions, leads to a small
63
+ // speedup by reducing latencies at the end of the loop.
64
+ template <typename WeightType, typename RhsType, typename OutType>
65
+ typename std::enable_if<
66
+ ShouldEnableGenericSpMV_4x4<WeightType, RhsType, OutType>::value>::type
67
+ SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
68
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
69
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
70
+ OutType* out_ptr, int64_t assigned_rows,
71
+ int64_t rows /* only used in SpMM variants */,
72
+ int64_t cols /* only used in SpMM variants */, int relu) {
73
+ for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
74
+ float accumulators[4];
75
+ // Undo the divion by the happens for the assembly version.
76
+ for (int i = 0; i < 4; ++i)
77
+ accumulators[i] = 4.f * static_cast<float>(*bias_ptr++);
78
+
79
+ int reduced_col_count = *nnz_per_row++;
80
+ for (int c = 0; c < reduced_col_count; ++c) {
81
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
82
+ rhs_ptr += col_delta;
83
+
84
+ // Multiply this 4x4 block.
85
+ for (int i = 0; i < 4; ++i) {
86
+ for (int j = 0; j < 4; ++j) {
87
+ accumulators[i] += static_cast<float>(*weights_ptr++) *
88
+ static_cast<float>(rhs_ptr[j]);
89
+ }
90
+ }
91
+ }
92
+
93
+ for (int i = 0; i < 4; ++i)
94
+ *out_ptr++ = static_cast<OutType>(relu ? std::max(accumulators[i], 0.f)
95
+ : accumulators[i]);
96
+ }
97
+ }
98
+
99
+ // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
100
+ // blocked pattern, x is a fat vector with 5 columns and b is vector. b is
101
+ // broadcast. Weights are stored for this routine by making each 4x4 block
102
+ // contiguous. Blocks are ordered in standard row-major format. column indices
103
+ // are converted to deltas and then multiplied by 2 to convert to bytes, so
104
+ // that the value can be used directly to offset the pointer into the rhs
105
+ // vector.
106
+ //
107
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
108
+ // this function. This is automatically taken care of in SparseLinearLayer.
109
+ // The bias is reconstructed through horizontal additions, leads to a small
110
+ // speedup by reducing latencies at the end of the loop.
111
+ template <typename WeightType, typename RhsType, typename OutType>
112
+ typename std::enable_if<
113
+ ShouldEnableGenericSpMM5_4x4<WeightType, RhsType, OutType>::value>::type
114
+ SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
115
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
116
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
117
+ OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
118
+ int relu) {
119
+ const RhsType* rhs_ptrs[5];
120
+ for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
121
+
122
+ OutType* out_ptrs[5];
123
+ for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
124
+
125
+ for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
126
+ float accumulators[4][5];
127
+ // Undo the divion by the happens for the assembly version.
128
+ for (int i = 0; i < 4; ++i) {
129
+ for (int k = 0; k < 5; ++k) {
130
+ accumulators[i][k] = 4.f * static_cast<float>(*bias_ptr);
131
+ }
132
+ ++bias_ptr;
133
+ }
134
+
135
+ int reduced_col_count = *nnz_per_row++;
136
+ for (int c = 0; c < reduced_col_count; ++c) {
137
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
138
+ for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta;
139
+
140
+ // multiply this 4x4 block
141
+ for (int i = 0; i < 4; ++i) {
142
+ for (int j = 0; j < 4; ++j) {
143
+ for (int k = 0; k < 5; ++k) {
144
+ accumulators[i][k] += static_cast<float>(*weights_ptr) *
145
+ static_cast<float>(rhs_ptrs[k][j]);
146
+ }
147
+ weights_ptr++;
148
+ }
149
+ }
150
+ }
151
+
152
+ for (int k = 0; k < 5; ++k) {
153
+ for (int i = 0; i < 4; ++i) {
154
+ out_ptrs[k][0] = static_cast<OutType>(
155
+ relu ? std::max(accumulators[i][k], 0.f) : accumulators[i][k]);
156
+ out_ptrs[k]++;
157
+ }
158
+ }
159
+ }
160
+ }
161
+
162
+ // Performs the calculation y = A * x + b where A is a sparse matrix with
163
+ // a 1x1 blocked pattern (ie unstructured), x is a
164
+ // vector and b is vector.
165
+ // Weights are stored for this routine in standard CSR format. Each row must
166
+ // have a multiple of 8 columns.
167
+ // column indices are converted to deltas and then multiplied by 2 to convert
168
+ // to bytes, so that the value can be used directly to offset the pointer
169
+ // into the rhs vector.
170
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
171
+ // this function. This is automatically taken care of in SparseLinearLayer.
172
+ // The bias is reconstructed through horizontal additions, leads to a small
173
+ // speedup by reducing latencies at the end of the loop.
174
+ template <typename WeightType, typename RhsType, typename OutType>
175
+ typename std::enable_if<
176
+ ShouldEnableGenericSpMV_1x1<WeightType, RhsType, OutType>::value>::type
177
+ SpMV_1x1(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
178
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
179
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
180
+ OutType* out_ptr, int64_t assigned_rows,
181
+ int64_t rows /* only used in SpMM variants */,
182
+ int64_t cols /* only used in SpMM variants */, int relu) {
183
+ for (int row = 0; row < assigned_rows; ++row) {
184
+ // Undo the divion by the happens for the assembly version.
185
+ float accumulator = 4.f * static_cast<float>(*bias_ptr++);
186
+
187
+ int col_count = *nnz_per_row++;
188
+ for (int c = 0; c < col_count; ++c) {
189
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
190
+ rhs_ptr += col_delta;
191
+
192
+ accumulator +=
193
+ static_cast<float>(*weights_ptr++) * static_cast<float>(*rhs_ptr);
194
+ }
195
+
196
+ *out_ptr++ =
197
+ static_cast<OutType>(relu ? std::max(accumulator, 0.f) : accumulator);
198
+ }
199
+ }
200
+
201
+ // Performs the calculation y = A * x + b where A is a sparse matrix with
202
+ // a 1x1 blocked pattern (ie unstructured), x is a
203
+ // vector and b is vector.
204
+ // Weights are stored for this routine in standard CSR format. Each row must
205
+ // have a multiple of 8 columns.
206
+ // column indices are converted to deltas and then multiplied by 2 to convert
207
+ // to bytes, so that the value can be used directly to offset the pointer
208
+ // into the rhs vector.
209
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
210
+ // this function. This is automatically taken care of in SparseLinearLayer.
211
+ // The bias is reconstructed through horizontal additions, leads to a small
212
+ // speedup by reducing latencies at the end of the loop.
213
+ template <typename WeightType, typename RhsType, typename OutType>
214
+ typename std::enable_if<
215
+ ShouldEnableGenericSpMM5_1x1<WeightType, RhsType, OutType>::value>::type
216
+ SpMM5_1x1(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
217
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
218
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
219
+ OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
220
+ int relu) {
221
+ const RhsType* rhs_ptrs[5];
222
+ for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
223
+
224
+ OutType* out_ptrs[5];
225
+ for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
226
+
227
+ for (int row = 0; row < assigned_rows; ++row) {
228
+ // Undo the divion by the happens for the assembly version.
229
+ float accumulator[5];
230
+ for (int i = 0; i < 5; ++i)
231
+ accumulator[i] = 4.f * static_cast<float>(*bias_ptr);
232
+
233
+ ++bias_ptr;
234
+
235
+ int col_count = *nnz_per_row++;
236
+ for (int c = 0; c < col_count; ++c) {
237
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
238
+ for (int i = 0; i < 5; ++i) {
239
+ rhs_ptrs[i] += col_delta;
240
+ accumulator[i] += static_cast<float>(*weights_ptr) *
241
+ static_cast<float>(rhs_ptrs[i][0]);
242
+ }
243
+ weights_ptr++;
244
+ }
245
+
246
+ for (int i = 0; i < 5; ++i) {
247
+ out_ptrs[i][0] = static_cast<OutType>(relu ? std::max(accumulator[i], 0.f)
248
+ : accumulator[i]);
249
+ out_ptrs[i]++;
250
+ }
251
+ }
252
+ }
253
+
254
+ template <typename Type>
255
+ typename std::enable_if<ShouldEnableGenericAdd<Type>::value>::type SumVectors(
256
+ int start, int end, const Type* add1, const Type* add2, Type* result) {
257
+ LOG_FIRST_N(WARNING, 1) << "SumVectors: using generic kernel!";
258
+ for (int i = start; i < end; ++i) {
259
+ Type sum = static_cast<Type>(static_cast<float>(add1[i]) +
260
+ static_cast<float>(add2[i]));
261
+ result[i] = sum;
262
+ }
263
+ }
264
+
265
+ } // namespace detail
266
+ } // namespace csrblocksparse
267
+
268
+ #undef LABEL_COL_LOOP
269
+ #undef LABEL_ROW_LOOP
270
+ #undef LABEL_SKIP_COL_LOOP
271
+ #undef LABEL_TOP_LOOP
272
+
273
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_
sparse_matmul/compute/matmul.h ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_
19
+
20
+ #include <cstdint>
21
+ #include <vector>
22
+
23
+ #include "absl/time/time.h"
24
+ #include "sparse_matmul/compute/matmul_fixed_avx2.h"
25
+ #include "sparse_matmul/compute/matmul_generic.h"
26
+ #include "sparse_matmul/numerics/fixed_types.h"
27
+ #include "sparse_matmul/numerics/type_utils.h"
28
+ #if defined(__x86_64__) || defined(__i386__) || defined(_WIN32)
29
+ #include <cpuid.h>
30
+ #endif
31
+
32
+ namespace csrblocksparse {
33
+
34
+ // The number of elements in a block.
35
+ constexpr int kBlockSize = 4;
36
+
37
+ // Base class for Matmul containing the members that are non type-specicfic.
38
+ class MatmulBase {
39
+ public:
40
+ // Constructor initializes the flags that determine which implementation to
41
+ // use at run-time, constrained by both compiler flags and cpuid.
42
+ MatmulBase() {
43
+ #if defined(__x86_64__) || defined(__i386__) || defined(_WIN32)
44
+ // Code tested to work on Linux systems and multiple Android emulators.
45
+ unsigned int eax, ebx, ecx, edx;
46
+ if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) != 0) {
47
+ using_avx_ = (ecx & bit_AVX) != 0;
48
+ if (using_avx_) {
49
+ __get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx);
50
+ using_avx2_ = (ebx & bit_AVX2) != 0;
51
+ using_avx512_ = (ebx & bit_AVX512F) != 0 && (ebx & bit_AVX512DQ) &&
52
+ (ebx & bit_AVX512BW) != 0;
53
+ VLOG(2) << "avx2 flag=" << using_avx2_ << " 512=" << using_avx512_;
54
+ } else {
55
+ LOG(ERROR) << "AVX not found at all!";
56
+ }
57
+ }
58
+ #else
59
+ using_aarch64_ = true;
60
+ #endif
61
+ }
62
+
63
+ protected:
64
+ // Flags that define what (runtime) architectures are available. Flags that
65
+ // are set are limited by both the compiler flags and runtime environment.
66
+ bool using_avx512_ = false;
67
+ bool using_avx2_ = false;
68
+ bool using_avx_ = false;
69
+ bool using_aarch64_ = false;
70
+ };
71
+
72
+ // The master template is really a catch-all for the unimplmented cases to
73
+ // report an error.
74
+ template <typename WeightType, typename RhsType>
75
+ class Matmul : public MatmulBase {
76
+ public:
77
+ // Sparse inputs, outputs replicated strided for each thread.
78
+ template <typename OutType>
79
+ void MatVec4x4(const WeightType* weights, const RhsType* rhs,
80
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias,
81
+ const int32_t* nnz_per_row, const int16_t* rhs_indices,
82
+ int start_row, int end_row, bool relu, int replicas,
83
+ int stride, OutType* output) {
84
+ // The specializations should take care of every real case.
85
+ CHECK(false) << "Unsupported combination of types used!";
86
+ }
87
+ template <typename OutType>
88
+ void MatVec8x4(const WeightType* weights, const RhsType* rhs,
89
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias,
90
+ const int32_t* nnz_per_row, const int16_t* rhs_indices,
91
+ int start_row, int end_row, bool relu, int replicas,
92
+ int stride, OutType* output) {
93
+ // The specializations should take care of every real case.
94
+ CHECK(false) << "Unsupported combination of types used!";
95
+ }
96
+ };
97
+
98
+ // Full specialization for float.
99
+ template <>
100
+ class Matmul<float, float> : public MatmulBase {
101
+ public:
102
+ void MatVec4x4(const float* weights, const float* rhs, const float* bias,
103
+ const int32_t* nnz_per_row, const int16_t* rhs_indices,
104
+ int start_row, int end_row, bool relu, int replicas,
105
+ int stride, float* output) {
106
+ detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
107
+ start_row, end_row, /*block_height=*/4,
108
+ /*block_width=*/4, relu, replicas, stride,
109
+ output);
110
+ }
111
+ void MatVec8x4(const float* weights, const float* rhs, const float* bias,
112
+ const int32_t* nnz_per_row, const int16_t* rhs_indices,
113
+ int start_row, int end_row, bool relu, int replicas,
114
+ int stride, float* output) {
115
+ detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
116
+ start_row, end_row, /*block_height=*/8,
117
+ /*block_width=*/4, relu, replicas, stride,
118
+ output);
119
+ }
120
+ };
121
+
122
+ // Partial specialization for fixed types. Covers fixed16xfixed16 = OutType,
123
+ // where OutType should be fixed16 or fixed32. The mantissa bits don't have
124
+ // to match.
125
+ template <int WeightBits, int RhsBits>
126
+ class Matmul<fixed16<WeightBits>, fixed16<RhsBits>> : public MatmulBase {
127
+ public:
128
+ using WeightType = fixed16<WeightBits>;
129
+ using RhsType = fixed16<RhsBits>;
130
+
131
+ template <typename OutType>
132
+ void MatVec4x4(const int16_t* weights, const int16_t* rhs,
133
+ const int32_t* bias, const int32_t* nnz_per_row,
134
+ const int16_t* rhs_indices, int start_row, int end_row,
135
+ bool relu, int replicas, int stride, OutType* output) {
136
+ constexpr int kShiftAmount =
137
+ TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
138
+ OutType::kMantissaBits;
139
+ static_assert(kShiftAmount >= 0,
140
+ "OutType must not have more mantissa bits than inputs");
141
+ #if defined __AVX2__
142
+ CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!";
143
+ if (sizeof(*output) == 4) {
144
+ int32_t* out32 = reinterpret_cast<int32_t*>(output);
145
+ detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices,
146
+ start_row, end_row, relu, kShiftAmount,
147
+ replicas, stride, out32);
148
+ } else {
149
+ int16_t* out16 = reinterpret_cast<int16_t*>(output);
150
+ detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices,
151
+ start_row, end_row, relu, kShiftAmount,
152
+ replicas, stride, out16);
153
+ }
154
+ #elif defined __aarch64__
155
+ if (using_aarch64_) {
156
+ LOG(FATAL) << "Fixed16 MatVec4x4 not yet implemented!";
157
+ }
158
+
159
+ #else
160
+ detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
161
+ start_row, end_row, /*block_height=*/4,
162
+ /*block_width=*/4, relu, sizeof(*output),
163
+ kShiftAmount, replicas, stride, output);
164
+ #endif // __AVX2__
165
+ }
166
+
167
+ template <typename OutType>
168
+ void MatVec8x4(const int16_t* weights, const int16_t* rhs,
169
+ const int32_t* bias, const int32_t* nnz_per_row,
170
+ const int16_t* rhs_indices, int start_row, int end_row,
171
+ bool relu, int replicas, int stride, OutType* output) {
172
+ constexpr int kShiftAmount =
173
+ TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
174
+ OutType::kMantissaBits;
175
+ static_assert(kShiftAmount >= 0,
176
+ "OutType must not have more mantissa bits than inputs");
177
+ #if defined __AVX2__
178
+ CHECK(replicas == 1 && sizeof(*output) == 4)
179
+ << "Only replicas == 1 and fixed32 output are implemented for AVX2!";
180
+ CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!";
181
+ int32_t* out32 = reinterpret_cast<int32_t*>(output);
182
+ detail::MatVec8x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices,
183
+ start_row, end_row, relu, kShiftAmount, out32);
184
+ #elif defined __aarch64__
185
+ if (using_aarch64_) {
186
+ LOG(FATAL) << "Fixed16 MatVec8x4 not yet implemented!";
187
+ }
188
+ #else
189
+ detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
190
+ start_row, end_row, /*block_height=*/8,
191
+ /*block_width=*/4, relu, sizeof(*output),
192
+ kShiftAmount, replicas, stride, output);
193
+ #endif // __AVX2__
194
+ }
195
+ };
196
+
197
+ } // namespace csrblocksparse
198
+
199
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_
sparse_matmul/compute/matmul_fixed_avx2.cc ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include "sparse_matmul/compute/matmul_fixed_avx2.h"
16
+
17
+ #include <cstdint>
18
+
19
+ #if defined __AVX__
20
+ #include <immintrin.h>
21
+ #endif
22
+
23
+ #include "sparse_matmul/compute/matmul.h"
24
+
25
+ namespace csrblocksparse {
26
+ namespace detail {
27
+
28
+ static const int32_t kint32min = static_cast<int32_t>(~0x7FFFFFFF);
29
+ static const int32_t kint32max = static_cast<int32_t>(0x7FFFFFFF);
30
+
31
+ #if defined __AVX2__
32
+ // In-line function computes and returns the result of one row (of blocks) as
33
+ // 4x int32_t. |weights_ptr| is a non-const reference so it can easily be
34
+ // interpreted as belonging to the caller.
35
+ inline __m256i ComputeRowResults(const __m128i& bias128, const int16_t* rhs,
36
+ const int16_t* rhs_indices, int nnz,
37
+ int16_t const*& weights_ptr) {
38
+ // Expand bias to 64 bits in a 256 bit register [0 z 1 z 2 z 3 z], where z is
39
+ // Zero and 0-3 are the 4x32 bit bias values.
40
+ __m256i sum = _mm256_cvtepu32_epi64(bias128);
41
+
42
+ for (int c = 0; c < nnz; ++c) {
43
+ int rhs_index = rhs_indices[c];
44
+ // Load all 16 weights.
45
+ __m256i weights =
46
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
47
+ // Get the 4x int16_t into the bottom of |rhs_64|.
48
+ __m128i rhs_64 = _mm_loadl_epi64(
49
+ reinterpret_cast<__m128i const*>(rhs + rhs_index * kBlockSize));
50
+ // Broadcast the rhs, pretending that each is a 64-bit unit:
51
+ // [0123 0123 0123 0123].
52
+ __m256i rhs_value = _mm256_broadcastq_epi64(rhs_64);
53
+ weights_ptr += 16;
54
+ sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs_value));
55
+ }
56
+ // Horizontally add the results. We have 1 register that contains results
57
+ // [0 0 1 1 2 2 3 3], but hadd (and almost no other AVX instruction) will not
58
+ // cross lanes, so we end up with [0 1 0 1 2 3 2 3]
59
+ sum = _mm256_hadd_epi32(sum, sum);
60
+ // Permutes the middle two pairs to get the answers together.
61
+ return _mm256_permute4x64_epi64(sum, 0xd8);
62
+ }
63
+
64
+ // Template that allows any fixed combination of OutType and replicas, plus
65
+ // variable |relu|, |shift_out|. Note that |kReplicas| is a template arg as
66
+ // well as a function arg so we can hard-code a limited amount of unrolling.
67
+ template <typename OutType, int kReplicas>
68
+ void MatVec4x4FixedAVX2Template(const int16_t* weights_ptr, const int16_t* rhs,
69
+ const int32_t* bias, const int32_t* nnz_per_row,
70
+ const int16_t* rhs_indices, int start_row,
71
+ int end_row, bool relu, int shift_out,
72
+ int replicas, int stride, OutType* output) {
73
+ int rounding_addon = shift_out > 0 ? (1 << (shift_out - 1)) : 0;
74
+ __m256i rounding = _mm256_set1_epi32(rounding_addon);
75
+ __m256i zero = relu ? _mm256_setzero_si256() : _mm256_set1_epi32(kint32min);
76
+ for (int row_block = start_row; row_block < end_row; ++row_block) {
77
+ // Load 4 biases [0 1 2 3].
78
+ __m128i bias128 = _mm_load_si128(reinterpret_cast<__m128i const*>(bias));
79
+ bias += kBlockSize;
80
+ int nnz = nnz_per_row[row_block];
81
+ __m256i sum =
82
+ ComputeRowResults(bias128, rhs, rhs_indices, nnz, weights_ptr);
83
+ rhs_indices += nnz;
84
+ // Shift right with rounding to get the right number of mantissa bits.
85
+ sum = _mm256_add_epi32(sum, rounding);
86
+ sum = _mm256_srai_epi32(sum, shift_out);
87
+ // Now sum contains [res0, res1, res2, res3, res0, res1, res2, res3]
88
+ sum = _mm256_max_epi32(sum, zero);
89
+ if (sizeof(OutType) == 2) {
90
+ // Clip to 16 bit range (with saturation) and pack in the bottom 64
91
+ // bits. The 64 bit result is replicated across the whole 256 bit
92
+ // register. [0123 0123 0123 0123]
93
+ sum = _mm256_packs_epi32(sum, sum);
94
+ int64_t result = _mm256_extract_epi64(sum, 0);
95
+ *reinterpret_cast<int64_t*>(output) = result;
96
+ if (kReplicas > 1) {
97
+ *reinterpret_cast<int64_t*>(output + stride) = result;
98
+ if (kReplicas > 2) {
99
+ for (int r = 2; r < replicas; ++r) {
100
+ *reinterpret_cast<int64_t*>(output + r * stride) = result;
101
+ }
102
+ }
103
+ }
104
+ } else {
105
+ // Save the lower 128 bits (4x int32_t).
106
+ __m128i result = _mm256_extractf128_si256(sum, 0);
107
+ _mm_store_si128(reinterpret_cast<__m128i*>(output), result);
108
+ if (kReplicas > 1) {
109
+ _mm_store_si128(reinterpret_cast<__m128i*>(output + stride), result);
110
+ if (kReplicas > 2) {
111
+ for (int r = 2; r < replicas; ++r) {
112
+ _mm_store_si128(reinterpret_cast<__m128i*>(output + r * stride),
113
+ result);
114
+ }
115
+ }
116
+ }
117
+ }
118
+ output += kBlockSize;
119
+ }
120
+ }
121
+
122
+ // Version that covers all possible combinations of the variable conditions:
123
+ // |relu|, |shift_out|, |replicas|, with int16_t |output|.
124
+ void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
125
+ const int32_t* bias, const int32_t* nnz_per_row,
126
+ const int16_t* rhs_indices, int start_row, int end_row,
127
+ bool relu, int shift_out, int replicas, int stride,
128
+ int16_t* output) {
129
+ if (replicas <= 1) {
130
+ MatVec4x4FixedAVX2Template<int16_t, 1>(weights_ptr, rhs, bias, nnz_per_row,
131
+ rhs_indices, start_row, end_row,
132
+ relu, shift_out, 1, stride, output);
133
+ } else if (replicas == 2) {
134
+ MatVec4x4FixedAVX2Template<int16_t, 2>(weights_ptr, rhs, bias, nnz_per_row,
135
+ rhs_indices, start_row, end_row,
136
+ relu, shift_out, 2, stride, output);
137
+ } else {
138
+ MatVec4x4FixedAVX2Template<int16_t, 3>(
139
+ weights_ptr, rhs, bias, nnz_per_row, rhs_indices, start_row, end_row,
140
+ relu, shift_out, replicas, stride, output);
141
+ }
142
+ }
143
+
144
+ // Version that covers all possible combinations of the variable conditions:
145
+ // |relu|, |shift_out|, |replicas|, with int32_t |output|.
146
+ void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
147
+ const int32_t* bias, const int32_t* nnz_per_row,
148
+ const int16_t* rhs_indices, int start_row, int end_row,
149
+ bool relu, int shift_out, int replicas, int stride,
150
+ int32_t* output) {
151
+ if (replicas <= 1) {
152
+ MatVec4x4FixedAVX2Template<int32_t, 1>(weights_ptr, rhs, bias, nnz_per_row,
153
+ rhs_indices, start_row, end_row,
154
+ relu, shift_out, 1, stride, output);
155
+ } else if (replicas == 2) {
156
+ MatVec4x4FixedAVX2Template<int32_t, 2>(weights_ptr, rhs, bias, nnz_per_row,
157
+ rhs_indices, start_row, end_row,
158
+ relu, shift_out, 2, stride, output);
159
+ } else {
160
+ MatVec4x4FixedAVX2Template<int32_t, 3>(
161
+ weights_ptr, rhs, bias, nnz_per_row, rhs_indices, start_row, end_row,
162
+ relu, shift_out, replicas, stride, output);
163
+ }
164
+ }
165
+
166
+ // In-line function computes and returns the result of one row (of blocks) as
167
+ // 8x int32_t. weights_ptr is a non-const reference so it can easily be
168
+ // interpreted as belonging to the caller.
169
+ inline __m256i Compute8RowResults(const __m256i& bias256, const int16_t* rhs,
170
+ const int16_t* rhs_indices, int nnz,
171
+ int16_t const*& weights_ptr) {
172
+ // Expand bias to 64 bits in a 256 bit register [0 z 1 z 2 z 3 z], where z is
173
+ // Zero and 0-3 are the 4x32 bit bias values from 128 bit half of the input.
174
+ __m256i sum1 = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(bias256));
175
+ // Plus 4 more in another sum register from the upper 128 bit half.
176
+ __m256i sum2 = _mm256_cvtepu32_epi64(_mm256_extractf128_si256(bias256, 1));
177
+
178
+ for (int c = 0; c < nnz; ++c) {
179
+ int rhs_index = rhs_indices[c];
180
+ // Load all 16 weights.
181
+ __m256i weights =
182
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
183
+ // Get the 4x int16_t into the bottom of |rhs_64|.
184
+ __m128i rhs_64 = _mm_loadl_epi64(
185
+ reinterpret_cast<__m128i const*>(rhs + rhs_index * kBlockSize));
186
+ // Broadcast the rhs, pretending that each is a 64-bit unit:
187
+ // [0123 0123 0123 0123].
188
+ __m256i rhs_value = _mm256_broadcastq_epi64(rhs_64);
189
+ weights_ptr += 16;
190
+ sum1 = _mm256_add_epi32(sum1, _mm256_madd_epi16(weights, rhs_value));
191
+ // Same again for the other 4 results, re-using the same rhs value.
192
+ weights = _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
193
+ weights_ptr += 16;
194
+ sum2 = _mm256_add_epi32(sum2, _mm256_madd_epi16(weights, rhs_value));
195
+ }
196
+ // Horizontally add the results. We have 2 registers that contain results
197
+ // [0 0 1 1 2 2 3 3], and [4 4 5 5 6 6 7 7] but hadd (and almost no other AVX
198
+ // instruction) will not cross lanes, so we end up with [0 1 4 5 2 3 6 7]
199
+ sum1 = _mm256_hadd_epi32(sum1, sum2);
200
+ // Permutes the middle two pairs to get the answers in the right order.
201
+ return _mm256_permute4x64_epi64(sum1, 0xd8);
202
+ }
203
+
204
+ // Version that covers the main conditions used with 8x4:
205
+ // |relu|, |shift_out|, with int32_t |output|.
206
+ void MatVec8x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
207
+ const int32_t* bias, const int32_t* nnz_per_row,
208
+ const int16_t* rhs_indices, int start_row, int end_row,
209
+ bool relu, int shift_out, int32_t* output) {
210
+ int rounding_addon = shift_out > 0 ? (1 << (shift_out - 1)) : 0;
211
+ __m256i rounding = _mm256_set1_epi32(rounding_addon);
212
+ __m256i zero = relu ? _mm256_setzero_si256() : _mm256_set1_epi32(kint32min);
213
+ for (int row_block = start_row; row_block < end_row; ++row_block) {
214
+ // Load 4 biases [0 1 2 3 4 5 6 7].
215
+ __m256i bias256 = _mm256_load_si256(reinterpret_cast<__m256i const*>(bias));
216
+ bias += kBlockSize * 2;
217
+ int nnz = nnz_per_row[row_block];
218
+ __m256i sum =
219
+ Compute8RowResults(bias256, rhs, rhs_indices, nnz, weights_ptr);
220
+ rhs_indices += nnz;
221
+ // Shift right with rounding to get the right number of mantissa bits.
222
+ sum = _mm256_add_epi32(sum, rounding);
223
+ sum = _mm256_srai_epi32(sum, shift_out);
224
+ // Now sum contains [res0, res1, res2, res3, res0, res1, res2, res3]
225
+ sum = _mm256_max_epi32(sum, zero);
226
+ // Save the all 256 bits (8x int32_t).
227
+ _mm256_store_si256(reinterpret_cast<__m256i*>(output), sum);
228
+ output += kBlockSize * 2;
229
+ }
230
+ }
231
+
232
+ #endif
233
+
234
+ } // namespace detail
235
+ } // namespace csrblocksparse
sparse_matmul/compute/matmul_fixed_avx2.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_
19
+
20
+ #include <cstdint>
21
+
22
+ namespace csrblocksparse {
23
+ namespace detail {
24
+
25
+ // Version that covers all possible combinations of the variable conditions:
26
+ // |relu|, |shift_out|, |replicas|, with int16 output.
27
+ void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
28
+ const int32_t* bias, const int32_t* nnz_per_row,
29
+ const int16_t* rhs_indices, int start_row, int end_row,
30
+ bool relu, int shift_out, int replicas, int stride,
31
+ int16_t* output);
32
+ // Version that covers all possible combinations of the variable conditions:
33
+ // |relu|, |shift_out|, |replicas|, with int32 output.
34
+ void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
35
+ const int32_t* bias, const int32_t* nnz_per_row,
36
+ const int16_t* rhs_indices, int start_row, int end_row,
37
+ bool relu, int shift_out, int replicas, int stride,
38
+ int32_t* output);
39
+ // Version that covers the main conditions used with 8x4:
40
+ // |relu|, |shift_out|, with int32 output.
41
+ void MatVec8x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
42
+ const int32_t* bias, const int32_t* nnz_per_row,
43
+ const int16_t* rhs_indices, int start_row, int end_row,
44
+ bool relu, int shift_out, int32_t* output);
45
+
46
+ } // namespace detail
47
+ } // namespace csrblocksparse
48
+
49
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_
sparse_matmul/compute/matmul_generic.cc ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include "sparse_matmul/compute/matmul_generic.h"
16
+
17
+ #include <cstdint>
18
+ #include <vector>
19
+
20
+ #include "sparse_matmul/compute/matmul.h"
21
+
22
+ namespace csrblocksparse {
23
+ namespace detail {
24
+
25
+ void MatVecFloatGeneric(const float* weights, const float* rhs,
26
+ const float* bias, const int32_t* nnz_per_row,
27
+ const int16_t* rhs_indices, int start_row, int end_row,
28
+ int block_height, int block_width, bool relu,
29
+ int replicas, int stride, float* output) {
30
+ int weight_index = 0;
31
+ int bias_index = 0;
32
+ std::vector<float> accumulators(block_height);
33
+ for (int row_block = start_row; row_block < end_row;
34
+ ++row_block, output += block_height) {
35
+ int nnz = nnz_per_row[row_block];
36
+ // Biases are now stored and used directly without pre-division.
37
+ for (int i = 0; i < block_height; ++i) accumulators[i] = bias[bias_index++];
38
+
39
+ for (int c = 0; c < nnz; ++c) {
40
+ int rhs_index = rhs_indices[c];
41
+ const float* block_rhs = rhs + rhs_index * block_width;
42
+ // Multiply this |block_height| x |block_width| block.
43
+ for (int i = 0; i < block_height; ++i) {
44
+ for (int j = 0; j < block_width; ++j) {
45
+ accumulators[i] += weights[weight_index++] * block_rhs[j];
46
+ }
47
+ }
48
+ }
49
+ rhs_indices += nnz;
50
+ // Apply relu if desired.
51
+ if (relu) {
52
+ for (int i = 0; i < block_height; ++i) {
53
+ if (accumulators[i] < 0) accumulators[i] = 0;
54
+ }
55
+ }
56
+ for (int r = 0; r < replicas; ++r) {
57
+ for (int i = 0; i < block_height; ++i) {
58
+ output[i + r * stride] = accumulators[i];
59
+ }
60
+ }
61
+ }
62
+ }
63
+
64
+ void MatVecFixedGeneric(const int16_t* weights, const int16_t* rhs,
65
+ const int32_t* bias, const int32_t* nnz_per_row,
66
+ const int16_t* rhs_indices, int start_row, int end_row,
67
+ int block_height, int block_width, bool relu,
68
+ int bytes_out, int shift_out, int replicas, int stride,
69
+ void* output) {
70
+ int weight_index = 0;
71
+ int bias_index = 0;
72
+ std::vector<int32_t> accumulators(block_height);
73
+ for (int row_block = start_row; row_block < end_row; ++row_block) {
74
+ int nnz = nnz_per_row[row_block];
75
+ // Biases are now stored and used directly without pre-division.
76
+ for (int i = 0; i < block_height; ++i) accumulators[i] = bias[bias_index++];
77
+
78
+ for (int c = 0; c < nnz; ++c) {
79
+ int rhs_index = rhs_indices[c];
80
+ const int16_t* block_rhs = rhs + rhs_index * block_width;
81
+ // Multiply this |block_height| x |block_width| block.
82
+ for (int i = 0; i < block_height; ++i) {
83
+ for (int j = 0; j < block_width; ++j) {
84
+ accumulators[i] += weights[weight_index++] * block_rhs[j];
85
+ }
86
+ }
87
+ }
88
+ rhs_indices += nnz;
89
+ // Apply relu if desired.
90
+ if (relu) {
91
+ for (int i = 0; i < block_height; ++i) {
92
+ if (accumulators[i] < 0) accumulators[i] = 0;
93
+ }
94
+ }
95
+ // Output shift.
96
+ if (shift_out > 0) {
97
+ for (int i = 0; i < block_height; ++i) {
98
+ accumulators[i] >>= shift_out;
99
+ }
100
+ }
101
+ if (bytes_out == 2) {
102
+ int16_t* out16 = reinterpret_cast<int16_t*>(output);
103
+ output = out16 + block_height;
104
+ for (int r = 0; r < replicas; ++r, out16 += stride) {
105
+ for (int i = 0; i < block_height; ++i) {
106
+ out16[i] = accumulators[i];
107
+ }
108
+ }
109
+ } else {
110
+ int32_t* out32 = reinterpret_cast<int32_t*>(output);
111
+ output = out32 + block_height;
112
+ for (int r = 0; r < replicas; ++r, out32 += stride) {
113
+ for (int i = 0; i < block_height; ++i) {
114
+ out32[i] = accumulators[i];
115
+ }
116
+ }
117
+ }
118
+ }
119
+ }
120
+
121
+ } // namespace detail
122
+ } // namespace csrblocksparse
sparse_matmul/compute/matmul_generic.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_
19
+
20
+ #include <cstdint>
21
+
22
+ namespace csrblocksparse {
23
+ namespace detail {
24
+
25
+ // Generic version uses plain C++ code.
26
+ void MatVecFloatGeneric(const float* weights, const float* rhs,
27
+ const float* bias, const int32_t* nnz_per_row,
28
+ const int16_t* rhs_indices, int start_row, int end_row,
29
+ int block_height, int block_width, bool relu,
30
+ int replicas, int stride, float* output);
31
+ void MatVecFixedGeneric(const int16_t* weights, const int16_t* rhs,
32
+ const int32_t* bias, const int32_t* nnz_per_row,
33
+ const int16_t* rhs_indices, int start_row, int end_row,
34
+ int block_height, int block_width, bool relu,
35
+ int bytes_out, int shift_out, int replicas, int stride,
36
+ void* output);
37
+
38
+ } // namespace detail
39
+ } // namespace csrblocksparse
40
+
41
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_
sparse_matmul/compute/thread_bounds.cc ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include "sparse_matmul/compute/thread_bounds.h"
16
+
17
+ #include <vector>
18
+
19
+ #include "glog/logging.h"
20
+
21
+ namespace csrblocksparse {
22
+
23
+ void ThreadBounds::PrepareForThreads(int block_width, int block_height,
24
+ int num_threads,
25
+ int reduced_rows_per_cache_row,
26
+ int reduced_rows, const int* nnz_per_row) {
27
+ CHECK_GT(num_threads, 0);
28
+ block_width_ = block_width;
29
+ block_height_ = block_height;
30
+ ComputeThreadSplitPoints(num_threads, reduced_rows_per_cache_row,
31
+ reduced_rows, nnz_per_row);
32
+ weight_starts_.clear();
33
+ rhs_indices_starts_.clear();
34
+ bias_starts_.clear();
35
+ weight_starts_.reserve(row_starts_.size());
36
+ rhs_indices_starts_.reserve(row_starts_.size());
37
+ bias_starts_.reserve(row_starts_.size());
38
+
39
+ // Compute the start indices of each of the types, given what we know about
40
+ // padding, and number of |nnz_per_row|.
41
+ int weight_index = 0;
42
+ int rhs_indices_index = 0;
43
+ int bias_index = 0;
44
+ int row = 0;
45
+ for (int start : row_starts_) {
46
+ while (row < start) {
47
+ weight_index += nnz_per_row[row] * block_width_ * block_height_;
48
+ rhs_indices_index += nnz_per_row[row];
49
+ bias_index += block_height_;
50
+ ++row;
51
+ }
52
+ weight_starts_.push_back(weight_index);
53
+ rhs_indices_starts_.push_back(rhs_indices_index);
54
+ bias_starts_.push_back(bias_index);
55
+ }
56
+ }
57
+
58
+ // Computes the block row (reduced) index of the start of each thread.
59
+ void ThreadBounds::ComputeThreadSplitPoints(int num_threads,
60
+ int reduced_rows_per_cache_row,
61
+ int reduced_rows,
62
+ const int* nnz_per_row) {
63
+ row_starts_.assign(/*n=*/1, /*val=*/0);
64
+ // Break the rule if the matrix is too small to allow one per thread, which
65
+ // occurs only during tests.
66
+ if (reduced_rows_per_cache_row * num_threads > reduced_rows)
67
+ reduced_rows_per_cache_row = std::max(reduced_rows / num_threads, 1);
68
+ int cache_rows = (reduced_rows + reduced_rows_per_cache_row - 1) /
69
+ reduced_rows_per_cache_row;
70
+
71
+ // Compute exclusive prefix sum of the amount of work per row.
72
+ std::vector<int> work_upto_row(cache_rows + 1, 0);
73
+ int extra_row_work = 2 * reduced_rows_per_cache_row;
74
+ for (int i = 0; i < cache_rows; ++i) {
75
+ int new_nnz = 0;
76
+ for (int j = 0; j < reduced_rows_per_cache_row; ++j) {
77
+ // if |reduced_rows_per_cache_row| isn't an exact multiple of the
78
+ // matrix size, then we need to be careful here.
79
+ int index = i * reduced_rows_per_cache_row + j;
80
+ if (index < reduced_rows) new_nnz += nnz_per_row[index];
81
+ }
82
+ work_upto_row[i + 1] = new_nnz + extra_row_work + work_upto_row[i];
83
+ }
84
+ int total_work = work_upto_row.back();
85
+ // Find the split point point based on assigned approximately equal amount
86
+ // of work for each thread.
87
+ int prev_split = 0;
88
+ for (int i = 1; i <= num_threads; ++i) {
89
+ int split = std::distance(
90
+ work_upto_row.begin(),
91
+ std::lower_bound(work_upto_row.begin(), work_upto_row.end(),
92
+ i * total_work / num_threads));
93
+ int split_row = split * reduced_rows_per_cache_row;
94
+ if (i == num_threads) {
95
+ split_row = reduced_rows;
96
+ }
97
+
98
+ VLOG(2) << "tid=" << i - 1 << " num rows=" << split_row - row_starts_.back()
99
+ << " work=" << work_upto_row[split] - work_upto_row[prev_split];
100
+ row_starts_.push_back(split_row);
101
+ prev_split = split;
102
+ }
103
+ VLOG(2) << "total rows=" << reduced_rows << " total work=" << total_work;
104
+ }
105
+
106
+ } // namespace csrblocksparse
sparse_matmul/compute/thread_bounds.h ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_
19
+
20
+ #include <vector>
21
+
22
+ namespace csrblocksparse {
23
+
24
+ // Class to compute and store the bounds of each thread used in a computation,
25
+ // and to provide corresponding spans of vectors.
26
+ class ThreadBounds {
27
+ public:
28
+ ThreadBounds() : block_width_(0), block_height_(0) {}
29
+
30
+ void PrepareForThreads(int block_width, int block_height, int num_threads,
31
+ int reduced_rows_per_cache_row, int reduced_rows,
32
+ const int* nnz_per_row);
33
+
34
+ // Functions that offset the appropriate type to the start of the data
35
+ // needed by the given thread id (|tid|).
36
+ template <typename WeightType>
37
+ const WeightType* OffsetWeights(const WeightType* weights, int tid) const {
38
+ return weights + weight_starts_[tid];
39
+ }
40
+ template <typename RhsIndType>
41
+ const RhsIndType* OffsetRhsIndices(const RhsIndType* rhs_indices,
42
+ int tid) const {
43
+ return rhs_indices + rhs_indices_starts_[tid];
44
+ }
45
+ template <typename BiasType>
46
+ const BiasType* OffsetBias(const BiasType* bias, int tid) const {
47
+ return bias + bias_starts_[tid];
48
+ }
49
+ template <typename OutType>
50
+ OutType* OffsetOutput(OutType* output, int tid) const {
51
+ return output + block_height_ * row_starts_[tid];
52
+ }
53
+ int StartRow(int tid) const { return row_starts_[tid]; }
54
+ const std::vector<int>& row_starts() const { return row_starts_; }
55
+
56
+ private:
57
+ // Computes the block row (reduced) index of the start of each thread.
58
+ void ComputeThreadSplitPoints(int num_threads, int reduced_rows_per_cache_row,
59
+ int reduced_rows, const int* nnz_per_row);
60
+
61
+ // Sizes of a sparse block.
62
+ int block_width_;
63
+ int block_height_;
64
+ // Start indices of each data type by thread-id with an extra value at the
65
+ // end.
66
+ std::vector<int> row_starts_;
67
+ std::vector<int> weight_starts_;
68
+ std::vector<int> rhs_indices_starts_;
69
+ std::vector<int> bias_starts_;
70
+ };
71
+
72
+ } // namespace csrblocksparse
73
+
74
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_
sparse_matmul/layers/BUILD ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sparse/Masked Matrix and Layer.
2
+
3
+ # [internal] load android_library_selector
4
+ # [internal] load android_cc_test:def.bzl
5
+
6
+ licenses(["notice"])
7
+
8
+ cc_library(
9
+ name = "layer",
10
+ hdrs = [
11
+ "sparse_linear_layer.h",
12
+ ],
13
+ visibility = [
14
+ "//sparse_matmul:__subpackages__",
15
+ ],
16
+ deps = [
17
+ ":matrix",
18
+ "//sparse_matmul/numerics:types",
19
+ "//sparse_matmul/os:coop_threads",
20
+ "//sparse_matmul/vector:cache_aligned_vector",
21
+ "@com_google_absl//absl/memory",
22
+ "@com_google_absl//absl/strings:str_format",
23
+ "@com_google_glog//:glog",
24
+ ],
25
+ )
26
+
27
+ cc_library(
28
+ name = "matrix",
29
+ hdrs = [
30
+ "csr_blocksparse_matrix.h",
31
+ "masked_sparse_matrix.h",
32
+ ],
33
+ visibility = [
34
+ "//sparse_matmul:__subpackages__",
35
+ ],
36
+ deps = [
37
+ "//sparse_matmul/compute:kernels",
38
+ "//sparse_matmul/compute:matmul",
39
+ "//sparse_matmul/compute:thread_bounds",
40
+ "//sparse_matmul/numerics:types",
41
+ "//sparse_matmul/os:coop_threads",
42
+ "//sparse_matmul/vector:cache_aligned_vector",
43
+ "@com_google_absl//absl/memory",
44
+ "@com_google_absl//absl/strings:str_format",
45
+ "@com_google_glog//:glog",
46
+ ],
47
+ )
48
+
49
+ cc_library(
50
+ name = "utils",
51
+ srcs = [
52
+ "utils.cc",
53
+ ],
54
+ hdrs = [
55
+ "read_array_ifstream.h",
56
+ "utils.h",
57
+ ],
58
+ visibility = [
59
+ "//sparse_matmul:__subpackages__",
60
+ ],
61
+ deps = [
62
+ ":layer",
63
+ ":matrix",
64
+ ":status",
65
+ "//sparse_matmul/numerics:types",
66
+ "//sparse_matmul/vector:cache_aligned_vector",
67
+ "//sparse_matmul/zlib_wrapper",
68
+ "@com_google_absl//absl/status",
69
+ "@com_google_absl//absl/strings",
70
+ "@com_google_absl//absl/strings:cord",
71
+ "@gulrak_filesystem//:filesystem",
72
+ ],
73
+ )
74
+
75
+ cc_library(
76
+ name = "status",
77
+ srcs = [
78
+ "errno_mapping.cc",
79
+ ],
80
+ hdrs = [
81
+ "errno_mapping.h",
82
+ "status_macros.h",
83
+ ],
84
+ deps = [
85
+ "@com_google_absl//absl/status",
86
+ "@com_google_absl//absl/status:statusor",
87
+ "@com_google_absl//absl/strings",
88
+ "@com_google_absl//absl/strings:cord",
89
+ ],
90
+ )
91
+
92
+ cc_test(
93
+ name = "csrblocksparse_test",
94
+ size = "small",
95
+ srcs = [
96
+ "csrblocksparse_test.cc",
97
+ ],
98
+ data = glob(["testdata/*"]),
99
+ linkopts = select({
100
+ "@bazel_tools//platforms:android": ["-landroid"],
101
+ "//conditions:default": [],
102
+ }),
103
+ shard_count = 10,
104
+ deps = [
105
+ ":status",
106
+ ":utils",
107
+ "//sparse_matmul/compute:matmul",
108
+ "//sparse_matmul/numerics:test_utils",
109
+ "//sparse_matmul/os:coop_threads",
110
+ "@com_google_absl//absl/status",
111
+ "@com_google_absl//absl/strings",
112
+ "@com_google_absl//absl/types:span",
113
+ "@com_google_googletest//:gtest_main",
114
+ "@gulrak_filesystem//:filesystem",
115
+ ],
116
+ )
117
+
118
+ cc_test(
119
+ name = "sparse_linear_layer_test",
120
+ srcs = [
121
+ "sparse_linear_layer_test.cc",
122
+ ],
123
+ deps = [
124
+ ":layer",
125
+ "//sparse_matmul/numerics:test_utils",
126
+ "@com_google_googletest//:gtest_main",
127
+ ],
128
+ )
129
+
130
+ cc_test(
131
+ name = "utils_test",
132
+ srcs = ["utils_test.cc"],
133
+ deps = [
134
+ ":layer",
135
+ ":matrix",
136
+ ":status",
137
+ ":utils",
138
+ "//sparse_matmul/numerics:fast_transcendentals",
139
+ "//sparse_matmul/numerics:test_utils",
140
+ "//sparse_matmul/numerics:types",
141
+ "//sparse_matmul/vector:cache_aligned_vector",
142
+ "@com_google_absl//absl/flags:flag",
143
+ "@com_google_googletest//:gtest_main",
144
+ "@gulrak_filesystem//:filesystem",
145
+ ],
146
+ )
sparse_matmul/layers/csr_blocksparse_matrix.h ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_
19
+
20
+ #include <algorithm>
21
+ #include <cstdint>
22
+ #include <iostream>
23
+ #include <memory>
24
+ #include <tuple>
25
+ #include <vector>
26
+
27
+ #include "glog/logging.h"
28
+ // IWYU pragma: begin_exports
29
+ #include "sparse_matmul/compute/kernels_generic.h"
30
+ #include "sparse_matmul/compute/matmul.h"
31
+ #include "sparse_matmul/compute/thread_bounds.h"
32
+ #include "sparse_matmul/layers/masked_sparse_matrix.h"
33
+ #include "sparse_matmul/numerics/fixed_types.h"
34
+ #include "sparse_matmul/numerics/float16_types.h"
35
+ #include "sparse_matmul/os/coop_threads.h"
36
+ #include "sparse_matmul/vector/cache_aligned_vector.h"
37
+ // IWYU pragma: end_exports
38
+ #include "absl/memory/memory.h"
39
+
40
+ namespace csrblocksparse {
41
+ // CsrBlockSparseMatrix stores a modified block compressed sparse row
42
+ // representation of a sparse matrix. The ordering of the weights is modified
43
+ // in the 16x1 and 1x1 cases so that a certain number (4 and 8 respectively)
44
+ // of columns of weights are stored contiguously before moving on to the next
45
+ // row. The 4x4 case stores each block contiguously.
46
+ //
47
+ // Currently it is constructed from a MaskedSparseMatrix which usees a dense
48
+ // binary mask representation. The construction generates the compressed
49
+ // representation. Further iterations will support a direct serialization
50
+ // of the compressed representation.
51
+ //
52
+ // MaskedSparseMatrix masked_matrix(rows, cols, existing_mask, existing_values)
53
+ // CsrBlockSparseMatrix matrix(masked_matrix)
54
+ //
55
+ // matrix.SpMV_bias(rhs, bias, &out);
56
+ //
57
+ // This class is thread compatible.
58
+ template <typename WeightType, typename RhsType, typename DeltaType = int16_t>
59
+ class CsrBlockSparseMatrix {
60
+ public:
61
+ CsrBlockSparseMatrix() {}
62
+
63
+ // Reference used to indicate that this is an input and not an output.
64
+ CsrBlockSparseMatrix(const uint8_t* const& buffer, const std::size_t& len) {
65
+ ReadFromFlatBuffer(buffer, len);
66
+ ComputeRHSIndices();
67
+ }
68
+
69
+ template <typename InputType>
70
+ CsrBlockSparseMatrix(const MaskedSparseMatrix<InputType>& masked_matrix) {
71
+ sparsity_ = masked_matrix.sparsity();
72
+ rows_ = masked_matrix.rows();
73
+ cols_ = masked_matrix.cols();
74
+
75
+ DetermineBlockSize(masked_matrix);
76
+
77
+ if (block_width_ == 1 && block_height_ == 1)
78
+ col_multiple_ = 8;
79
+ else
80
+ col_multiple_ = 1;
81
+
82
+ std::vector<InputType> weights(masked_matrix.values().begin(),
83
+ masked_matrix.values().end());
84
+
85
+ reduced_rows_ = (rows_ + block_height_ - 1) / block_height_;
86
+ rows_ = reduced_rows_ * block_height_;
87
+ reduced_cols_ = cols_ / block_width_;
88
+
89
+ // Calculate the reduced CSR representation of the matrix.
90
+ std::vector<int> reduced_mask(reduced_rows_ * reduced_cols_);
91
+ std::vector<int> row_offsets = {0};
92
+ int nnz = 0;
93
+ const auto& mask = masked_matrix.mask();
94
+ for (int r = 0; r < reduced_rows_; ++r) {
95
+ for (int c = 0; c < reduced_cols_; ++c) {
96
+ int mask_val = mask[r * block_height_ * cols_ + c * block_width_];
97
+ reduced_mask[r * reduced_cols_ + c] = mask_val;
98
+ nnz += mask_val;
99
+ }
100
+ row_offsets.push_back(nnz);
101
+ }
102
+
103
+ // Make sure the reduced representation has the correct number of columns.
104
+ MakeColumnsMultiple(row_offsets, &reduced_mask, &weights);
105
+
106
+ std::vector<int> col_indices;
107
+ std::vector<WeightType> weights_csr;
108
+ std::vector<int> nnz_per_row;
109
+ MaskAndWeightsToCsr(reduced_mask, weights, &nnz_per_row, &col_indices,
110
+ &weights_csr);
111
+
112
+ // Generate column deltas from |col_indices|.
113
+ std::vector<DeltaType> col_deltas;
114
+ for (int i = 0; i < col_indices.size(); ++i) {
115
+ // |col_indices| are used to index the RHS vector which is always float.
116
+ int64_t diff = sizeof(RhsType);
117
+ if (i == 0)
118
+ diff *= block_width_ * (col_indices[i]);
119
+ else
120
+ diff *= block_width_ * (col_indices[i] - col_indices[i - 1]);
121
+
122
+ CHECK(diff < std::numeric_limits<DeltaType>::max())
123
+ << "delta between column indices in bytes " << diff
124
+ << " exceeded the maximum size of the DeltaType "
125
+ << std::numeric_limits<DeltaType>::max();
126
+ col_deltas.push_back(static_cast<DeltaType>(diff));
127
+ }
128
+
129
+ // Because of pre-fetching we need some extra values at the end.
130
+ col_deltas.insert(col_deltas.end(), std::max(2, col_multiple_ + 1), 0);
131
+ nnz_per_row.insert(nnz_per_row.end(), 2, nnz_per_row.back());
132
+
133
+ weights_ = CacheAlignedVector<WeightType>(weights_csr);
134
+ col_deltas_ = CacheAlignedVector<DeltaType>(col_deltas);
135
+ nnz_per_row_ = CacheAlignedVector<int>(nnz_per_row);
136
+ ComputeRHSIndices();
137
+
138
+ num_threads_ = 0;
139
+ PrepareForThreads(1);
140
+ }
141
+
142
+ // Constructor makes a matrix from the given weights, deltas and nnz, taking
143
+ // the other parameters from |src_matrix|. |cols| is the number of raw columns
144
+ // (NOT blocks) of the new matrix.
145
+ CsrBlockSparseMatrix(
146
+ const CsrBlockSparseMatrix<WeightType, RhsType, DeltaType>& src_matrix,
147
+ const std::vector<WeightType>& new_weights,
148
+ const std::vector<DeltaType>& new_deltas, const std::vector<int>& new_nnz,
149
+ int cols) {
150
+ num_threads_ = 0;
151
+ col_multiple_ = src_matrix.col_multiple_;
152
+ block_width_ = src_matrix.block_width_;
153
+ block_height_ = src_matrix.block_height_;
154
+ reduced_rows_ = new_nnz.size();
155
+ rows_ = reduced_rows_ * block_height_;
156
+ cols_ = cols;
157
+ reduced_cols_ = cols_ / block_width_;
158
+ weights_ = CacheAlignedVector<WeightType>(new_weights);
159
+ col_deltas_ = CacheAlignedVector<DeltaType>(new_deltas);
160
+ nnz_per_row_ = CacheAlignedVector<int>(new_nnz);
161
+ sparsity_ = 1.0f - static_cast<float>(new_weights.size()) / (rows_ * cols_);
162
+ ComputeRHSIndices();
163
+ name_ = src_matrix.name_;
164
+ PrepareForThreads(1);
165
+ }
166
+
167
+ // Factory method takes a column slice out of *this and returns a sparse
168
+ // matrix that takes as inputs [|start_col|, |end_col|) of *this, and
169
+ // returns the same number of outputs, but only a partial result.
170
+ // If |keep_rhs_size|, then the new matrix takes the same rhs as the current
171
+ // matrix, but uses a subset of it, instead of expecting just the reduced rhs.
172
+ // If |start_col| > |end_col|, then we slice out the complement of the defined
173
+ // interval, ie [0, |end_col|) + [|start_col|, current end).
174
+ // NOTE That |start_col| and |end_col| are in raw column coordinates, NOT
175
+ // block units.
176
+ CsrBlockSparseMatrix SplitByColumn(int start_col, int end_col,
177
+ bool keep_rhs_size = false) const {
178
+ int weight_index = 0;
179
+ int delta_index = 0;
180
+ std::vector<DeltaType> new_deltas;
181
+ std::vector<WeightType> new_weights;
182
+ std::vector<int> new_nnz(reduced_rows_);
183
+ int col = 0;
184
+ int prev_col = keep_rhs_size ? 0 : start_col;
185
+ for (int r = 0; r < reduced_rows_; ++r) {
186
+ int reduced_col_count = nnz_per_row_[r];
187
+ for (int c = 0; c < reduced_col_count; ++c, ++delta_index) {
188
+ col += col_deltas_[delta_index] / sizeof(RhsType);
189
+ if ((start_col < end_col && start_col <= col && col < end_col) ||
190
+ (start_col > end_col && (col < end_col || col >= start_col))) {
191
+ ++new_nnz[r];
192
+ new_deltas.push_back((col - prev_col) * sizeof(RhsType));
193
+ prev_col = col;
194
+ for (int i = 0; i < block_width_ * block_height_;
195
+ ++i, ++weight_index) {
196
+ new_weights.push_back(weights_[weight_index]);
197
+ }
198
+ } else {
199
+ weight_index += block_width_ * block_height_;
200
+ }
201
+ }
202
+ }
203
+ int new_cols = keep_rhs_size ? cols_ : end_col - start_col;
204
+ return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz,
205
+ new_cols);
206
+ }
207
+
208
+ // Factory method takes a row slice out of *this and returns a sparse
209
+ // matrix that takes the sampe inputs as *this, and returns the outputs for
210
+ // the range [|start_row|, |end_row|).
211
+ // NOTE That |start_row| and |end_row| are in raw column coordinates, NOT
212
+ // block units.
213
+ CsrBlockSparseMatrix SplitByRow(int start_row, int end_row) const {
214
+ int start_reduced = start_row / block_height_;
215
+ int end_reduced = end_row / block_height_;
216
+ std::vector<int> new_nnz(nnz_per_row_.data() + start_reduced,
217
+ nnz_per_row_.data() + end_reduced);
218
+ int weight_start = 0;
219
+ for (int r = 0; r < start_reduced; ++r) {
220
+ weight_start += nnz_per_row_[r];
221
+ }
222
+ int weight_end = weight_start;
223
+ for (int r = start_reduced; r < end_reduced; ++r) {
224
+ weight_end += nnz_per_row_[r];
225
+ }
226
+ int delta_start = 0;
227
+ for (int i = 0; i < weight_start; ++i) {
228
+ delta_start += col_deltas_[i];
229
+ }
230
+ std::vector<DeltaType> new_deltas(col_deltas_.data() + weight_start,
231
+ col_deltas_.data() + weight_end);
232
+ new_deltas[0] += delta_start;
233
+ int block_size = block_height_ * block_width_;
234
+ std::vector<WeightType> new_weights(
235
+ weights_.data() + weight_start * block_size,
236
+ weights_.data() + weight_end * block_size);
237
+ return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz, cols_);
238
+ }
239
+
240
+ // Combines adjacent row blocks, doubling the block height.
241
+ // This necessarily involves adding zero weights where the blocks don't align
242
+ // across adjacent pairs of rows, so use with caution, as the resulting matrix
243
+ // is most likely to run slower if very sparse to begin with.
244
+ // In the few cases where the blocks do mostly align, the resulting matmul
245
+ // could be much faster, as the number of reads of the rhs will be halved.
246
+ void DoubleBlockHeight() {
247
+ int new_rows = reduced_rows_ / 2;
248
+ std::vector<int> new_nnz(new_rows);
249
+ std::vector<DeltaType> new_rhs_indices;
250
+ std::vector<WeightType> new_weights;
251
+ int rhs_index1 = 0;
252
+ int rhs_index2 = 0;
253
+ int block_size = block_height_ * block_width_;
254
+ for (int r = 0; r < new_rows; ++r) {
255
+ int start_nnz = new_rhs_indices.size();
256
+ rhs_index2 += nnz_per_row_[r * 2];
257
+ int end1 = rhs_index1 + nnz_per_row_[r * 2];
258
+ int end2 = rhs_index2 + nnz_per_row_[r * 2 + 1];
259
+ // Run over a pair of rows with 2 iterators, combining blocks as we go, or
260
+ // padding with zeros where the block positions don't match.
261
+ while (rhs_index1 < end1 || rhs_index2 < end2) {
262
+ int col1 = rhs_index1 < end1 ? rhs_indices_[rhs_index1] : reduced_cols_;
263
+ int col2 = rhs_index2 < end2 ? rhs_indices_[rhs_index2] : reduced_cols_;
264
+ if (col1 < col2) {
265
+ // Need zero weights for row2 to pad out weights block.
266
+ new_rhs_indices.push_back(col1);
267
+ new_weights.insert(new_weights.end(),
268
+ weights_.data() + rhs_index1 * block_size,
269
+ weights_.data() + (rhs_index1 + 1) * block_size);
270
+ new_weights.insert(new_weights.end(), block_size,
271
+ static_cast<WeightType>(0.0f));
272
+ ++rhs_index1;
273
+ } else if (col1 > col2) {
274
+ // Need zero weights for row1 to pad out weights block.
275
+ new_rhs_indices.push_back(col2);
276
+ new_weights.insert(new_weights.end(), block_size,
277
+ static_cast<WeightType>(0.0f));
278
+ new_weights.insert(new_weights.end(),
279
+ weights_.data() + rhs_index2 * block_size,
280
+ weights_.data() + (rhs_index2 + 1) * block_size);
281
+ ++rhs_index2;
282
+ } else {
283
+ // Combine weights for both row1 and row2.
284
+ new_rhs_indices.push_back(col1);
285
+ new_weights.insert(new_weights.end(),
286
+ weights_.data() + rhs_index1 * block_size,
287
+ weights_.data() + (rhs_index1 + 1) * block_size);
288
+ new_weights.insert(new_weights.end(),
289
+ weights_.data() + rhs_index2 * block_size,
290
+ weights_.data() + (rhs_index2 + 1) * block_size);
291
+ ++rhs_index1;
292
+ ++rhs_index2;
293
+ }
294
+ }
295
+ rhs_index1 = rhs_index2;
296
+ new_nnz[r] = new_rhs_indices.size() - start_nnz;
297
+ }
298
+ block_height_ *= 2;
299
+ reduced_rows_ /= 2;
300
+ weights_ = CacheAlignedVector<WeightType>(new_weights);
301
+ rhs_indices_ = CacheAlignedVector<DeltaType>(new_rhs_indices);
302
+ nnz_per_row_ = CacheAlignedVector<int>(new_nnz);
303
+ sparsity_ = 1.0f - static_cast<float>(new_weights.size()) / (rows_ * cols_);
304
+ ComputeColDeltas();
305
+ if (num_threads_ > 0) {
306
+ int num_threads = num_threads_;
307
+ num_threads_ = 0;
308
+ PrepareForThreads(num_threads);
309
+ }
310
+ }
311
+
312
+ // Allocates memory and fills buffer.
313
+ // Caller is responsible for the memory de-allocation.
314
+ // TODO(b/189958858): Both Read and Write need to eventually handle the
315
+ // different possible HalfType and DeltaType values, but punting for now as
316
+ // there is only one supported combination.
317
+ std::size_t WriteToFlatBuffer(std::string* csr_flatbuffer) {
318
+ std::size_t bytes = 0;
319
+ bytes += FixedParameterSize();
320
+ bytes += weights_.size() * sizeof(WeightType);
321
+ bytes += col_deltas_.size() * sizeof(DeltaType);
322
+ bytes += nnz_per_row_.size() * sizeof(int);
323
+
324
+ uint8_t* bytes_ptr_ptr =
325
+ reinterpret_cast<uint8_t*>(CHECK_NOTNULL(malloc(bytes)));
326
+
327
+ int* int_bytes_ptr = reinterpret_cast<int*>(bytes_ptr_ptr);
328
+
329
+ *int_bytes_ptr++ = rows_;
330
+ *int_bytes_ptr++ = cols_;
331
+ *int_bytes_ptr++ = reduced_rows_;
332
+ *int_bytes_ptr++ = reduced_cols_;
333
+ *int_bytes_ptr++ = block_width_;
334
+ *int_bytes_ptr++ = block_height_;
335
+ *int_bytes_ptr++ = col_multiple_;
336
+ *int_bytes_ptr++ = num_threads_;
337
+ *int_bytes_ptr++ = weights_.size();
338
+ *int_bytes_ptr++ = col_deltas_.size();
339
+ *int_bytes_ptr++ = nnz_per_row_.size();
340
+
341
+ float* float_bytes_ptr = reinterpret_cast<float*>(int_bytes_ptr);
342
+ *float_bytes_ptr++ = sparsity_;
343
+
344
+ uint8_t* bytes_ptr = reinterpret_cast<uint8_t*>(float_bytes_ptr);
345
+
346
+ memcpy(bytes_ptr, weights_.data(), weights_.size() * sizeof(WeightType));
347
+ bytes_ptr += weights_.size() * sizeof(WeightType);
348
+
349
+ memcpy(bytes_ptr, col_deltas_.data(),
350
+ col_deltas_.size() * sizeof(DeltaType));
351
+ bytes_ptr += col_deltas_.size() * sizeof(DeltaType);
352
+
353
+ memcpy(bytes_ptr, nnz_per_row_.data(), nnz_per_row_.size() * sizeof(int));
354
+ bytes_ptr += nnz_per_row_.size() * sizeof(int);
355
+
356
+ csr_flatbuffer->resize(bytes);
357
+ csr_flatbuffer->assign(reinterpret_cast<char*>(bytes_ptr_ptr), bytes);
358
+ free(bytes_ptr_ptr);
359
+
360
+ return bytes;
361
+ }
362
+
363
+ void ReadFromFlatBuffer(const uint8_t* const& bytes, const std::size_t& len) {
364
+ CHECK_GE(len, FixedParameterSize());
365
+
366
+ const int* int_bytes_ptr = reinterpret_cast<const int*>(bytes);
367
+ rows_ = *int_bytes_ptr++;
368
+ cols_ = *int_bytes_ptr++;
369
+ reduced_rows_ = *int_bytes_ptr++;
370
+ reduced_cols_ = *int_bytes_ptr++;
371
+ block_width_ = *int_bytes_ptr++;
372
+ block_height_ = *int_bytes_ptr++;
373
+ col_multiple_ = *int_bytes_ptr++;
374
+ int num_threads = *int_bytes_ptr++;
375
+ int32_t weights_size = *int_bytes_ptr++;
376
+ int32_t col_deltas_size = *int_bytes_ptr++;
377
+ int32_t nnz_per_row_size = *int_bytes_ptr++;
378
+
379
+ // Make sure negative sizes don't mess things up.
380
+ weights_size = std::max(0, weights_size);
381
+ col_deltas_size = std::max(0, col_deltas_size);
382
+ nnz_per_row_size = std::max(0, nnz_per_row_size);
383
+
384
+ const float* float_bytes_ptr =
385
+ reinterpret_cast<const float*>(int_bytes_ptr);
386
+ sparsity_ = *float_bytes_ptr++;
387
+
388
+ std::size_t total_bytes =
389
+ FixedParameterSize() + weights_size * sizeof(WeightType) +
390
+ col_deltas_size * sizeof(DeltaType) + nnz_per_row_size * sizeof(int);
391
+
392
+ CHECK_EQ(total_bytes, len)
393
+ << "total bytes: " << total_bytes << ", actual len given: " << len;
394
+
395
+ const uint8_t* bytes_ptr =
396
+ reinterpret_cast<const uint8_t*>(float_bytes_ptr);
397
+ std::vector<WeightType> weights_raw(weights_size);
398
+ memcpy(weights_raw.data(), bytes_ptr, weights_size * sizeof(WeightType));
399
+ weights_ = CacheAlignedVector<WeightType>(weights_raw);
400
+ bytes_ptr += weights_size * sizeof(WeightType);
401
+
402
+ std::vector<DeltaType> deltas_raw(col_deltas_size);
403
+ memcpy(deltas_raw.data(), bytes_ptr, col_deltas_size * sizeof(DeltaType));
404
+ col_deltas_ = CacheAlignedVector<DeltaType>(deltas_raw);
405
+ bytes_ptr += col_deltas_size * sizeof(DeltaType);
406
+
407
+ std::vector<int> nnz_raw(nnz_per_row_size);
408
+ memcpy(nnz_raw.data(), bytes_ptr, nnz_per_row_size * sizeof(int));
409
+ nnz_per_row_ = CacheAlignedVector<int>(nnz_raw);
410
+ num_threads_ = 0;
411
+ PrepareForThreads(num_threads);
412
+ }
413
+
414
+ // Multiply a Sparse matrix by a possibly dense matrix. Often the matrix is
415
+ // a vector with a small number of columns, hence the term "fat vector".
416
+ // 1x1 and 4x4 have specializations for output columns (ie fatness) > 5,
417
+ // and often achieve twice as many GFlops when multiplying a right hand side
418
+ // that has 5 or more columns. (Best is a multiple of 5).
419
+ // 16x1 doesn't have enough registers and just loops over the width 1 kernel.
420
+ //
421
+ // |rhs| and |out| are COLUMN MAJOR.
422
+
423
+ // Fast Tuples WeightType, BiasType, RhsType, OutType are:
424
+ // (float, float, float, float)
425
+ // (bfloat16, float, float, float)
426
+ // and only on ARM64. All other cases use a slow generic implementation.
427
+ template <typename RhsClass, typename BiasClass, typename OutClass,
428
+ typename BiasType = typename BiasClass::value_type,
429
+ typename OutType = typename OutClass::value_type>
430
+ void SpMM_bias(const RhsClass& rhs, const BiasClass& bias, OutClass* out,
431
+ bool relu = false, int tid = 0,
432
+ SpinBarrier* barrier = nullptr) const {
433
+ static_assert(std::is_same<typename RhsClass::value_type, RhsType>::value,
434
+ "Rhs types must match");
435
+ CHECK_LT(tid, num_threads_);
436
+ CHECK_EQ(rhs.cols(), out->cols());
437
+ CHECK_EQ(rhs.rows(), cols_);
438
+ CHECK_GE(out->rows(), rows_);
439
+ int cols_to_go = out->cols();
440
+ int rhs_index = *thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid);
441
+ const RhsType* rhs_ptr = rhs.data() + rhs_index * block_height_;
442
+ OutType* out_ptr = thread_bounds_.OffsetOutput(out->data(), tid);
443
+ const WeightType* weights_ptr =
444
+ thread_bounds_.OffsetWeights(weights_.data(), tid);
445
+ const DeltaType* delta_ptr =
446
+ thread_bounds_.OffsetRhsIndices(col_deltas_.data(), tid);
447
+ int offset = *delta_ptr / sizeof(RhsType);
448
+ rhs_ptr -= offset;
449
+ const int* nnz_ptr = nnz_per_row_.data() + thread_bounds_.StartRow(tid);
450
+ int assigned_rows =
451
+ thread_bounds_.StartRow(tid + 1) - thread_bounds_.StartRow(tid);
452
+ const BiasType* bias_ptr = thread_bounds_.OffsetBias(bias.data(), tid);
453
+
454
+ while (cols_to_go > 0) {
455
+ if (block_width_ == 4 && block_height_ == 4) {
456
+ if (cols_to_go >= 5) {
457
+ detail::SpMM5_4x4<WeightType, RhsType, OutType>(
458
+ weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
459
+ assigned_rows, out->col_stride(), rhs.col_stride(), relu);
460
+ } else {
461
+ detail::SpMV_4x4<WeightType, RhsType, OutType>(
462
+ weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
463
+ assigned_rows, out->col_stride(), rhs.col_stride(), relu);
464
+ }
465
+ } else {
466
+ if (cols_to_go >= 5) {
467
+ detail::SpMM5_1x1<WeightType, RhsType, OutType>(
468
+ weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
469
+ assigned_rows, out->col_stride(), rhs.col_stride(), relu);
470
+ } else {
471
+ detail::SpMV_1x1<WeightType, RhsType, OutType>(
472
+ weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
473
+ assigned_rows, out->col_stride(), rhs.col_stride(), relu);
474
+ }
475
+ }
476
+
477
+ if (cols_to_go >= 5) {
478
+ cols_to_go -= 5;
479
+ rhs_ptr += rhs.col_stride() * 5;
480
+ out_ptr += out->col_stride() * 5;
481
+ } else {
482
+ cols_to_go--;
483
+ rhs_ptr += rhs.col_stride();
484
+ out_ptr += out->col_stride();
485
+ }
486
+ if (barrier) barrier->barrier();
487
+ }
488
+ }
489
+ template <typename MVRhsType, typename MVBiasType, typename OutType>
490
+ void MatVec(const MVRhsType* rhs, const MVBiasType* bias, bool relu, int tid,
491
+ int replicas, int output_stride, OutType* output) {
492
+ CHECK_LT(tid, num_threads_);
493
+ CHECK_EQ(block_width_, 4) << "Block width must be 4!";
494
+ if (block_height_ == 8) {
495
+ matmul_.MatVec8x4(
496
+ thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs,
497
+ thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(),
498
+ thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid),
499
+ thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu,
500
+ replicas, output_stride, thread_bounds_.OffsetOutput(output, tid));
501
+ } else {
502
+ CHECK_EQ(block_height_, 4) << "Block height must be 4 or 8!";
503
+ matmul_.MatVec4x4(
504
+ thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs,
505
+ thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(),
506
+ thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid),
507
+ thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu,
508
+ replicas, output_stride, thread_bounds_.OffsetOutput(output, tid));
509
+ }
510
+ }
511
+
512
+ int rows() const { return rows_; }
513
+ int cols() const { return cols_; }
514
+ int block_height() const { return block_height_; }
515
+ int block_width() const { return block_width_; }
516
+ float sparsity() const { return sparsity_; }
517
+ int num_threads() const { return num_threads_; }
518
+ const ThreadBounds& thread_bounds() const { return thread_bounds_; }
519
+ const CacheAlignedVector<DeltaType>& rhs_indices() const {
520
+ return rhs_indices_;
521
+ }
522
+ const std::string& name() const { return name_; }
523
+ void set_name(const std::string& name) { name_ = name; }
524
+ const std::vector<int>& split_points() const {
525
+ return thread_bounds_.row_starts();
526
+ }
527
+
528
+ std::size_t bytes() const {
529
+ return weights_.size() * sizeof(WeightType) +
530
+ col_deltas_.size() * sizeof(DeltaType) +
531
+ nnz_per_row_.size() * sizeof(int);
532
+ }
533
+
534
+ // Multiplies a sparse matrix by a possibly dense matrix, as SpMM_bias above,
535
+ // and then samples from the output (softmax distribution) layer.
536
+ template <typename RhsClass, typename BiasClass, typename OutClass,
537
+ typename BiasType = typename BiasClass::value_type,
538
+ typename OutType = typename OutClass::value_type>
539
+ typename std::enable_if<!IsFixed32Type<OutType>::value, int>::type
540
+ SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out,
541
+ float temperature, int tid, SpinBarrier* barrier,
542
+ std::minstd_rand* gen,
543
+ CacheAlignedVector<float>* scratch) const {
544
+ SpMM_bias(rhs, bias, out, /*relu=*/false, tid, barrier);
545
+ return out->Sample(temperature, gen, scratch);
546
+ }
547
+ // Fixed32 version.
548
+ template <typename RhsClass, typename BiasClass, typename OutClass,
549
+ typename BiasType = typename BiasClass::value_type,
550
+ typename OutType = typename OutClass::value_type>
551
+ typename std::enable_if<IsFixed32Type<OutType>::value, int>::type
552
+ SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out,
553
+ float temperature, int tid, SpinBarrier* barrier,
554
+ std::minstd_rand* gen,
555
+ CacheAlignedVector<float>* scratch) const {
556
+ // We don't pass the barrier on, as we have more work to do.
557
+ SpMM_bias(rhs, bias, out, /*relu=*/false, tid);
558
+ return out->ReducingSample(gen, scratch, tid, temperature, barrier);
559
+ }
560
+
561
+ void Print() const {
562
+ std::cout << "Weights\n";
563
+ weights_.Print();
564
+ std::cout << std::endl;
565
+ std::cout << "Deltas\n";
566
+ col_deltas_.Print();
567
+ std::cout << std::endl;
568
+ std::cout << "nnz\n";
569
+ nnz_per_row_.Print();
570
+ std::cout << std::endl;
571
+ }
572
+
573
+ // Split the computation amongst threads by rows based on the number of
574
+ // non zeros, with the addition of a constant to account for the work of the
575
+ // bias and the horizontal add at the end, and also guarantees that each
576
+ // thread writes only whole cache lines, based on the size of OutType.
577
+ // The |cache_line_size| arg is used only for testing. Normally it is provided
578
+ // through the architecture #defines.
579
+ // Each thread gets a contiguous row range (|split_points|).
580
+ // Thread t does rows [ split_points[t], split_points[t + 1] )
581
+ // Each thread also needs to know how many non zeros were before it to skip
582
+ // (|nnz_to_skip|). And finally it also needs to know what the offset into
583
+ // the rhs vector would have been at the split point (|rhs_to_skip|).
584
+ //
585
+ // Some tricky corner cases where the number of non-zeros doesn't split
586
+ // nicely amongst the number of requested threads are not handled and default
587
+ // to one thread; these cases are only going to happen in tests and not in
588
+ // the matrices that correspond in real models.
589
+ //
590
+ // Returns the maximum number of threads that can be used; <= |num_threads|.
591
+ template <typename OutType = int32_t>
592
+ int PrepareForThreads(int num_threads, int cache_line_size = -1) {
593
+ CHECK_GT(num_threads, 0);
594
+ // we've already prepared for this number of threads, nothing to do
595
+ if (num_threads == num_threads_) return num_threads_;
596
+
597
+ num_threads_ = num_threads;
598
+ thread_bounds_.PrepareForThreads(
599
+ block_width_, block_height_, num_threads_,
600
+ ReducedRowsPerCacheLine<OutType>(cache_line_size), reduced_rows_,
601
+ nnz_per_row_.data());
602
+ return num_threads_;
603
+ }
604
+
605
+ // Computes and stores the |rhs_indices_| from the |col_deltas_|.
606
+ void ComputeRHSIndices() {
607
+ std::vector<int> cumulative_deltas = CumulativeColDeltas();
608
+ std::vector<DeltaType> rhs_indices(cumulative_deltas.size() +
609
+ reduced_rows_);
610
+ int total_indices = 0;
611
+ int delta_index = 0;
612
+ for (int r = 0; r < reduced_rows_; ++r) {
613
+ for (int n = 0; n < nnz_per_row_[r]; ++n, ++delta_index) {
614
+ rhs_indices[total_indices++] =
615
+ cumulative_deltas[delta_index] / block_width_;
616
+ }
617
+ }
618
+ rhs_indices_ = CacheAlignedVector<DeltaType>(rhs_indices);
619
+ }
620
+
621
+ // Computes and stores the |col_deltas_| from the |rhs_indices_|.
622
+ void ComputeColDeltas() {
623
+ std::vector<int> col_deltas(rhs_indices_.size());
624
+ int prev_index = 0;
625
+ for (int i = 0; i < rhs_indices_.size(); ++i) {
626
+ int offset = rhs_indices_[i] - prev_index;
627
+ prev_index = rhs_indices_[i];
628
+ col_deltas[i] = offset * block_width_ * sizeof(RhsType);
629
+ }
630
+ col_deltas_ = CacheAlignedVector<DeltaType>(col_deltas);
631
+ }
632
+
633
+ // Computes and returns the inclusive prefix sum of the deltas, ie absolute
634
+ // positions.
635
+ std::vector<int> CumulativeColDeltas() const {
636
+ std::vector<int> cum_col_deltas(col_deltas_.size());
637
+ for (int i = 0; i < col_deltas_.size(); ++i) {
638
+ cum_col_deltas[i] = col_deltas_[i] / sizeof(RhsType);
639
+ if (i > 0) cum_col_deltas[i] += cum_col_deltas[i - 1];
640
+ }
641
+ return cum_col_deltas;
642
+ }
643
+
644
+ private:
645
+ constexpr std::size_t FixedParameterSize() const {
646
+ return sizeof(int) // rows
647
+ + sizeof(int) // cols
648
+ + sizeof(int) // reduced_rows
649
+ + sizeof(int) // reduced_cols
650
+ + sizeof(int) // block_width
651
+ + sizeof(int) // block_height
652
+ + sizeof(float) // sparsity
653
+ + sizeof(int) // col_multiple
654
+ + sizeof(int) // num_threads_
655
+ + sizeof(int) // weights_.size()
656
+ + sizeof(int) // col_deltas_.size()
657
+ + sizeof(int); // nnz_per_row_.size()
658
+ }
659
+ // Possible block sizes are only those that are supported by the computation
660
+ // default is 1x1, other options are 4x4 and 16x1.
661
+ template <typename InputType>
662
+ void DetermineBlockSize(const MaskedSparseMatrix<InputType>& masked_matrix) {
663
+ const std::vector<std::pair<int, int>> kPreferredOrder = {{4, 4}};
664
+ int rows = masked_matrix.rows();
665
+ int cols = masked_matrix.cols();
666
+
667
+ for (const auto& block_size : kPreferredOrder) {
668
+ int block_height, block_width;
669
+ std::tie(block_height, block_width) = block_size;
670
+ if (cols % block_width != 0) continue;
671
+
672
+ int reduced_rows = (rows + block_height - 1) / block_height;
673
+ int reduced_cols = cols / block_width;
674
+
675
+ // For each possible block, confirm that it is either all 0s or all 1s.
676
+ bool all_same = true;
677
+ const auto& mask = masked_matrix.mask();
678
+ for (int r = 0; r < reduced_rows; ++r) {
679
+ for (int c = 0; c < reduced_cols; ++c) {
680
+ int val = mask[r * block_height * cols + c * block_width];
681
+ for (int i = 0; i < block_height; ++i) {
682
+ for (int j = 0; j < block_width; ++j) {
683
+ int index = (r * block_height + i) * cols + c * block_width + j;
684
+ if (index < masked_matrix.mask().size()) {
685
+ all_same &= (masked_matrix.mask()[index] == val);
686
+ }
687
+ }
688
+ }
689
+ }
690
+ }
691
+
692
+ // If this block configuration is possible, accept it.
693
+ if (all_same) {
694
+ block_height_ = block_height;
695
+ block_width_ = block_width;
696
+ return;
697
+ }
698
+ }
699
+
700
+ // No large blocks were found, default to 1x1.
701
+ block_height_ = 1;
702
+ block_width_ = 1;
703
+ }
704
+
705
+ // CSR descriptors are for the reduced matrix, weights is the full matrix.
706
+ template <typename InputType>
707
+ void MakeColumnsMultiple(const std::vector<int>& row_offsets,
708
+ std::vector<int>* reduced_mask,
709
+ std::vector<InputType>* weights) {
710
+ if (col_multiple_ > 0) {
711
+ // Make sure each row has a number of columns that is a multiple of
712
+ // |col_multiple|.
713
+ for (int r = 1; r < row_offsets.size(); ++r) {
714
+ int num_row = row_offsets[r] - row_offsets[r - 1];
715
+ int num_needed = col_multiple_ - num_row % col_multiple_;
716
+ if (num_needed < col_multiple_) {
717
+ // Find gaps in the columns where we can insert a column of 0 weights.
718
+ int num_added = 0;
719
+ for (int c = 0; c < reduced_cols_; ++c) {
720
+ if ((*reduced_mask)[(r - 1) * reduced_cols_ + c] == 0) {
721
+ (*reduced_mask)[(r - 1) * reduced_cols_ + c] = 1;
722
+
723
+ // Zero out the weights that correspond to this block.
724
+ for (int i = 0; i < block_height_; ++i) {
725
+ for (int j = 0; j < block_width_; ++j) {
726
+ (*weights)[((r - 1) * block_height_ + i) * cols_ +
727
+ block_width_ * c + j] = InputType(0.f);
728
+ }
729
+ }
730
+ num_added++;
731
+ }
732
+
733
+ if (num_added == num_needed) break;
734
+ }
735
+ }
736
+ }
737
+ }
738
+ }
739
+
740
+ // Given the final dense mask and weights, convert to the compressed
741
+ // block CSR representation.
742
+ template <typename InputType>
743
+ void MaskAndWeightsToCsr(const std::vector<int>& mask,
744
+ const std::vector<InputType>& weights,
745
+ std::vector<int>* nnz_per_row,
746
+ std::vector<int>* col_indices,
747
+ std::vector<WeightType>* weights_csr) {
748
+ std::vector<int> row_offsets = {0};
749
+ int nnz = 0;
750
+ // Standard CSR format.
751
+ if (block_width_ == 1 && block_height_ == 1) {
752
+ for (int r = 0; r < rows_; ++r) {
753
+ for (int c = 0; c < cols_; ++c) {
754
+ if (mask[r * cols_ + c] == 1) {
755
+ nnz++;
756
+ col_indices->push_back(c);
757
+ weights_csr->push_back(WeightType(weights[r * cols_ + c]));
758
+ }
759
+ }
760
+ row_offsets.push_back(nnz);
761
+ }
762
+ } else if (block_width_ == 4 && block_height_ == 4) {
763
+ // Weights are stored contiguously for each block in this case.
764
+ for (int r = 0; r < reduced_rows_; ++r) {
765
+ for (int c = 0; c < reduced_cols_; ++c) {
766
+ if (mask[r * reduced_cols_ + c] == 1) {
767
+ col_indices->push_back(c);
768
+ nnz++;
769
+ for (int i = 0; i < block_height_; ++i) {
770
+ for (int j = 0; j < block_width_; ++j) {
771
+ int row_index = (block_height_ * r + i) * cols_;
772
+ int w_index = row_index + block_width_ * c + j;
773
+ WeightType weight = w_index < weights.size()
774
+ ? WeightType(weights[w_index])
775
+ : WeightType(0.0f);
776
+ weights_csr->push_back(weight);
777
+ }
778
+ }
779
+ }
780
+ }
781
+ row_offsets.push_back(nnz);
782
+ }
783
+ }
784
+ for (int i = 1; i < row_offsets.size(); ++i)
785
+ nnz_per_row->push_back(row_offsets[i] - row_offsets[i - 1]);
786
+ }
787
+
788
+ // Returns the number of block rows per cache line. This is the minimum unit
789
+ // into which the calculation is broken for threads.
790
+ template <typename OutType>
791
+ int ReducedRowsPerCacheLine(int override_cache_line_size = -1) const {
792
+ int line_size = kCacheLineSize;
793
+ if (override_cache_line_size >= 1) line_size = override_cache_line_size;
794
+ return std::max<int>(line_size / (block_height_ * sizeof(OutType)), 1);
795
+ }
796
+
797
+ int col_multiple_;
798
+ int rows_;
799
+ int cols_;
800
+ int reduced_rows_;
801
+ int reduced_cols_;
802
+ float sparsity_;
803
+ int block_width_;
804
+ int block_height_;
805
+ int num_threads_;
806
+ std::string name_;
807
+
808
+ CacheAlignedVector<WeightType> weights_;
809
+ CacheAlignedVector<DeltaType> col_deltas_;
810
+ CacheAlignedVector<int> nnz_per_row_;
811
+ // |thread_bounds_| and |rhs_indices_| don't need to be serialized as they are
812
+ // always recalculated from serialized data.
813
+ CacheAlignedVector<DeltaType> rhs_indices_;
814
+ Matmul<WeightType, RhsType> matmul_;
815
+ ThreadBounds thread_bounds_;
816
+ static constexpr int kCacheLineSize = 64;
817
+ };
818
+
819
+ // Converts a sparse matrix represented with (|mask|, |weights|, |size|) into
820
+ // the CSR format, and returns that as a serialized string.
821
+ template <typename MaskType>
822
+ std::string ConvertDenseToSparseRepresentation_Int16Deltas(
823
+ const std::vector<MaskType>& mask, const std::vector<float>& weights,
824
+ const int rows, const int cols) {
825
+ MaskedSparseMatrix<float> masked_weights(rows, cols, mask.data(),
826
+ weights.data());
827
+ CsrBlockSparseMatrix<csrblocksparse::bfloat16, float, int16_t>
828
+ sparse_masked_weights(masked_weights);
829
+ std::string buffer;
830
+ sparse_masked_weights.WriteToFlatBuffer(&buffer);
831
+ return buffer;
832
+ }
833
+
834
+ } // namespace csrblocksparse
835
+ #endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_
sparse_matmul/layers/csrblocksparse_test.cc ADDED
@@ -0,0 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include <array>
16
+ #include <cstdint>
17
+ #include <tuple>
18
+ #include <vector>
19
+
20
+ // Placeholder for get runfiles header.
21
+ #include "absl/status/status.h"
22
+ #include "absl/strings/str_cat.h"
23
+ #include "absl/strings/string_view.h"
24
+ #include "absl/types/span.h"
25
+ #include "gtest/gtest.h"
26
+ #include "include/ghc/filesystem.hpp"
27
+ #include "sparse_matmul/compute/matmul.h"
28
+ #include "sparse_matmul/layers/utils.h"
29
+ #include "sparse_matmul/numerics/test_utils.h"
30
+ #include "sparse_matmul/os/coop_threads.h"
31
+
32
+ namespace csrblocksparse {
33
+ namespace {
34
+
35
+ inline constexpr absl::string_view kTestdataPath = "layers/testdata";
36
+
37
+ TEST(CSRBlockSparseMatrix, FlatBufferSerialization) {
38
+ const int kRows = 8;
39
+ const int kCols = 8;
40
+ std::vector<int> mask = {1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
41
+ 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
42
+ 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1,
43
+ 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1};
44
+ std::vector<float> values(kRows * kCols, 1.f);
45
+ values[1] = 2.f;
46
+ values[3] = 3.f;
47
+ values[36] = -1.f;
48
+ values[45] = -2.f;
49
+
50
+ csrblocksparse::CacheAlignedVector<float> bias(kRows);
51
+ csrblocksparse::CacheAlignedVector<float> rhs(kCols);
52
+ csrblocksparse::CacheAlignedVector<float> out_ref(kRows);
53
+ csrblocksparse::CacheAlignedVector<float> out_test(kRows);
54
+
55
+ bias.FillZero();
56
+ rhs.FillOnes();
57
+
58
+ csrblocksparse::MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(),
59
+ values.data());
60
+
61
+ matrix.SpMM_bias(rhs, bias, &out_ref);
62
+
63
+ csrblocksparse::CsrBlockSparseMatrix<csrblocksparse::bfloat16, float, int16_t>
64
+ block_sparse_matrix(matrix);
65
+
66
+ std::string buffer;
67
+ std::size_t num_bytes = block_sparse_matrix.WriteToFlatBuffer(&buffer);
68
+
69
+ csrblocksparse::CsrBlockSparseMatrix<csrblocksparse::bfloat16, float, int16_t>
70
+ new_block_sparse_matrix(reinterpret_cast<const uint8_t*>(buffer.c_str()),
71
+ num_bytes);
72
+
73
+ new_block_sparse_matrix.SpMM_bias(rhs, bias, &out_test);
74
+
75
+ CheckResult(out_ref, out_test, kCols);
76
+ }
77
+
78
+ template <typename ComputeType, typename RhsType, typename OutType>
79
+ void CorrectnessCheckBlockSpMM(int rows, int cols, int block_height,
80
+ int block_width, float sparsity,
81
+ bool use_relu = false, int num_threads = 1,
82
+ int fatness = 1, bool test_matmul = false) {
83
+ using BiasType = typename TypeOfProduct<ComputeType, RhsType>::type;
84
+ MaskedSparseMatrix<float> matrix(rows, cols, sparsity, block_height,
85
+ block_width);
86
+ matrix.CastWeights<ComputeType>();
87
+ FatCacheAlignedVector<RhsType> rhs(cols, fatness);
88
+ CacheAlignedVector<BiasType> bias(rows);
89
+ FatCacheAlignedVector<OutType> out(rows, fatness);
90
+
91
+ bias.FillRandom();
92
+ rhs.FillRandom();
93
+ out.FillZero();
94
+ FatCacheAlignedVector<OutType> out_reference = out;
95
+
96
+ matrix.SpMM_bias(rhs, bias, &out_reference, use_relu);
97
+
98
+ CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
99
+
100
+ SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
101
+ std::move(sparse_matrix), std::move(bias));
102
+ num_threads = sparse_linear_layer.PrepareForThreads(num_threads);
103
+
104
+ // Checks that the result of applying each thread's portion serially is
105
+ // correct.
106
+ for (int thread_id = 0; thread_id < num_threads; ++thread_id) {
107
+ sparse_linear_layer.SpMM_bias(rhs, &out, use_relu, thread_id);
108
+ }
109
+
110
+ CheckResult(out_reference, out, sparse_linear_layer.cols());
111
+
112
+ if (test_matmul) {
113
+ for (int thread_id = 0; thread_id < num_threads; ++thread_id) {
114
+ sparse_linear_layer.MatVec(rhs, use_relu, thread_id,
115
+ /*replicas=*/1, /*output_stride=*/0, &out);
116
+ }
117
+
118
+ CheckResult(out_reference, out, sparse_linear_layer.cols());
119
+ }
120
+ }
121
+
122
+ // Does:
123
+ // y = Ax + b;
124
+ // x = Ay + b;
125
+ // y = Ax + b;
126
+ //
127
+ // to make sure that dependent multiplies are correct.
128
+ template <typename ComputeType, typename RhsType, typename OutType>
129
+ void ThreadBody(
130
+ SpinBarrier* spin_barrier, int tid,
131
+ const SparseLinearLayer<ComputeType, RhsType>& sparse_linear_layer,
132
+ FatCacheAlignedVector<RhsType>* rhs, FatCacheAlignedVector<OutType>* out,
133
+ bool use_relu) {
134
+ sparse_linear_layer.SpMM_bias(*rhs, out, use_relu, tid);
135
+ spin_barrier->barrier();
136
+ sparse_linear_layer.SpMM_bias(*out, rhs, use_relu, tid);
137
+ spin_barrier->barrier();
138
+ sparse_linear_layer.SpMM_bias(*rhs, out, use_relu, tid);
139
+ }
140
+
141
+ template <typename ComputeType, typename RhsType, typename OutType>
142
+ void CorrectnessCheckBlockSpMM_MultiThread(int rows, int cols, int block_height,
143
+ int block_width, float sparsity,
144
+ bool use_relu = false,
145
+ int num_threads = 1,
146
+ int fatness = 1) {
147
+ typedef typename TypeOfProduct<ComputeType, RhsType>::type BiasType;
148
+ CHECK(rows == cols);
149
+ MaskedSparseMatrix<float> matrix(rows, cols, sparsity, block_height,
150
+ block_width);
151
+ matrix.CastWeights<ComputeType>();
152
+ FatCacheAlignedVector<RhsType> rhs(cols, fatness);
153
+ FatCacheAlignedVector<RhsType> rhs_mt(cols, fatness);
154
+ CacheAlignedVector<BiasType> bias(rows);
155
+ FatCacheAlignedVector<OutType> out(rows, fatness);
156
+
157
+ bias.FillOnes();
158
+ rhs.FillOnes();
159
+ rhs_mt.FillOnes();
160
+ out.FillZero();
161
+ FatCacheAlignedVector<OutType> out_reference = out;
162
+
163
+ matrix.SpMM_bias(rhs, bias, &out_reference, use_relu);
164
+ matrix.SpMM_bias(out_reference, bias, &rhs, use_relu);
165
+ matrix.SpMM_bias(rhs, bias, &out_reference, use_relu);
166
+
167
+ CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
168
+
169
+ num_threads = sparse_matrix.PrepareForThreads(num_threads,
170
+ /*cache_line_size=*/1);
171
+
172
+ SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
173
+ std::move(sparse_matrix), std::move(bias));
174
+
175
+ csrblocksparse::LaunchOnThreadsWithBarrier(
176
+ num_threads, ThreadBody<ComputeType, RhsType, OutType>,
177
+ sparse_linear_layer, &rhs_mt, &out, use_relu);
178
+
179
+ CheckResult(out_reference, out, cols);
180
+ }
181
+
182
+ } // namespace
183
+
184
+ TEST(MaskedSparseCorrectness, HandCoded) {
185
+ const int kRows = 8;
186
+ const int kCols = 8;
187
+ // clang-format off
188
+ std::vector<int> mask = {1, 1, 0, 0, 0, 1, 1, 1,
189
+ 0, 1, 0, 1, 0, 1, 0, 1,
190
+ 1, 0, 0, 1, 1, 1, 1, 0,
191
+ 0, 0, 0, 0, 0, 0, 0, 0,
192
+ 1, 1, 1, 1, 1, 1, 1, 1,
193
+ 0, 0, 0, 0, 1, 1, 0, 0,
194
+ 1, 1, 0, 0, 1, 1, 0, 0,
195
+ 1, 0, 0, 0, 0, 1, 0, 1};
196
+ // clang-format on
197
+ std::vector<float> values(kRows * kCols, 1.f);
198
+
199
+ std::vector<float> answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f};
200
+
201
+ MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(), values.data());
202
+ CacheAlignedVector<float> rhs(kCols);
203
+ CacheAlignedVector<float> bias(kRows);
204
+ CacheAlignedVector<float> out(kRows);
205
+
206
+ bias.FillOnes();
207
+ rhs.FillOnes();
208
+ out.FillZero();
209
+
210
+ MaskedLinearLayer<float> masked_linear_layer(std::move(matrix),
211
+ std::move(bias));
212
+
213
+ masked_linear_layer.SpMM_bias(rhs, &out);
214
+
215
+ for (int i = 0; i < kRows; ++i) {
216
+ EXPECT_EQ(answer[i], out[i]);
217
+ }
218
+ }
219
+
220
+ TEST(MaskedSparseCorrectness, HandCodedFatVector) {
221
+ const int kRows = 8;
222
+ const int kCols = 8;
223
+ // clang-format off
224
+ std::vector<int> mask = {1, 1, 0, 0, 0, 1, 1, 1,
225
+ 0, 1, 0, 1, 0, 1, 0, 1,
226
+ 1, 0, 0, 1, 1, 1, 1, 0,
227
+ 0, 0, 0, 0, 0, 0, 0, 0,
228
+ 1, 1, 1, 1, 1, 1, 1, 1,
229
+ 0, 0, 0, 0, 1, 1, 0, 0,
230
+ 1, 1, 0, 0, 1, 1, 0, 0,
231
+ 1, 0, 0, 0, 0, 1, 0, 1};
232
+ // clang-format on
233
+
234
+ std::vector<float> values(kRows * kCols, 1.f);
235
+ std::vector<float> answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f};
236
+
237
+ MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(), values.data());
238
+ const int kMaxWidth = 5;
239
+ for (int width = 5; width <= kMaxWidth; ++width) {
240
+ FatCacheAlignedVector<float> rhs(kCols, width);
241
+ CacheAlignedVector<float> bias(kRows);
242
+ FatCacheAlignedVector<float> out(kRows, width);
243
+
244
+ bias.FillOnes();
245
+ rhs.FillOnes();
246
+ out.FillZero();
247
+
248
+ MaskedLinearLayer<float> masked_linear_layer(std::move(matrix),
249
+ std::move(bias));
250
+
251
+ masked_linear_layer.SpMM_bias(rhs, &out);
252
+
253
+ for (int i = 0; i < kRows; ++i) {
254
+ for (int width = 0; width < kMaxWidth; ++width) {
255
+ EXPECT_EQ(answer[i], out[i + width * kRows]);
256
+ }
257
+ }
258
+ }
259
+ }
260
+
261
+ TEST(CsrBlockSparseMatrix, HandCodedMultiThread) {
262
+ const int kRows = 8;
263
+ const int kCols = 8;
264
+ // clang-format off
265
+ std::vector<int> mask = {1, 1, 0, 0, 0, 1, 1, 1,
266
+ 0, 1, 0, 1, 0, 1, 0, 1,
267
+ 1, 0, 0, 1, 1, 1, 1, 0,
268
+ 0, 0, 0, 0, 0, 0, 0, 0,
269
+ 1, 1, 1, 1, 1, 1, 1, 1,
270
+ 0, 0, 0, 0, 1, 1, 0, 0,
271
+ 1, 1, 0, 0, 1, 1, 0, 0,
272
+ 1, 0, 0, 0, 0, 1, 0, 1};
273
+ // clang-format on
274
+ std::vector<float> values(kRows * kCols, 1.f);
275
+
276
+ std::vector<float> answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f};
277
+
278
+ MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(), values.data());
279
+ CacheAlignedVector<float> rhs(kCols);
280
+ CacheAlignedVector<float> bias(kRows);
281
+ CacheAlignedVector<float> out(kRows);
282
+
283
+ bias.FillOnes();
284
+ rhs.FillOnes();
285
+ out.FillZero();
286
+
287
+ CacheAlignedVector<float> bias_csr = bias;
288
+
289
+ CsrBlockSparseMatrix<bfloat16, float> sparse_matrix(matrix);
290
+
291
+ MaskedLinearLayer<float> masked_linear_layer(std::move(matrix),
292
+ std::move(bias));
293
+
294
+ masked_linear_layer.SpMM_bias(rhs, &out);
295
+
296
+ SparseLinearLayer<bfloat16, float> sparse_linear_layer(
297
+ std::move(sparse_matrix), std::move(bias_csr));
298
+ sparse_linear_layer.PrepareForThreads(2, /*cache_line_size=*/1);
299
+
300
+ CacheAlignedVector<float> out_tmp(kRows);
301
+ const bool kUseRelu = false;
302
+ sparse_linear_layer.SpMM_bias(rhs, &out_tmp, kUseRelu, /*tid=*/0);
303
+ sparse_linear_layer.SpMM_bias(rhs, &out_tmp, kUseRelu, /*tid=*/1);
304
+
305
+ for (int i = 0; i < kRows; ++i) {
306
+ EXPECT_EQ(answer[i], out_tmp[i]);
307
+ }
308
+ }
309
+
310
+ TEST(TestCasts, TestBfloat16) {
311
+ const int kRows = 1000;
312
+ const int kCols = 100;
313
+ const float kSparsity = 0.f;
314
+
315
+ MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
316
+ MaskedSparseMatrix<float> matrix_bfloat16(kRows, kCols, matrix.mask().data(),
317
+ matrix.values().data());
318
+
319
+ matrix_bfloat16.CastWeights<bfloat16>();
320
+
321
+ CheckResult(matrix.values(), matrix_bfloat16.values(), kCols);
322
+ }
323
+
324
+ TEST(TestCasts, TestFP16) {
325
+ const int kRows = 1000;
326
+ const int kCols = 100;
327
+ const float kSparsity = 0.f;
328
+
329
+ MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
330
+ #if !defined __arm__ && !defined __aarch64__
331
+ // Conversion doesn't handle denormals, so flush denormals to zero first.
332
+ for (int i = 0; i < matrix.values().size(); ++i) {
333
+ if (matrix.data()[i] < 1. / static_cast<float>(1 << 14))
334
+ matrix.data()[i] = 0.f;
335
+ }
336
+ #endif
337
+ MaskedSparseMatrix<float> matrix_fp16(kRows, kCols, matrix.mask().data(),
338
+ matrix.values().data());
339
+
340
+ matrix_fp16.CastWeights<csrblocksparse::fp16>();
341
+
342
+ CheckResult(matrix.values(), matrix_fp16.values(), kCols);
343
+ }
344
+
345
+ TEST(TestCasts, TestFixed16) {
346
+ const int kRows = 100000;
347
+ const int kCols = 1;
348
+ const float kSparsity = 0.f;
349
+
350
+ MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
351
+
352
+ // Relative error for fixed point is high near 0.
353
+ for (int i = 0; i < matrix.values().size(); ++i) {
354
+ // 1.1e-3 is based on the max error of .013 and a grid spacing of 1 / 2**16
355
+ // == 3e-5. 3e-5 / .013 / 2 = 1.1e-3.
356
+ if (std::abs(matrix.data()[i]) < 1.1e-3) {
357
+ matrix.data()[i] = 0.f;
358
+ }
359
+ }
360
+
361
+ MaskedSparseMatrix<float> matrix_fixed16 = matrix;
362
+
363
+ matrix_fixed16.CastWeights<csrblocksparse::fixed16</*ExponentBits=*/0>>();
364
+
365
+ CheckResult(matrix.values(), matrix_fixed16.values(), kCols);
366
+ }
367
+
368
+ TEST(TestCasts, TestFixed32) {
369
+ const int kRows = 100000;
370
+ const int kCols = 1;
371
+ const float kSparsity = 0.f;
372
+
373
+ MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
374
+ MaskedSparseMatrix<float> matrix_fixed32 = matrix;
375
+
376
+ matrix_fixed32.CastWeights<csrblocksparse::fixed32</*ExponentBits=*/0>>();
377
+
378
+ CheckResult(matrix.values(), matrix_fixed32.values(), kCols);
379
+ }
380
+
381
+ template <typename ComputeType, typename RhsType, typename OutType>
382
+ void TestSpMM(int block_width, int block_height, int fatness,
383
+ bool test_matmul = false) {
384
+ std::array<bool, 2> use_relu = {false, true};
385
+ std::vector<float> sparsity_levels = {.5, .8, .9, .95, .98};
386
+ std::vector<std::pair<int, int>> sizes = {{8, 8}, {128, 128}, {128, 64},
387
+ {256, 192}, {512, 512}, {1024, 512},
388
+ {384, 384}, {512, 384}};
389
+ for (int num_threads = 1; num_threads < 2 + test_matmul; ++num_threads) {
390
+ for (const auto& relu : use_relu) {
391
+ for (const auto& sparsity : sparsity_levels) {
392
+ for (const auto& size : sizes) {
393
+ int rows, cols;
394
+ std::tie(rows, cols) = size;
395
+ CorrectnessCheckBlockSpMM<ComputeType, RhsType, OutType>(
396
+ rows, cols, block_height, block_width, sparsity, relu,
397
+ num_threads, fatness, test_matmul);
398
+ }
399
+ }
400
+ }
401
+ }
402
+ }
403
+
404
+ template <typename ComputeType, typename RhsType, typename OutType>
405
+ void TestSpMM_MultiThread(int block_width, int block_height, int fatness) {
406
+ std::array<bool, 2> use_relu = {false, true};
407
+ std::vector<float> sparsity_levels = {.5, .8, .9, .95, .98};
408
+ std::vector<std::pair<int, int>> sizes = {
409
+ {48, 48}, {128, 128}, {512, 512}, {384, 384}};
410
+ for (int num_threads = 1; num_threads < 5; ++num_threads) {
411
+ for (const auto& relu : use_relu) {
412
+ for (const auto& sparsity : sparsity_levels) {
413
+ for (const auto& size : sizes) {
414
+ int rows, cols;
415
+ std::tie(rows, cols) = size;
416
+ CorrectnessCheckBlockSpMM_MultiThread<ComputeType, RhsType, OutType>(
417
+ rows, cols, block_height, block_width, sparsity, relu,
418
+ num_threads, fatness);
419
+ }
420
+ }
421
+ }
422
+ }
423
+ }
424
+
425
+ template <typename DataType>
426
+ void TestSumVectors(int start = 0, int end = -1, int size = 6) {
427
+ std::vector<DataType> values;
428
+ std::vector<DataType> answer;
429
+
430
+ for (int i = 1; i < size + 1; ++i) {
431
+ const float x = static_cast<float>(i);
432
+ values.push_back(static_cast<DataType>(x));
433
+ answer.push_back(static_cast<DataType>(x * 2));
434
+ }
435
+
436
+ if (end == -1) {
437
+ end = values.size();
438
+ }
439
+
440
+ csrblocksparse::CacheAlignedVector<DataType> result(values.size());
441
+ csrblocksparse::CacheAlignedVector<DataType> values_aligned(values);
442
+ detail::SumVectors(start, end, values_aligned.data(), values_aligned.data(),
443
+ result.data());
444
+ for (int i = start; i < end; ++i) {
445
+ EXPECT_EQ(static_cast<float>(answer[i]), static_cast<float>(result[i]));
446
+ }
447
+ }
448
+
449
+ TEST(CsrBlockSparseMatrix, SumVectors_Generic) {
450
+ TestSumVectors<float>();
451
+ TestSumVectors<float>(1);
452
+ TestSumVectors<float>(1, 4);
453
+ }
454
+
455
+ TEST(CsrBlockSparseMatrix, SumVectors_Bfloat16) {
456
+ TestSumVectors<csrblocksparse::bfloat16>();
457
+ TestSumVectors<csrblocksparse::bfloat16>(1);
458
+ TestSumVectors<csrblocksparse::bfloat16>(1, 4);
459
+ }
460
+
461
+ // For SIMD-optimized SumVectors, the memory of the vector should be at least
462
+ // |kSIMDWidth * sizeof(float)| long, and the start position has to be an
463
+ // aligned memory location. So setting |size| to be 100 to be safe and
464
+ // |start| to be 0 (|start| == 1 is not aligned).
465
+ TEST(CsrBlockSparseMatrix, SumVectors_Fixed16) {
466
+ TestSumVectors<csrblocksparse::fixed16<8>>(0, -1, 100);
467
+ TestSumVectors<csrblocksparse::fixed16<8>>(0, 4, 100);
468
+ }
469
+
470
+ TEST(CsrBlockSparseMatrix, SumVectors_Fixed32) {
471
+ TestSumVectors<csrblocksparse::fixed32<11>>(0, -1, 100);
472
+ TestSumVectors<csrblocksparse::fixed32<11>>(0, 4, 100);
473
+ }
474
+
475
+ TEST(CsrBlockSparseMatrix, SpMM_Block4x4_Bfloat16) {
476
+ TestSpMM<csrblocksparse::bfloat16, float, float>(/*block_width=*/4,
477
+ /*block_height=*/4,
478
+ /*fatness=*/7);
479
+ }
480
+
481
+ // This actually uses multiple threads, and uses the output as the input for
482
+ // multiple steps to test that synchronization and memory visibility is
483
+ // working correctly.Requires square matrices.
484
+ TEST(CsrBlockSparseMatrix, SpMV_4x4MultiThreading_Bfloat16) {
485
+ TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
486
+ /*block_width=*/4,
487
+ /*block_height=*/4,
488
+ /*fatness=*/1);
489
+ }
490
+
491
+ TEST(CsrBlockSparseMatrix, SpMM_4x4MultiThreading_Bfloat16) {
492
+ TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
493
+ /*block_width=*/4,
494
+ /*block_height=*/4,
495
+ /*fatness=*/7);
496
+ }
497
+
498
+ TEST(CsrBlockSparseMatrix, SpMV_Block1x1_Bfloat16) {
499
+ TestSpMM<csrblocksparse::bfloat16, float, float>(/*block_width=*/1,
500
+ /*block_height=*/1,
501
+ /*fatness=*/1);
502
+ }
503
+
504
+ TEST(CsrBlockSparseMatrix, SpMM_Block1x1_Bfloat16) {
505
+ TestSpMM<csrblocksparse::bfloat16, float, float>(/*block_width=*/1,
506
+ /*block_height=*/1,
507
+ /*fatness=*/7);
508
+ }
509
+
510
+ // This actually uses multiple threads, and uses the output as the input for
511
+ // multiple steps to test that synchronization and memory visibility is
512
+ // working correctly.Requires square matrices.
513
+ TEST(CsrBlockSparseMatrix, SpMV_1x1MultiThreading_Bfloat16) {
514
+ TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
515
+ /*block_width=*/1,
516
+ /*block_height=*/1,
517
+ /*fatness=*/1);
518
+ }
519
+
520
+ TEST(CsrBlockSparseMatrix, SpMM_1x1MultiThreading_Bfloat16) {
521
+ TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
522
+ /*block_width=*/1,
523
+ /*block_height=*/1,
524
+ /*fatness=*/7);
525
+ }
526
+
527
+ TEST(CsrBlockSparseMatrix, SpMV_Block4x4_float) {
528
+ TestSpMM<float, float, float>(/*block_width=*/4,
529
+ /*block_height=*/4,
530
+ /*fatness=*/1,
531
+ /*test_matmul=*/true);
532
+ }
533
+
534
+ TEST(CsrBlockSparseMatrix, SpMM_Block4x4_float) {
535
+ TestSpMM<float, float, float>(/*block_width=*/4,
536
+ /*block_height=*/4,
537
+ /*fatness=*/7);
538
+ }
539
+
540
+ // This actually uses multiple threads, and uses the output as the input for
541
+ // multiple steps to test that synchronization and memory visibility is
542
+ // working correctly.Requires square matrices.
543
+ TEST(CsrBlockSparseMatrix, SpMV_4x4MultiThreading_float) {
544
+ TestSpMM_MultiThread<float, float, float>(/*block_width=*/4,
545
+ /*block_height=*/4,
546
+ /*fatness=*/1);
547
+ }
548
+
549
+ TEST(CsrBlockSparseMatrix, SpMM_4x4MultiThreading_float) {
550
+ TestSpMM_MultiThread<float, float, float>(/*block_width=*/4,
551
+ /*block_height=*/4,
552
+ /*fatness=*/7);
553
+ }
554
+
555
+ TEST(CsrBlockSparseMatrix, SpMV_Block1x1_float) {
556
+ TestSpMM<float, float, float>(/*block_width=*/1,
557
+ /*block_height=*/1,
558
+ /*fatness=*/1);
559
+ }
560
+
561
+ TEST(CsrBlockSparseMatrix, SpMM_Block1x1_float) {
562
+ TestSpMM<float, float, float>(/*block_width=*/1,
563
+ /*block_height=*/1,
564
+ /*fatness=*/7);
565
+ }
566
+
567
+ // This actually uses multiple threads, and uses the output as the input for
568
+ // multiple steps to test that synchronization and memory visibility is
569
+ // working correctly.Requires square matrices.
570
+ TEST(CsrBlockSparseMatrix, SpMV_1x1MultiThreading_float) {
571
+ TestSpMM_MultiThread<float, float, float>(/*block_width=*/1,
572
+ /*block_height=*/1,
573
+ /*fatness=*/1);
574
+ }
575
+
576
+ TEST(CsrBlockSparseMatrix, SpMM_1x1MultiThreading_float) {
577
+ TestSpMM_MultiThread<float, float, float>(/*block_width=*/1,
578
+ /*block_height=*/1,
579
+ /*fatness=*/7);
580
+ }
581
+
582
+ TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_32) {
583
+ TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
584
+ typename csrblocksparse::TypeOfProduct<
585
+ csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
586
+ /*block_width=*/4,
587
+ /*block_height=*/4,
588
+ /*fatness=*/1,
589
+ /*test_matmul=*/true);
590
+ }
591
+
592
+ TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_32) {
593
+ TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
594
+ typename csrblocksparse::TypeOfProduct<
595
+ csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
596
+ /*block_width=*/4,
597
+ /*block_height=*/4,
598
+ /*fatness=*/7);
599
+ }
600
+
601
+ TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_32) {
602
+ TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
603
+ typename csrblocksparse::TypeOfProduct<
604
+ csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
605
+ /*block_width=*/1,
606
+ /*block_height=*/1,
607
+ /*fatness=*/1);
608
+ }
609
+
610
+ TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_32) {
611
+ TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
612
+ typename csrblocksparse::TypeOfProduct<
613
+ csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
614
+ /*block_width=*/1,
615
+ /*block_height=*/1,
616
+ /*fatness=*/7);
617
+ }
618
+
619
+ TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_16) {
620
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
621
+ csrblocksparse::fixed16<8>>(
622
+ /*block_width=*/4,
623
+ /*block_height=*/4,
624
+ /*fatness=*/1,
625
+ /*test_matmul=*/true);
626
+ }
627
+
628
+ TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_16) {
629
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
630
+ csrblocksparse::fixed16<8>>(
631
+ /*block_width=*/4,
632
+ /*block_height=*/4,
633
+ /*fatness=*/7);
634
+ }
635
+
636
+ TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_16) {
637
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
638
+ csrblocksparse::fixed16<8>>(
639
+ /*block_width=*/1,
640
+ /*block_height=*/1,
641
+ /*fatness=*/1);
642
+ }
643
+
644
+ TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_16) {
645
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
646
+ csrblocksparse::fixed16<8>>(
647
+ /*block_width=*/1,
648
+ /*block_height=*/1,
649
+ /*fatness=*/7);
650
+ }
651
+
652
+ TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_32_unmatched) {
653
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
654
+ csrblocksparse::fixed32<13>>(
655
+ /*block_width=*/4,
656
+ /*block_height=*/4,
657
+ /*fatness=*/1,
658
+ /*test_matmul=*/true);
659
+ }
660
+
661
+ TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_32_unmatched) {
662
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
663
+ csrblocksparse::fixed32<13>>(
664
+ /*block_width=*/4,
665
+ /*block_height=*/4,
666
+ /*fatness=*/7);
667
+ }
668
+
669
+ TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_32_unmatched) {
670
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
671
+ csrblocksparse::fixed32<13>>(
672
+ /*block_width=*/1,
673
+ /*block_height=*/1,
674
+ /*fatness=*/1);
675
+ }
676
+
677
+ TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_32_unmatched) {
678
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
679
+ csrblocksparse::fixed32<13>>(
680
+ /*block_width=*/1,
681
+ /*block_height=*/1,
682
+ /*fatness=*/7);
683
+ }
684
+
685
+ TEST(CsrBlockSparseMatrix, RhsIndicesDeltasRoundTrip) {
686
+ MaskedSparseMatrix<float> matrix(/*rows=*/256, /*cols=*/256,
687
+ /*sparsity=*/0.9, /*block_height=*/4,
688
+ /*block_width=*/4);
689
+ CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
690
+ CacheAlignedVector<int16_t> copy_indices = sparse_matrix.rhs_indices();
691
+ sparse_matrix.ComputeColDeltas();
692
+ sparse_matrix.ComputeRHSIndices();
693
+ // They get padded when created, so the newer one could be bigger.
694
+ EXPECT_LE(copy_indices.size(), sparse_matrix.rhs_indices().size());
695
+ for (int i = 0; i < copy_indices.size(); ++i) {
696
+ EXPECT_EQ(copy_indices[i], sparse_matrix.rhs_indices()[i]) << "i=" << i;
697
+ }
698
+ }
699
+
700
+ // Tests that a Layer that is split into 2 by columns (inputs) computes the same
701
+ // result as the original layer.
702
+ TEST(CsrBlockSparseMatrix, SplitByCol) {
703
+ int kRows = 1024;
704
+ int kCols = 1024;
705
+ MaskedSparseMatrix<float> matrix(kRows, kCols, 0.95, /*block_height=*/4,
706
+ /*block_width=*/4);
707
+ FatCacheAlignedVector<float> rhs(kCols, /*cols=*/1);
708
+ CacheAlignedVector<float> bias(kRows);
709
+ FatCacheAlignedVector<float> out1(kRows, /*cols=*/1);
710
+ FatCacheAlignedVector<float> out2(kRows, /*cols=*/1);
711
+
712
+ bias.FillRandom();
713
+ rhs.FillRandom();
714
+ out1.FillZero();
715
+ out2.FillZero();
716
+ FatCacheAlignedVector<float> out_reference = out1;
717
+
718
+ CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
719
+
720
+ SparseLinearLayer<float, float> sparse_linear_layer(std::move(sparse_matrix),
721
+ std::move(bias));
722
+ sparse_linear_layer.PrepareForThreads(1);
723
+ sparse_linear_layer.SpMM_bias(rhs, &out_reference, /*relu=*/false,
724
+ /*tid=*/0);
725
+ // Split the layer into 2 parts.
726
+ SparseLinearLayer<float, float> part1, part2;
727
+ sparse_linear_layer.SplitInputs(&part1, &part2);
728
+ part1.PrepareForThreads(1);
729
+ part2.PrepareForThreads(1);
730
+ EXPECT_EQ(kRows, part1.rows());
731
+ EXPECT_EQ(kCols / 2, part1.cols());
732
+ EXPECT_EQ(kRows, part2.rows());
733
+ EXPECT_EQ(kCols / 2, part2.cols());
734
+ MutableVectorView<float> rhs1(&rhs, 0, kCols / 2);
735
+ MutableVectorView<float> rhs2(&rhs, kCols / 2, kCols / 2);
736
+ for (int i = 0; i < kCols / 2; ++i) {
737
+ EXPECT_FLOAT_EQ(rhs[i], rhs1.data()[i]);
738
+ EXPECT_FLOAT_EQ(rhs[i + kCols / 2], rhs2.data()[i]);
739
+ }
740
+ part1.SpMM_bias(rhs1, &out1, /*relu=*/false, /*tid=*/0);
741
+ part2.SpMM_bias(rhs2, &out2, /*relu=*/false, /*tid=*/0);
742
+ // Check that out1 + out2 = out_reference.
743
+ for (int i = 0; i < kRows; ++i) {
744
+ EXPECT_NEAR(out_reference[i], out1[i] + out2[i], 2e-5)
745
+ << " i=" << i << " out1=" << out1[i] << " out2=" << out2[i];
746
+ }
747
+ }
748
+ // Tests that a Layer that is split into 2 by rows (outputs) computes the same
749
+ // result as the original layer.
750
+ TEST(CsrBlockSparseMatrix, SplitByRow) {
751
+ int kRows = 1024;
752
+ int kCols = 1024;
753
+ MaskedSparseMatrix<float> matrix(kRows, kCols, 0.95, /*block_height=*/4,
754
+ /*block_width=*/4);
755
+ FatCacheAlignedVector<float> rhs(kCols, /*cols=*/1);
756
+ CacheAlignedVector<float> bias(kRows);
757
+ FatCacheAlignedVector<float> out1(kRows, /*cols=*/1);
758
+ FatCacheAlignedVector<float> out2(kRows, /*cols=*/1);
759
+
760
+ bias.FillRandom();
761
+ rhs.FillRandom();
762
+ out1.FillZero();
763
+ out2.FillZero();
764
+ FatCacheAlignedVector<float> out_reference = out1;
765
+
766
+ CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
767
+
768
+ SparseLinearLayer<float, float> sparse_linear_layer(std::move(sparse_matrix),
769
+ std::move(bias));
770
+ sparse_linear_layer.PrepareForThreads(1);
771
+ sparse_linear_layer.SpMM_bias(rhs, &out_reference, /*relu=*/false,
772
+ /*tid=*/0);
773
+ // Split the layer into 2 parts.
774
+ SparseLinearLayer<float, float> part1, part2;
775
+ sparse_linear_layer.SplitOutputs(&part1, &part2);
776
+ part1.PrepareForThreads(1);
777
+ part2.PrepareForThreads(1);
778
+ EXPECT_EQ(kRows / 2, part1.rows());
779
+ EXPECT_EQ(kCols, part1.cols());
780
+ EXPECT_EQ(kRows / 2, part2.rows());
781
+ EXPECT_EQ(kCols, part2.cols());
782
+ MutableVectorView<float> out2a(&out2, 0, kRows / 2);
783
+ MutableVectorView<float> out2b(&out2, kRows / 2, kRows / 2);
784
+ part1.SpMM_bias(rhs, &out2a, /*relu=*/false, /*tid=*/0);
785
+ part2.SpMM_bias(rhs, &out2b, /*relu=*/false, /*tid=*/0);
786
+ // Check that out2 = out_reference.
787
+ for (int i = 0; i < kRows; ++i) {
788
+ EXPECT_NEAR(out_reference[i], out2[i], 2e-5)
789
+ << " i=" << i << " out1=" << out_reference[i] << " out2=" << out2[i];
790
+ }
791
+ }
792
+
793
+ TEST(CsrBlockSparseMatrix, MutableVectorView) {
794
+ const int kRows = 1024;
795
+ const int kCols = 1024;
796
+ const int kFatness = 2;
797
+
798
+ std::vector<float> values(kRows * kCols, 1.f);
799
+ std::vector<int> mask(kRows * kCols);
800
+ for (int i = 0; i < mask.size(); ++i) mask[i] = i % 2;
801
+
802
+ auto masked_matrix =
803
+ MaskedSparseMatrix<float>(kRows, kCols, mask.data(), values.data());
804
+ auto sparse_matrix = CsrBlockSparseMatrix<bfloat16, float>(masked_matrix);
805
+ FatCacheAlignedVector<float> x(kCols, kFatness);
806
+ x.FillOnes();
807
+
808
+ CacheAlignedVector<float> bias(kRows);
809
+ bias.FillZero();
810
+
811
+ // First check that we can use spans as output. Split a multiplication
812
+ // into upper and lower halves times the full vector:
813
+ // --------------- x t
814
+ // | | x t
815
+ // | | x t
816
+ // --------------- =
817
+ // | | x b
818
+ // | | x b
819
+ // --------------- x b
820
+
821
+ FatCacheAlignedVector<float> out(kRows, kFatness);
822
+ FatCacheAlignedVector<float> out_view(kRows, kFatness);
823
+
824
+ MutableVectorView<float> out_view_top(&out_view, 0, kRows / 2);
825
+ MutableVectorView<float> out_view_bottom(&out_view, kRows / 2, kRows / 2);
826
+
827
+ sparse_matrix.SpMM_bias(x, bias, &out);
828
+
829
+ auto masked_matrix_top =
830
+ MaskedSparseMatrix<float>(kRows / 2, kCols, mask.data(), values.data());
831
+ auto masked_matrix_bottom = MaskedSparseMatrix<float>(
832
+ kRows / 2, kCols, mask.data() + kRows * kCols / 2,
833
+ values.data() + kRows * kCols / 2);
834
+ auto sparse_matrix_top =
835
+ CsrBlockSparseMatrix<bfloat16, float>(masked_matrix_top);
836
+ auto sparse_matrix_bottom =
837
+ CsrBlockSparseMatrix<bfloat16, float>(masked_matrix_bottom);
838
+
839
+ sparse_matrix_top.SpMM_bias(x, bias, &out_view_top);
840
+ sparse_matrix_bottom.SpMM_bias(x, bias, &out_view_bottom);
841
+
842
+ CheckResult(out, out_view, kCols);
843
+
844
+ // Check that we can use a span as an input vector. Multiply upper left
845
+ // portion of the matrix by the top half of the vector.
846
+ // ---------------
847
+ // |oooooo | x q
848
+ // |oooooo | x q
849
+ // | | =
850
+ // | |
851
+ // ---------------
852
+
853
+ auto masked_matrix_quarter = MaskedSparseMatrix<float>(
854
+ kRows / 2, kCols / 2, mask.data(), values.data());
855
+ auto sparse_matrix_quarter =
856
+ CsrBlockSparseMatrix<bfloat16, float>(masked_matrix_quarter);
857
+
858
+ MutableVectorView<float> x_top(&x, 0, kCols / 2);
859
+ FatCacheAlignedVector<float> out_correct(kRows / 2, /*cols=*/2);
860
+
861
+ for (int i = 0; i < kFatness * (kRows / 2); ++i) out_correct[i] = 256.f;
862
+
863
+ MutableVectorView<float> bias_top(&bias, 0, kRows / 2);
864
+ FatCacheAlignedVector<float> out_quarter(kRows / 2, kFatness);
865
+
866
+ sparse_matrix_quarter.SpMM_bias(x_top, bias_top, &out_quarter);
867
+
868
+ CheckResult(out_correct, out_quarter, kCols / 2);
869
+ }
870
+
871
+ namespace {
872
+
873
+ bool skip_test(const absl::Status& status, absl::string_view msg) {
874
+ if (!status.ok()) {
875
+ LOG(INFO) << "Couldn't load " << msg << ", skipping test " << status;
876
+ return true;
877
+ }
878
+
879
+ return false;
880
+ }
881
+
882
+ } // namespace
883
+
884
+ TEST(CsrBlockSparseMatrix, ModelMatrices_Bfloat16) {
885
+ std::vector<std::string> names = {
886
+ "768_512_95_4x4_wavernn_gru_", "768_512_95_4x4_coarseproj_",
887
+ "768_512_95_4x4_coarselogit_", "768_512_95_4x4_fineproj_",
888
+ "768_512_95_4x4_finelogit_", "lyra_conv1d_"};
889
+ const std::string kPath =
890
+ #if defined __arm__ || defined __aarch64__
891
+ "/data/local/tmp/";
892
+ #else
893
+ (ghc::filesystem::current_path() / kTestdataPath).string();
894
+ #endif
895
+ for (auto& layer_name : names) {
896
+ SparseLinearLayer<bfloat16, float> sparse_linear_layer;
897
+ auto status = LoadSparseLayer<bfloat16, float>(layer_name, /*zipped=*/true,
898
+ &sparse_linear_layer, kPath);
899
+ // If the files don't exist on the device we're running on, just skip this
900
+ // test and log that it was skipped.
901
+ if (skip_test(status, layer_name)) return;
902
+
903
+ int rows = sparse_linear_layer.rows();
904
+ int cols = sparse_linear_layer.cols();
905
+
906
+ MaskedLinearLayer<float> masked_linear_layer;
907
+ status = LoadMaskedLayer<float>(layer_name, /*zipped=*/true,
908
+ &masked_linear_layer, kPath);
909
+ if (skip_test(status, layer_name)) return;
910
+ masked_linear_layer.CastWeights<csrblocksparse::bfloat16>();
911
+
912
+ CacheAlignedVector<float> rhs(cols);
913
+ CacheAlignedVector<float> out_ref(rows);
914
+ CacheAlignedVector<float> out_spmv(rows);
915
+
916
+ rhs.FillRandom();
917
+ out_ref.FillZero();
918
+ out_spmv.FillZero();
919
+
920
+ std::array<bool, 2> use_relus = {false, true};
921
+ for (bool use_relu : use_relus) {
922
+ masked_linear_layer.SpMM_bias(rhs, &out_ref, use_relu);
923
+ sparse_linear_layer.SpMM_bias(rhs, &out_spmv, use_relu);
924
+
925
+ CheckResult(out_ref, out_spmv, cols);
926
+ }
927
+ }
928
+ }
929
+
930
+ TEST(CsrBlockSparseMatrix, ModelMatrices_float) {
931
+ std::vector<std::string> names = {
932
+ "768_512_95_4x4_wavernn_gru_", "768_512_95_4x4_coarseproj_",
933
+ "768_512_95_4x4_coarselogit_", "768_512_95_4x4_fineproj_",
934
+ "768_512_95_4x4_finelogit_", "lyra_conv1d_"};
935
+ const std::string kPath =
936
+ #if defined __arm__ || defined __aarch64__
937
+ "/data/local/tmp/";
938
+ #else
939
+ (ghc::filesystem::current_path() / kTestdataPath).string();
940
+ #endif
941
+ for (auto& layer_name : names) {
942
+ SparseLinearLayer<float, float> sparse_linear_layer;
943
+ auto status = LoadSparseLayer<float, float>(layer_name, /*zipped=*/true,
944
+ &sparse_linear_layer, kPath);
945
+ // If the files don't exist on the device we're running on, just skip this
946
+ // test and log that it was skipped.
947
+ if (skip_test(status, layer_name)) return;
948
+
949
+ int rows = sparse_linear_layer.rows();
950
+ int cols = sparse_linear_layer.cols();
951
+
952
+ MaskedLinearLayer<float> masked_linear_layer;
953
+ status = LoadMaskedLayer<float>(layer_name, /*zipped=*/true,
954
+ &masked_linear_layer, kPath);
955
+ if (skip_test(status, layer_name)) return;
956
+
957
+ CacheAlignedVector<float> rhs(cols);
958
+ CacheAlignedVector<float> out_ref(rows);
959
+ CacheAlignedVector<float> out_spmv(rows);
960
+
961
+ rhs.FillRandom();
962
+ out_ref.FillZero();
963
+ out_spmv.FillZero();
964
+
965
+ std::array<bool, 2> use_relus = {false, true};
966
+ for (bool use_relu : use_relus) {
967
+ masked_linear_layer.SpMM_bias(rhs, &out_ref, use_relu);
968
+ sparse_linear_layer.SpMM_bias(rhs, &out_spmv, use_relu);
969
+
970
+ CheckResult(out_ref, out_spmv, cols);
971
+ }
972
+ }
973
+ }
974
+
975
+ #undef SKIP_TEST
976
+
977
+ } // namespace csrblocksparse
sparse_matmul/layers/errno_mapping.cc ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include "sparse_matmul/layers/errno_mapping.h"
16
+
17
+ #include <string>
18
+
19
+ #include "absl/strings/str_cat.h"
20
+
21
+ namespace csrblocksparse {
22
+
23
+ namespace {
24
+
25
+ absl::StatusCode ErrnoToCode(int error_number) {
26
+ switch (error_number) {
27
+ case 0:
28
+ return absl::StatusCode::kOk;
29
+ case EINVAL: // Invalid argument
30
+ case ENAMETOOLONG: // Filename too long
31
+ case E2BIG: // Argument list too long
32
+ case EDESTADDRREQ: // Destination address required
33
+ case EDOM: // Mathematics argument out of domain of function
34
+ case EFAULT: // Bad address
35
+ case EILSEQ: // Illegal byte sequence
36
+ case ENOPROTOOPT: // Protocol not available
37
+ case ENOSTR: // Not a STREAM
38
+ case ENOTSOCK: // Not a socket
39
+ case ENOTTY: // Inappropriate I/O control operation
40
+ case EPROTOTYPE: // Protocol wrong type for socket
41
+ case ESPIPE: // Invalid seek
42
+ return absl::StatusCode::kInvalidArgument;
43
+ case ETIMEDOUT: // Connection timed out
44
+ case ETIME: // Timer expired
45
+ return absl::StatusCode::kDeadlineExceeded;
46
+ case ENODEV: // No such device
47
+ case ENOENT: // No such file or directory
48
+ #ifdef ENOMEDIUM
49
+ case ENOMEDIUM: // No medium found
50
+ #endif
51
+ case ENXIO: // No such device or address
52
+ case ESRCH: // No such process
53
+ return absl::StatusCode::kNotFound;
54
+ case EEXIST: // File exists
55
+ case EADDRNOTAVAIL: // Address not available
56
+ case EALREADY: // Connection already in progress
57
+ #ifdef ENOTUNIQ
58
+ case ENOTUNIQ: // Name not unique on network
59
+ #endif
60
+ return absl::StatusCode::kAlreadyExists;
61
+ case EPERM: // Operation not permitted
62
+ case EACCES: // Permission denied
63
+ #ifdef ENOKEY
64
+ case ENOKEY: // Required key not available
65
+ #endif
66
+ case EROFS: // Read only file system
67
+ return absl::StatusCode::kPermissionDenied;
68
+ case ENOTEMPTY: // Directory not empty
69
+ case EISDIR: // Is a directory
70
+ case ENOTDIR: // Not a directory
71
+ case EADDRINUSE: // Address already in use
72
+ case EBADF: // Invalid file descriptor
73
+ #ifdef EBADFD
74
+ case EBADFD: // File descriptor in bad state
75
+ #endif
76
+ case EBUSY: // Device or resource busy
77
+ case ECHILD: // No child processes
78
+ case EISCONN: // Socket is connected
79
+ #ifdef EISNAM
80
+ case EISNAM: // Is a named type file
81
+ #endif
82
+ #ifdef ENOTBLK
83
+ case ENOTBLK: // Block device required
84
+ #endif
85
+ case ENOTCONN: // The socket is not connected
86
+ case EPIPE: // Broken pipe
87
+ #ifdef ESHUTDOWN
88
+ case ESHUTDOWN: // Cannot send after transport endpoint shutdown
89
+ #endif
90
+ case ETXTBSY: // Text file busy
91
+ #ifdef EUNATCH
92
+ case EUNATCH: // Protocol driver not attached
93
+ #endif
94
+ return absl::StatusCode::kFailedPrecondition;
95
+ case ENOSPC: // No space left on device
96
+ #ifdef EDQUOT
97
+ case EDQUOT: // Disk quota exceeded
98
+ #endif
99
+ case EMFILE: // Too many open files
100
+ case EMLINK: // Too many links
101
+ case ENFILE: // Too many open files in system
102
+ case ENOBUFS: // No buffer space available
103
+ case ENODATA: // No message is available on the STREAM read queue
104
+ case ENOMEM: // Not enough space
105
+ case ENOSR: // No STREAM resources
106
+ #ifdef EUSERS
107
+ case EUSERS: // Too many users
108
+ #endif
109
+ return absl::StatusCode::kResourceExhausted;
110
+ #ifdef ECHRNG
111
+ case ECHRNG: // Channel number out of range
112
+ #endif
113
+ case EFBIG: // File too large
114
+ case EOVERFLOW: // Value too large to be stored in data type
115
+ case ERANGE: // Result too large
116
+ return absl::StatusCode::kOutOfRange;
117
+ #ifdef ENOPKG
118
+ case ENOPKG: // Package not installed
119
+ #endif
120
+ case ENOSYS: // Function not implemented
121
+ case ENOTSUP: // Operation not supported
122
+ case EAFNOSUPPORT: // Address family not supported
123
+ #ifdef EPFNOSUPPORT
124
+ case EPFNOSUPPORT: // Protocol family not supported
125
+ #endif
126
+ case EPROTONOSUPPORT: // Protocol not supported
127
+ #ifdef ESOCKTNOSUPPORT
128
+ case ESOCKTNOSUPPORT: // Socket type not supported
129
+ #endif
130
+ case EXDEV: // Improper link
131
+ return absl::StatusCode::kUnimplemented;
132
+ case EAGAIN: // Resource temporarily unavailable
133
+ #ifdef ECOMM
134
+ case ECOMM: // Communication error on send
135
+ #endif
136
+ case ECONNREFUSED: // Connection refused
137
+ case ECONNABORTED: // Connection aborted
138
+ case ECONNRESET: // Connection reset
139
+ case EINTR: // Interrupted function call
140
+ #ifdef EHOSTDOWN
141
+ case EHOSTDOWN: // Host is down
142
+ #endif
143
+ case EHOSTUNREACH: // Host is unreachable
144
+ case ENETDOWN: // Network is down
145
+ case ENETRESET: // Connection aborted by network
146
+ case ENETUNREACH: // Network unreachable
147
+ case ENOLCK: // No locks available
148
+ case ENOLINK: // Link has been severed
149
+ #ifdef ENONET
150
+ case ENONET: // Machine is not on the network
151
+ #endif
152
+ return absl::StatusCode::kUnavailable;
153
+ case EDEADLK: // Resource deadlock avoided
154
+ #ifdef ESTALE
155
+ case ESTALE: // Stale file handle
156
+ #endif
157
+ return absl::StatusCode::kAborted;
158
+ case ECANCELED: // Operation cancelled
159
+ return absl::StatusCode::kCancelled;
160
+ default:
161
+ return absl::StatusCode::kUnknown;
162
+ }
163
+ }
164
+
165
+ // POSIX `strerror_r()` returns `int`.
166
+ ABSL_ATTRIBUTE_UNUSED std::string StrErrorResult(int result, const char* buffer,
167
+ int error_code) {
168
+ if (ABSL_PREDICT_FALSE(result != 0)) {
169
+ return absl::StrCat("Unknown error ", error_code);
170
+ }
171
+ return buffer;
172
+ }
173
+
174
+ // GNU `strerror_r()` returns `char*`.
175
+ ABSL_ATTRIBUTE_UNUSED std::string StrErrorResult(char* result,
176
+ const char* buffer,
177
+ int error_code) {
178
+ return result;
179
+ }
180
+
181
+ std::string StrError(int error_code) {
182
+ char message[256];
183
+ return StrErrorResult(strerror_r(error_code, message, sizeof(message)),
184
+ message, error_code);
185
+ }
186
+
187
+ } // namespace
188
+
189
+ absl::Status ErrnoToCanonicalStatus(int error_number,
190
+ absl::string_view message) {
191
+ return absl::Status(ErrnoToCode(error_number),
192
+ absl::StrCat(message, ": ", StrError(error_number)));
193
+ }
194
+
195
+ } // namespace csrblocksparse
sparse_matmul/layers/errno_mapping.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_
16
+ #define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_
17
+
18
+ #include "absl/status/status.h"
19
+ #include "absl/strings/string_view.h"
20
+
21
+ namespace csrblocksparse {
22
+
23
+ // Converts |error_number| value to absl::Status.
24
+ absl::Status ErrnoToCanonicalStatus(int error_number,
25
+ absl::string_view message);
26
+
27
+ } // namespace csrblocksparse
28
+
29
+ #endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_
sparse_matmul/layers/masked_sparse_matrix.h ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_
19
+
20
+ #include <algorithm>
21
+ #include <cstdio>
22
+ #include <numeric>
23
+ #include <vector>
24
+
25
+ #include "absl/strings/str_format.h"
26
+ #include "sparse_matmul/vector/cache_aligned_vector.h"
27
+
28
+ namespace csrblocksparse {
29
+
30
+ // MaskedSparseMatrix serves two purposes:
31
+ // 1) It is useful as a reference implementation of SpMV for correctness
32
+ // checking the much more complicated implementations in CSRBlockSparseMatrix
33
+ // 2) This is the format that sparse matrices are represented after pruning
34
+ // in TF. This class provides a bridge to getting these parameters into
35
+ // a compressed form suitable for computation and serialization.
36
+ //
37
+ // MaskedSparseMatrix<float> matrix(rows, cols, mask_from_tf, values_from_tf);
38
+ // CSRBlockSparseMatrix<float, bfloat16, int16_t> csr_matrix(matrix);
39
+ // csr_matrix.Multiply(rhs, bias, &out);
40
+ template <typename T>
41
+ class MaskedSparseMatrix {
42
+ public:
43
+ MaskedSparseMatrix() {}
44
+
45
+ // Construct a MaskedSparseMatrix of the given size, sparsity and block size.
46
+ // This is mainly useful for testing.
47
+ MaskedSparseMatrix(int rows, int cols, float sparsity, int block_height = 1,
48
+ int block_width = 1, float constant = 1.f,
49
+ bool random = true)
50
+ : rows_(rows), cols_(cols), sparsity_(sparsity) {
51
+ CHECK_EQ(rows % block_height, 0);
52
+ CHECK_EQ(cols % block_width, 0);
53
+
54
+ init(sparsity, block_height, block_width, constant, random);
55
+ }
56
+
57
+ // Construct from an existing mask and values (most likely from a TF model).
58
+ template <typename MaskType>
59
+ MaskedSparseMatrix(int rows, int cols, const MaskType* mask, const T* values)
60
+ : rows_(rows), cols_(cols) {
61
+ mask_.resize(rows * cols);
62
+ values_.resize(rows * cols);
63
+ std::copy_n(mask, rows * cols, mask_.begin());
64
+ std::copy_n(values, rows * cols, values_.begin());
65
+ sparsity_ =
66
+ 1.f - std::accumulate(mask_.begin(), mask_.end(), 0.f) / mask_.size();
67
+ }
68
+
69
+ const std::vector<int>& mask() const { return mask_; }
70
+ const std::vector<T>& values() const { return values_; }
71
+ T* data() { return values_.data(); }
72
+ const T* data() const { return values_.data(); }
73
+
74
+ int rows() const { return rows_; }
75
+ int cols() const { return cols_; }
76
+ float sparsity() const { return sparsity_; }
77
+
78
+ void Print() const {
79
+ absl::PrintF("-------Values---------\n");
80
+ for (int r = 0; r < rows_; ++r) {
81
+ for (int c = 0; c < cols_; ++c) {
82
+ absl::PrintF("%+6.3f ", static_cast<float>(values_[r * cols_ + c]));
83
+ }
84
+ absl::PrintF("\n");
85
+ }
86
+ absl::PrintF("-------Mask---------\n");
87
+ for (int r = 0; r < rows_; ++r) {
88
+ for (int c = 0; c < cols_; ++c) {
89
+ printf("%2d ", mask_[r * cols_ + c]);
90
+ }
91
+ absl::PrintF("\n");
92
+ }
93
+ }
94
+
95
+ // This routine is useful for rounding the possibly higher precision values
96
+ // stored in this class to a lower precision, so that correctness checks
97
+ // between this class and CSRBlockSparseMatrix can have a tighter tolerance.
98
+ template <typename U>
99
+ void CastWeights() {
100
+ for (int i = 0; i < values_.size(); ++i) {
101
+ values_[i] = static_cast<T>(U(values_[i]));
102
+ }
103
+ }
104
+
105
+ // Only meant for correctness checking.
106
+ // RhsClassType is meant to be either CacheAlignedVector OR
107
+ // FatCacheAlignedVector.
108
+ // The weight matrix is ROW MAJOR and RhsClassType is COLUMN MAJOR.
109
+ // |bias| is broadcast if |rhs| has more than one column.
110
+ template <typename RhsClassType, typename BiasType, typename OutClassType,
111
+ typename RhsType = typename RhsClassType::value_type,
112
+ typename OutType = typename OutClassType::value_type>
113
+ void SpMM_bias(const RhsClassType& rhs,
114
+ const CacheAlignedVector<BiasType>& bias, OutClassType* out,
115
+ bool relu = false) {
116
+ for (int r = 0; r < rows_; ++r) {
117
+ for (int n = 0; n < rhs.cols(); ++n) {
118
+ float sum = 0.f;
119
+ const RhsType* rhs_ptr = rhs.data() + n * rhs.rows();
120
+ OutType* out_ptr = out->data() + n * out->rows();
121
+ const int* mask_ptr = mask_.data() + r * cols_;
122
+ const T* value_ptr = values_.data() + r * cols_;
123
+ for (int c = 0; c < cols_; ++c) {
124
+ sum += mask_ptr[c] * static_cast<float>(value_ptr[c]) *
125
+ static_cast<float>(rhs_ptr[c]);
126
+ }
127
+ out_ptr[r] = static_cast<OutType>(
128
+ relu ? std::max(sum + static_cast<float>(bias[r]), 0.f)
129
+ : sum + static_cast<float>(bias[r]));
130
+ }
131
+ }
132
+ }
133
+
134
+ private:
135
+ // Generate a random matrix with the specified sparsity.
136
+ // Useful for testing.
137
+ void init(float sparsity, int block_height, int block_width, float constant,
138
+ bool random = true) {
139
+ int reduced_rows = rows_ / block_height;
140
+ int reduced_cols = cols_ / block_width;
141
+ mask_.resize(rows_ * cols_, 0);
142
+
143
+ // Fill with non-zero value to make sure masking works.
144
+ values_.resize(rows_ * cols_, static_cast<T>(2.f));
145
+
146
+ std::mt19937 generator(0);
147
+ std::uniform_real_distribution<float> dist_sparsity;
148
+ std::uniform_real_distribution<float> dist_value(-1.f, 1.f);
149
+ int nnz = 0;
150
+ while (nnz == 0) {
151
+ for (int r = 0; r < reduced_rows; ++r) {
152
+ for (int c = 0; c < reduced_cols; ++c) {
153
+ if (dist_sparsity(generator) > sparsity) {
154
+ nnz++;
155
+ for (int i = 0; i < block_height; ++i) {
156
+ for (int j = 0; j < block_width; ++j) {
157
+ mask_[(r * block_height + i) * cols_ + block_width * c + j] = 1;
158
+ values_[(r * block_height + i) * cols_ + block_width * c + j] =
159
+ static_cast<T>(random ? dist_value(generator) : constant);
160
+ }
161
+ }
162
+ }
163
+ }
164
+ }
165
+ }
166
+ }
167
+
168
+ std::vector<int> mask_;
169
+ std::vector<T> values_;
170
+ int rows_;
171
+ int cols_;
172
+ float sparsity_;
173
+ };
174
+
175
+ template <typename T>
176
+ class MaskedLinearLayer {
177
+ public:
178
+ MaskedLinearLayer(MaskedSparseMatrix<T>&& weights,
179
+ CacheAlignedVector<T>&& bias)
180
+ : weights_(std::move(weights)), bias_(std::move(bias)) {}
181
+
182
+ MaskedLinearLayer() {}
183
+
184
+ template <typename U>
185
+ void CastWeights() {
186
+ weights_.template CastWeights<U>();
187
+ }
188
+
189
+ // Does Ax + b where A is a masked sparse ROW MAJOR matrix and
190
+ // x is a COLUMN MAJOR dense vector or matrix. Bias is a vector that is
191
+ // broadcast is rhs has more than one column.
192
+ template <typename FatVector>
193
+ void SpMM_bias(const FatVector& rhs, FatVector* out, bool relu = false) {
194
+ static_assert(std::is_same<typename FatVector::value_type, T>::value,
195
+ "FatVector value_type must match masked_linear_layer type");
196
+ weights_.SpMM_bias(rhs, bias_, out, relu);
197
+ }
198
+
199
+ private:
200
+ MaskedSparseMatrix<T> weights_;
201
+ CacheAlignedVector<T> bias_;
202
+ };
203
+
204
+ } // namespace csrblocksparse
205
+
206
+ #endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_
sparse_matmul/layers/read_array_ifstream.h ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ // Low-level array reading function using std::ifstream.
18
+
19
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_
20
+ #define LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_
21
+
22
+ #include <cstdint>
23
+ #include <fstream>
24
+ #include <sstream>
25
+ #include <string>
26
+
27
+ #include "absl/status/status.h"
28
+ #include "absl/strings/substitute.h"
29
+ #include "include/ghc/filesystem.hpp"
30
+
31
+ namespace csrblocksparse {
32
+ namespace detail {
33
+
34
+ template <typename T>
35
+ absl::Status ReadArrayIfstream(const std::string& file_name,
36
+ const std::string& path, std::vector<T>* array,
37
+ int64_t* length) {
38
+ ghc::filesystem::path complete_path(path);
39
+ complete_path /= file_name;
40
+ std::ifstream in_stream(complete_path.u8string(), std::ios::binary);
41
+ if (!in_stream.is_open()) {
42
+ return absl::UnknownError(
43
+ absl::Substitute("Error opening $0", complete_path.string()));
44
+ }
45
+
46
+ std::stringstream buffer;
47
+ buffer << in_stream.rdbuf();
48
+ if (buffer.str().empty()) {
49
+ LOG(ERROR) << "File " << complete_path << " was empty.";
50
+ return absl::UnknownError(
51
+ absl::Substitute("File $0 was empty", complete_path.string()));
52
+ }
53
+ std::string contents = buffer.str();
54
+ *length = contents.length();
55
+ int64_t elem = (*length + sizeof(T) - 1) / sizeof(T);
56
+ array->resize(elem);
57
+ std::move(contents.begin(), contents.end(),
58
+ reinterpret_cast<char*>(array->data()));
59
+
60
+ return absl::OkStatus();
61
+ }
62
+
63
+ } // namespace detail
64
+ } // namespace csrblocksparse
65
+
66
+ #endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_
sparse_matmul/layers/sparse_linear_layer.h ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_
19
+
20
+ #include <cstdint>
21
+
22
+ #include "absl/memory/memory.h"
23
+ #include "glog/logging.h"
24
+ #include "sparse_matmul/layers/csr_blocksparse_matrix.h"
25
+ #include "sparse_matmul/layers/masked_sparse_matrix.h"
26
+ #include "sparse_matmul/numerics/type_utils.h"
27
+ #include "sparse_matmul/os/coop_threads.h"
28
+ #include "sparse_matmul/vector/cache_aligned_vector.h"
29
+
30
+ namespace csrblocksparse {
31
+
32
+ template <typename WeightType, typename RhsType,
33
+ typename BiasType = typename TypeOfProduct<WeightType, RhsType>::type,
34
+ typename DeltaType = int16_t>
35
+ class SparseLinearLayer {
36
+ public:
37
+ SparseLinearLayer() {}
38
+
39
+ SparseLinearLayer(CsrBlockSparseMatrix<WeightType, RhsType>&& sparse_matrix,
40
+ CacheAlignedVector<BiasType>&& bias)
41
+ : sparse_matrix_(std::move(sparse_matrix)), full_bias_(std::move(bias)) {
42
+ CHECK_EQ(sparse_matrix_.rows(), full_bias_.size());
43
+ // Some kernels expect that the bias is divided by 4, so we store a second
44
+ // copy of a quarter of the bias.
45
+ // TODO(b/189958858): Remove the quartered bias if it can be done without
46
+ // loss of speed, and rename the |full_bias_| member back to |bias_|.
47
+ bias_ = full_bias_;
48
+ for (int i = 0; i < bias_.size(); ++i) {
49
+ bias_[i] = static_cast<BiasType>(.25f * static_cast<float>(bias_[i]));
50
+ }
51
+ }
52
+ SparseLinearLayer(
53
+ const SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>& src) {
54
+ *this = src;
55
+ }
56
+ SparseLinearLayer& operator=(
57
+ const SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>& src) {
58
+ sparse_matrix_ = src.sparse_matrix_;
59
+ bias_ = src.bias_;
60
+ full_bias_ = src.full_bias_;
61
+ mid_output_ = src.mid_output_;
62
+ thread_layers_ = src.thread_layers_;
63
+ num_threads_ = src.num_threads_;
64
+ if (src.split_pc_) {
65
+ split_pc_ = absl::make_unique<ProducerConsumer>(
66
+ src.split_pc_->num_producers(), src.split_pc_->num_consumers());
67
+ }
68
+ return *this;
69
+ }
70
+
71
+ // Does Ax + b where A is a block sparse compressed sparse row matrix and
72
+ // x is a COLUMN MAJOR dense vector or matrix. Bias is a vector that is
73
+ // broadcast if rhs has more than one column.
74
+ template <typename RhsClassType, typename OutType>
75
+ void SpMM_bias(const RhsClassType& rhs, OutType* out, bool relu = false,
76
+ int tid = 0, SpinBarrier* barrier = nullptr) const {
77
+ static_assert(
78
+ std::is_same<typename RhsClassType::value_type, RhsType>::value, "");
79
+ sparse_matrix_.SpMM_bias(rhs, bias_, out, relu, tid, barrier);
80
+ }
81
+ // Multiplies a sparse matrix by a possibly dense matrix, as SpMM_bias above,
82
+ // and then samples from the output (softmax distribution) layer.
83
+ template <typename RhsClassType, typename OutType>
84
+ int SpMM_bias_Sample(const RhsClassType& rhs, OutType* out, float temperature,
85
+ int tid, SpinBarrier* barrier, std::minstd_rand* gen,
86
+ CacheAlignedVector<float>* scratch) const {
87
+ static_assert(
88
+ std::is_same<typename RhsClassType::value_type, RhsType>::value, "");
89
+ return sparse_matrix_.SpMM_bias_Sample(rhs, bias_, out, temperature, tid,
90
+ barrier, gen, scratch);
91
+ }
92
+ template <typename RhsClassType, typename OutType>
93
+ void MatVec(const RhsClassType& rhs, bool relu, int tid, int replicas,
94
+ int output_stride, OutType* output,
95
+ SpinBarrier* barrier = nullptr) {
96
+ static_assert(
97
+ std::is_same<typename RhsClassType::value_type, RhsType>::value, "");
98
+ #ifdef __AVX2__
99
+ if (block_width() == 4 && (block_height() == 4 || block_height() == 8) &&
100
+ !IsCustomFloatType<WeightType>::value) {
101
+ if (!IsSplit()) {
102
+ sparse_matrix_.MatVec(rhs.cast_data(), full_bias_.cast_data(), relu,
103
+ tid, replicas, output_stride, output->data());
104
+ if (barrier != nullptr) barrier->barrier();
105
+ return;
106
+ }
107
+ // NOTE: Until the quartered bias is removed it is a bad idea to split
108
+ // for ARM in the same way, as we would have to quarter the output of
109
+ // the first part of the split before running the second part.
110
+ // Signal completion of the previous MatVec.
111
+ split_pc_->produce();
112
+ PartLinearLayer& thread_part = thread_layers_[tid];
113
+ auto offset_output =
114
+ sparse_matrix_.thread_bounds().OffsetOutput(output->data(), tid);
115
+ auto mid_output =
116
+ sparse_matrix_.thread_bounds().OffsetOutput(mid_output_.data(), tid);
117
+ auto offset_bias = sparse_matrix_.thread_bounds().OffsetOutput(
118
+ mid_output_.cast_data(), tid);
119
+ // We can continue to consume the data that this thread produced and
120
+ // compute just the |self_matrix| part.
121
+ // No |relu| or |replicas|, as this is only a partial matmul.
122
+ // |tid| is always zero because the matrix has been split by tid.
123
+ thread_part.self_matrix.MatVec(
124
+ rhs.cast_data(), thread_part.full_bias.cast_data(), /*relu=*/false,
125
+ /*tid=*/0, /*replicas=*/1, output_stride, mid_output);
126
+ // We have to wait for the other threads to finish working on the previous
127
+ // MatMul before consuming the rest of |rhs|.
128
+ split_pc_->consume();
129
+ thread_part.other_matrix.MatVec(rhs.cast_data(), offset_bias, relu,
130
+ /*tid=*/0, replicas, output_stride,
131
+ offset_output);
132
+ return;
133
+ }
134
+ #endif
135
+ DCHECK_EQ(replicas, 1) << "Must have single replica for SpMM API";
136
+ if (IsSplit()) {
137
+ // Generics aren't setup to use a split matrix. This will be inefficient.
138
+ split_pc_->produce();
139
+ split_pc_->consume();
140
+ }
141
+ if (block_height() == 8) {
142
+ // We are currently forced to use MatVec generics for this case.
143
+ LOG(WARNING) << "Need to implement MatVec for 8x4 for non-AVX2 targets!!";
144
+ sparse_matrix_.MatVec(rhs.cast_data(), full_bias_.cast_data(), relu, tid,
145
+ replicas, output_stride, output->data());
146
+ if (barrier != nullptr) barrier->barrier();
147
+ } else {
148
+ sparse_matrix_.SpMM_bias(rhs, bias_, output, relu, tid, barrier);
149
+ }
150
+ }
151
+
152
+ int rows() const { return sparse_matrix_.rows(); }
153
+ int cols() const { return sparse_matrix_.cols(); }
154
+ float sparsity() const { return sparse_matrix_.sparsity(); }
155
+ int block_width() const { return sparse_matrix_.block_width(); }
156
+ int block_height() const { return sparse_matrix_.block_height(); }
157
+ int num_threads() const { return sparse_matrix_.num_threads(); }
158
+ const CacheAlignedVector<BiasType>& bias() const { return bias_; }
159
+ const std::vector<int>& split_points() const {
160
+ return sparse_matrix_.split_points();
161
+ }
162
+ bool IsSplit() const {
163
+ return !thread_layers_.empty() && split_pc_ != nullptr;
164
+ }
165
+
166
+ std::size_t bytes() const { return sparse_matrix_.bytes() + bias_.bytes(); }
167
+ void Print() const {
168
+ printf("Matrix\n");
169
+ sparse_matrix_.Print();
170
+ printf("Bias\n");
171
+ bias_.Print();
172
+ }
173
+
174
+ // Combines adjacent row blocks, doubling the block height.
175
+ // This necessarily involves adding zero weights where the blocks don't align
176
+ // across adjacent pairs of rows, so use with caution, as the resulting matrix
177
+ // is most likely to run slower if very sparse to begin with.
178
+ // In the few cases where the blocks do mostly align, the resulting matmul
179
+ // could be much faster, as the number of reads of the rhs will be halved.
180
+ void DoubleBlockHeight() { sparse_matrix_.DoubleBlockHeight(); }
181
+
182
+ // Cache_line_size is provided only for testing. Normally uses a value for
183
+ // the current architecture.
184
+ int PrepareForThreads(int num_threads, int cache_line_size = -1) {
185
+ num_threads_ = num_threads;
186
+ if (num_threads_ > 1) {
187
+ split_pc_ =
188
+ absl::make_unique<ProducerConsumer>(num_threads_, num_threads_);
189
+ } else {
190
+ split_pc_.reset(nullptr);
191
+ }
192
+ return sparse_matrix_.PrepareForThreads(num_threads, cache_line_size);
193
+ }
194
+
195
+ // Partitions the matrix into pieces by thread.
196
+ // In this matrix, we can go ahead and calculate the part that only depends
197
+ // on rhs inputs that were generated by this thread in the previous matvec,
198
+ // without having to use any thread synchronization, and only after that do we
199
+ // have to wait for the other threads to finish the previous matvec.
200
+ // So we split the matrix using the |split_points| from the previous matrix
201
+ // into 2 * |num_threads_| pieces: self and other for each thread, being the
202
+ // parts that can be calculated before and after the other threads have
203
+ // completed their calculation of the previous matvec.
204
+ // We then have to use a ProducerConsumer lock instead of a SpinBarrier to
205
+ // synchronize the data produced by the other threads.
206
+ void SliceForThreads(const std::vector<int>& split_points) {
207
+ thread_layers_.clear();
208
+ thread_layers_.reserve(num_threads_);
209
+ LOG(INFO) << "Slicing " << rows() << "x" << cols() << " matrix for "
210
+ << num_threads_ << " threads";
211
+ for (int tid = 0; tid < num_threads_; ++tid) {
212
+ thread_layers_.emplace_back(
213
+ sparse_matrix_, full_bias_, bias_, tid,
214
+ split_points[tid] * sparse_matrix_.block_height(),
215
+ split_points[tid + 1] * sparse_matrix_.block_height());
216
+ }
217
+ mid_output_ =
218
+ std::move(csrblocksparse::CacheAlignedVector<BiasType>(rows()));
219
+ mid_output_.FillZero();
220
+ }
221
+
222
+ // Splits the layer by inputs into 2 equal pieces. Each of the resulting
223
+ // layers should be computed independently on the first and second halves of
224
+ // the inputs respectively and the results added to achieve the same effect
225
+ // as the original layer.
226
+ void SplitInputs(
227
+ SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part1,
228
+ SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part2) {
229
+ CsrBlockSparseMatrix<WeightType, RhsType> matrix1(
230
+ sparse_matrix_.SplitByColumn(0, sparse_matrix_.cols() / 2));
231
+ CsrBlockSparseMatrix<WeightType, RhsType> matrix2(
232
+ sparse_matrix_.SplitByColumn(sparse_matrix_.cols() / 2,
233
+ sparse_matrix_.cols()));
234
+ *part1 =
235
+ std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
236
+ std::move(matrix1),
237
+ std::move(CacheAlignedVector<BiasType>(full_bias_))));
238
+ CacheAlignedVector<BiasType> bias2(sparse_matrix_.rows());
239
+ bias2.FillZero();
240
+ *part2 =
241
+ std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
242
+ std::move(matrix2), std::move(bias2)));
243
+ }
244
+
245
+ // Splits the layer by outputs into 2 equal pieces. Each of the resulting
246
+ // layers should be computed independently on the full inputs and the results
247
+ // concatenated to achieve the same effect as the original layer.
248
+ void SplitOutputs(
249
+ SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part1,
250
+ SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part2) {
251
+ LOG(INFO) << "input rows=" << sparse_matrix_.rows()
252
+ << ", cols=" << sparse_matrix_.cols();
253
+ CsrBlockSparseMatrix<WeightType, RhsType> matrix1(
254
+ sparse_matrix_.SplitByRow(0, sparse_matrix_.rows() / 2));
255
+ CsrBlockSparseMatrix<WeightType, RhsType> matrix2(sparse_matrix_.SplitByRow(
256
+ sparse_matrix_.rows() / 2, sparse_matrix_.rows()));
257
+ CacheAlignedVector<BiasType> bias1(full_bias_, 0, full_bias_.size() / 2);
258
+ *part1 =
259
+ std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
260
+ std::move(matrix1), std::move(bias1)));
261
+ CacheAlignedVector<BiasType> bias2(full_bias_, full_bias_.size() / 2,
262
+ full_bias_.size());
263
+ *part2 =
264
+ std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
265
+ std::move(matrix2), std::move(bias2)));
266
+ }
267
+
268
+ private:
269
+ // Simple struct to hold a partitioned layer.
270
+ struct PartLinearLayer {
271
+ // The original matrix is first split by row to generate only the outputs
272
+ // for the given tid. The |row_sub_matrix| is then split by column into two
273
+ // partitions:
274
+ // self is the part for which the rhs elements in [|start_col|, |end_col|)
275
+ // were generated by this thread in some previous matmul.
276
+ // |other| is the rest of the columns that require rhs elements from other
277
+ // threads.
278
+ // NOTE that| start_col|, |end_col| are in raw columns, not blocks.
279
+ PartLinearLayer(const CsrBlockSparseMatrix<WeightType, RhsType>& matrix,
280
+ const CacheAlignedVector<BiasType>& bias,
281
+ const CacheAlignedVector<BiasType>& bias_4, int tid,
282
+ int start_col, int end_col) {
283
+ int block_height = matrix.block_height();
284
+ // Split the input matrix by row, selecting only the rows relevant to
285
+ // thread tid.
286
+ int start_row = matrix.split_points()[tid] * block_height;
287
+ int end_row = matrix.split_points()[tid + 1] * block_height;
288
+ LOG(INFO) << "input cols [" << start_col << "," << end_col << ") rows ["
289
+ << start_row << "," << end_row << ")";
290
+ CsrBlockSparseMatrix<WeightType, RhsType> row_sub_matrix =
291
+ matrix.SplitByRow(start_row, end_row);
292
+ // Partition into the columns that use rhs elements that thread tid
293
+ // produced in a previous matmul, and the other rhs elements.
294
+ // NOTE that we |keep_rhs_size|=true so that each matrix can operate on
295
+ // the same rhs input vector. The self matrix just guarantees not to
296
+ // access any of the elements that are generated by another thread.
297
+ self_matrix = std::move(row_sub_matrix.SplitByColumn(
298
+ start_col, end_col, /*keep_rhs_size=*/true));
299
+ self_matrix.PrepareForThreads(1);
300
+ // The reversed start and end slice out the complement of [start, end).
301
+ other_matrix = std::move(row_sub_matrix.SplitByColumn(
302
+ end_col, start_col, /*keep_rhs_size=*/true));
303
+ other_matrix.PrepareForThreads(1);
304
+ full_bias =
305
+ std::move(CacheAlignedVector<BiasType>(bias, start_row, end_row));
306
+ // TODO(b/189958858): Eliminate the quarter bias from all the code.
307
+ quarter_bias =
308
+ std::move(CacheAlignedVector<BiasType>(bias_4, start_row, end_row));
309
+ }
310
+ // The part of the matrix that only depends on this thread for rhs inputs.
311
+ CsrBlockSparseMatrix<WeightType, RhsType> self_matrix;
312
+ CacheAlignedVector<BiasType> full_bias;
313
+ CacheAlignedVector<BiasType> quarter_bias;
314
+ // The part of the matrix that uses rhs inputs from other threads.
315
+ CsrBlockSparseMatrix<WeightType, RhsType> other_matrix;
316
+ };
317
+ CsrBlockSparseMatrix<WeightType, RhsType, DeltaType> sparse_matrix_;
318
+ CacheAlignedVector<BiasType> bias_;
319
+ CacheAlignedVector<BiasType> full_bias_;
320
+ // Output from the self_matrix that will be given to |other_matrix| as bias.
321
+ CacheAlignedVector<BiasType> mid_output_;
322
+ // One partitioned pair of matrices for each thread.
323
+ std::vector<PartLinearLayer> thread_layers_;
324
+ // Producer-consumer lock used to wait between computing |self_matrix| and
325
+ // |other_matrix| for the other threads to finish the *previous* matvec.
326
+ std::unique_ptr<ProducerConsumer> split_pc_;
327
+ int num_threads_ = 0;
328
+ };
329
+
330
+ template <typename WeightType, typename RhsType>
331
+ SparseLinearLayer<WeightType, RhsType> CreateRandomLayer(int rows, int cols,
332
+ float sparsity,
333
+ int block_height = 1,
334
+ int block_width = 1) {
335
+ typedef typename TypeOfProduct<WeightType, RhsType>::type BiasType;
336
+ CacheAlignedVector<BiasType> bias(rows);
337
+ bias.FillRandom();
338
+
339
+ auto masked_matrix = MaskedSparseMatrix<float>(rows, cols, sparsity,
340
+ block_height, block_width);
341
+ auto sparse_matrix = CsrBlockSparseMatrix<WeightType, RhsType>(masked_matrix);
342
+
343
+ return SparseLinearLayer<WeightType, RhsType>(std::move(sparse_matrix),
344
+ std::move(bias));
345
+ }
346
+
347
+ template <typename WeightType, typename RhsType>
348
+ SparseLinearLayer<WeightType, RhsType> CreateConstantLayer(
349
+ int rows, int cols, float sparsity, float constant = 1.f) {
350
+ typedef typename TypeOfProduct<WeightType, RhsType>::type BiasType;
351
+ CacheAlignedVector<BiasType> bias(rows);
352
+ bias.FillOnes();
353
+
354
+ MaskedSparseMatrix<float> masked_matrix(rows, cols, sparsity,
355
+ /*block_height=*/1, /*block_width=*/1,
356
+ constant, /*random=*/false);
357
+ CsrBlockSparseMatrix<WeightType, RhsType> sparse_matrix(masked_matrix);
358
+
359
+ return SparseLinearLayer<WeightType, RhsType>(std::move(sparse_matrix),
360
+ std::move(bias));
361
+ }
362
+
363
+ } // namespace csrblocksparse
364
+
365
+ #endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_
sparse_matmul/layers/sparse_linear_layer_test.cc ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include "sparse_matmul/layers/sparse_linear_layer.h"
16
+
17
+ #include "gmock/gmock.h"
18
+ #include "gtest/gtest.h"
19
+ #include "sparse_matmul/numerics/test_utils.h"
20
+
21
+ namespace csrblocksparse {
22
+ namespace {
23
+
24
+ constexpr int kBlockSize = 4;
25
+ constexpr int kSize = 256;
26
+ constexpr int kNumThreads = 4;
27
+ constexpr int kCols = 1;
28
+
29
+ void SlicedThreadBody(SpinBarrier* spin_barrier, int tid,
30
+ const FatCacheAlignedVector<float>& rhs,
31
+ SparseLinearLayer<float, float>* sparse_linear_layer,
32
+ FatCacheAlignedVector<float>* out, bool use_relu) {
33
+ sparse_linear_layer->MatVec(rhs, use_relu, tid, /*replicas=*/1,
34
+ /*output_stride=*/0, out);
35
+ spin_barrier->barrier();
36
+ }
37
+
38
+ // Tests that a Layer that has been SliceForThreads computes the same result as
39
+ // the original layer. This is a basic test that all the slicing didn't mess up
40
+ // any of the computations.
41
+ TEST(CsrBlockSparseMatrix, SliceForThreads) {
42
+ MaskedSparseMatrix<float> matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize);
43
+ FatCacheAlignedVector<float> rhs(kSize, kCols);
44
+ CacheAlignedVector<float> bias(kSize);
45
+ FatCacheAlignedVector<float> out1(kSize, kCols);
46
+
47
+ bias.FillRandom();
48
+ rhs.FillRandom();
49
+ out1.FillZero();
50
+ FatCacheAlignedVector<float> out_reference = out1;
51
+ CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
52
+ SparseLinearLayer<float, float> sparse_linear_layer(std::move(sparse_matrix),
53
+ std::move(bias));
54
+ sparse_linear_layer.PrepareForThreads(1);
55
+ sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
56
+ /*output_stride=*/0, &out_reference);
57
+ std::vector<int> fake_split_points = {0, 48 / kBlockSize, 128 / kBlockSize,
58
+ 208 / kBlockSize, kSize / kBlockSize};
59
+ sparse_linear_layer.PrepareForThreads(kNumThreads);
60
+ sparse_linear_layer.SliceForThreads(fake_split_points);
61
+ csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, SlicedThreadBody, rhs,
62
+ &sparse_linear_layer, &out1,
63
+ /*relu=*/true);
64
+
65
+ CheckResult(out_reference, out1, kCols);
66
+ }
67
+
68
+ void LayersThreadBody(SpinBarrier* spin_barrier, int tid,
69
+ const FatCacheAlignedVector<float>& rhs,
70
+ SparseLinearLayer<float, float>* sparse_linear_layer1,
71
+ SparseLinearLayer<float, float>* sparse_linear_layer2,
72
+ FatCacheAlignedVector<float>* out1,
73
+ FatCacheAlignedVector<float>* out2, bool use_relu) {
74
+ sparse_linear_layer1->MatVec(rhs, use_relu, tid, /*replicas=*/1,
75
+ /*output_stride=*/0, out1);
76
+ // NOTE no barrier here!
77
+ sparse_linear_layer2->MatVec(*out1, use_relu, tid, /*replicas=*/1,
78
+ /*output_stride=*/0, out2);
79
+ spin_barrier->barrier();
80
+ }
81
+
82
+ // Tests that a pair of layers computes the same result whether or not the
83
+ // second layer has been SliceForThreads. This is a more critical test that
84
+ // the replacement of barriers with producer-consumer locks works.
85
+ // Must be run with tsan to really test it properly.
86
+ TEST(CsrBlockSparseMatrix, SliceForThreadsLayers) {
87
+ MaskedSparseMatrix<float> matrix1(kSize, kSize, 0.95, kBlockSize, kBlockSize);
88
+ FatCacheAlignedVector<float> rhs(kSize, kCols);
89
+ CacheAlignedVector<float> bias1(kSize);
90
+ FatCacheAlignedVector<float> out1(kSize, kCols);
91
+ MaskedSparseMatrix<float> matrix2(kSize, kSize, 0.95, kBlockSize, kBlockSize);
92
+ CacheAlignedVector<float> bias2(kSize);
93
+ FatCacheAlignedVector<float> out2(kSize, kCols);
94
+
95
+ bias1.FillRandom();
96
+ rhs.FillRandom();
97
+ bias2.FillRandom();
98
+ out1.FillZero();
99
+ out2.FillZero();
100
+ FatCacheAlignedVector<float> out_reference = out2;
101
+ CsrBlockSparseMatrix<float, float> sparse_matrix1(matrix1);
102
+ SparseLinearLayer<float, float> layer1(std::move(sparse_matrix1),
103
+ std::move(bias1));
104
+ CsrBlockSparseMatrix<float, float> sparse_matrix2(matrix2);
105
+ SparseLinearLayer<float, float> layer2(std::move(sparse_matrix2),
106
+ std::move(bias2));
107
+ layer1.PrepareForThreads(1);
108
+ layer2.PrepareForThreads(1);
109
+ layer1.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
110
+ /*output_stride=*/0, &out1);
111
+ layer2.MatVec(out1, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
112
+ /*output_stride=*/0, &out_reference);
113
+ layer1.PrepareForThreads(kNumThreads);
114
+ layer2.PrepareForThreads(kNumThreads);
115
+ layer2.SliceForThreads(layer1.split_points());
116
+ csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, LayersThreadBody, rhs,
117
+ &layer1, &layer2, &out1, &out2,
118
+ /*relu=*/true);
119
+
120
+ CheckResult(out_reference, out2, kCols);
121
+ }
122
+
123
+ // Tests that a Layer that has been DoubleBlockHeight()-ed computes the same
124
+ // result as original layer. (Float compute type).
125
+ TEST(CsrBlockSparseMatrix, Float8x4) {
126
+ using ComputeType = float;
127
+ using RhsType = float;
128
+ using BiasType = float;
129
+ MaskedSparseMatrix<float> matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize);
130
+ matrix.CastWeights<ComputeType>();
131
+ FatCacheAlignedVector<RhsType> rhs(kSize, kCols);
132
+ CacheAlignedVector<BiasType> bias(kSize);
133
+ FatCacheAlignedVector<BiasType> out1(kSize, kCols);
134
+
135
+ bias.FillRandom();
136
+ rhs.FillRandom();
137
+ out1.FillZero();
138
+ FatCacheAlignedVector<BiasType> out_reference = out1;
139
+ CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
140
+ SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
141
+ std::move(sparse_matrix), std::move(bias));
142
+ sparse_linear_layer.PrepareForThreads(1);
143
+ sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
144
+ /*output_stride=*/0, &out_reference);
145
+ sparse_linear_layer.DoubleBlockHeight();
146
+ sparse_linear_layer.PrepareForThreads(1);
147
+ sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
148
+ /*output_stride=*/0, &out1);
149
+ CheckResult(out_reference, out1, kCols);
150
+ }
151
+
152
+ // Tests that a Layer that has been DoubleBlockHeight()-ed computes the same
153
+ // result as original layer. (Fixed16 compute type).
154
+ TEST(CsrBlockSparseMatrix, Fixed8x4) {
155
+ using ComputeType = csrblocksparse::fixed16<4>;
156
+ using RhsType = csrblocksparse::fixed16<4>;
157
+ using BiasType = typename TypeOfProduct<ComputeType, RhsType>::type;
158
+ MaskedSparseMatrix<float> matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize);
159
+ matrix.CastWeights<ComputeType>();
160
+ FatCacheAlignedVector<RhsType> rhs(kSize, kCols);
161
+ CacheAlignedVector<BiasType> bias(kSize);
162
+ FatCacheAlignedVector<BiasType> out1(kSize, kCols);
163
+
164
+ bias.FillRandom();
165
+ rhs.FillRandom();
166
+ out1.FillZero();
167
+ FatCacheAlignedVector<BiasType> out_reference = out1;
168
+ CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
169
+ SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
170
+ std::move(sparse_matrix), std::move(bias));
171
+ sparse_linear_layer.PrepareForThreads(1);
172
+ sparse_linear_layer.MatVec(rhs, /*relu=*/false, /*tid=*/0, /*replicas=*/1,
173
+ /*output_stride=*/0, &out_reference);
174
+ sparse_linear_layer.DoubleBlockHeight();
175
+ sparse_linear_layer.PrepareForThreads(1);
176
+ sparse_linear_layer.MatVec(rhs, /*relu=*/false, /*tid=*/0, /*replicas=*/1,
177
+ /*output_stride=*/0, &out1);
178
+ CheckResult(out_reference, out1, kCols);
179
+ }
180
+
181
+ TEST(SparseLinearLayerTest, PrintCompiles) {
182
+ SparseLinearLayer<float, float> sparse_linear_layer;
183
+ sparse_linear_layer.Print();
184
+ }
185
+
186
+ } // namespace
187
+ } // namespace csrblocksparse
sparse_matmul/layers/status_macros.h ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_
16
+ #define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_
17
+
18
+ #include "absl/status/status.h"
19
+ #include "absl/status/statusor.h"
20
+
21
+ #define SPARSE_MATMUL_RETURN_IF_ERROR(expr) \
22
+ do { \
23
+ const absl::Status _status = (expr); \
24
+ if (!_status.ok()) return _status; \
25
+ } while (0)
26
+ template <typename T>
27
+ absl::Status DoAssignOrReturn(T& lhs, absl::StatusOr<T> result) {
28
+ if (result.ok()) {
29
+ lhs = result.value();
30
+ }
31
+ return result.status();
32
+ }
33
+
34
+ #endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_
sparse_matmul/layers/testdata/768_512_95_4x4_QRhat_weights.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50f861af29b1f767830d74ef83874944b18d80157b6b0256fdc4c14fa79ec936
3
+ size 20852
sparse_matmul/layers/testdata/768_512_95_4x4_What_weights.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2d534bde2caf6e59990a46b4b1907088b8144c53d62d97de7e2b4bdc956da68
3
+ size 5133
sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_bias.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11399f9d0e8f8dfbef6eb37e0c096f858658bc650f728a08f3135ccca44f0a5a
3
+ size 1062
sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_mask.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3d971e067a6df985d68beac26bcf4e9a6cc13ff328599e84d50a0fc9a7c103b
3
+ size 2382
sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_weights.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1376ef7a360699dae24a49f40a254990d4a70b844dadcdbe9dcbf1a306999a8
3
+ size 55829
sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_bias.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffcc8ccf086fccfacc928877aa29ef03ce51cce0f0b7d2aacf81782b7b527089
3
+ size 2003
sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_mask.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a16f98ba6f09031ea9fefb79fdc9ba90e44f0046ab70dab014ac971ca7f7186
3
+ size 4684
sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_weights.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1b91304f5b6f7b53651ec7f9c827d4a2447366d1f990032adff46b18377741f
3
+ size 113777
sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_bias.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ebb84ab4e16408f898b41a28c0d2c611f6735c8d9ad96a6805947c57cb547c7
3
+ size 1055
sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_mask.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:071159e5397eff604ff3f1fca3ba90980a1ff9ae12838022179709d2c50e4627
3
+ size 2322
sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_weights.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fdd0cbc0e79ea0a0dc1fc2ce8b10c5f25387fb4fd2ca019b66ac7ad7f44d219
3
+ size 51615
sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_bias.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abd83a1795fd5e7044200029eae3ce6406b84095b7128288ac0dda1de5746b59
3
+ size 2001
sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_mask.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:455e1c142dd29bc4a4bb5a15c1f88ef3e0fbb580425620ef6f923b6e04faab01
3
+ size 4459
sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_weights.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:171d1e86e04fbefeca7dcce59817ad82d30556a110b4552cd5757a9348405d1c
3
+ size 111636
sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_bias.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fba804daa5c3c4d5c87ca1ff4060d118c33f8e2201077e6faa233822c5f0c511
3
+ size 10706
sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_mask.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62c03b31f5f58eb67773dcc5b0bae5b4790a26dca1934d79802342b4175e7a74
3
+ size 50978
sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_weights.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:679c5bd2d5ca6abaae96225e8bab2ce9f9d57170027471465c85fc220c0c44a8
3
+ size 1361746