diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..6f6e89490e089bee56acc8bc7816dd18a86db6de --- /dev/null +++ b/.gitattributes @@ -0,0 +1,30 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bin.* filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zstandard filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +bazelisk-linux-amd64 filter=lfs diff=lfs merge=lfs -text +wavegru_mod.so filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b694934fbf9b49ee808b6dfc7292c28e2c46a97e --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.venv \ No newline at end of file diff --git a/BUILD b/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..74e1b1d34af9eae1bebb3605c2196b9cafb0d086 --- /dev/null +++ b/BUILD @@ -0,0 +1,44 @@ +# [internal] load cc_fuzz_target.bzl +# [internal] load cc_proto_library.bzl +# [internal] load android_cc_test:def.bzl + +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package(default_visibility = [":__subpackages__"]) + +licenses(["notice"]) + +# To run all cc_tests in this directory: +# bazel test //:all + +# [internal] Command to run dsp_util_android_test. + +# [internal] Command to run lyra_integration_android_test. + +exports_files( + srcs = [ + "wavegru_mod.cc", + ], +) + +pybind_extension( + name = "wavegru_mod", # This name is not actually created! + srcs = ["wavegru_mod.cc"], + deps = [ + "//sparse_matmul", + ], +) + +py_library( + name = "wavegru_mod", + data = [":wavegru_mod.so"], +) + +py_binary( + name = "wavegru", + srcs = ["wavegru.py"], + deps = [ + ":wavegru_mod" + ], +) + diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..10acf1dee98979dd39babc9dd46f964c25b244f4 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,32 @@ +# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker +# you will also find guides on how best to write your Dockerfile + +FROM python:3.11 + +RUN apt update; apt install libsndfile1-dev make autoconf automake libtool gcc pkg-config -y + +WORKDIR /code + +COPY ./requirements.txt /code/requirements.txt + +RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt + +# Set up a new user named "user" with user ID 1000 +RUN useradd -m -u 1000 user + +# Switch to the "user" user +USER user + +# Set home to the user's home directory +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH + +# Set the working directory to the user's home directory +WORKDIR $HOME/app + +# Copy the current directory contents into the container at $HOME/app setting the owner to the user +COPY --chown=user . $HOME/app + +RUN bash ./install_espeak_ng.sh + +CMD ["python", "main.py"] \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..232a850b602f3d97be7758c309e01c36a9e4d147 --- /dev/null +++ b/README.md @@ -0,0 +1,20 @@ +--- +title: WaveGRU Text To Speech +emoji: 🌍 +colorFrom: blue +colorTo: blue +sdk: gradio +sdk_version: 2.8.10 +app_file: app.py +pinned: false +license: mit +duplicated_from: ntt123/WaveGRU-Text-To-Speech +--- + + +## Build wavenet-cpp + + + ./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native + cp -f bazel-bin/wavegru_mod.so . + diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 0000000000000000000000000000000000000000..b7317411b729bf64f9f988581abbc62db20a045b --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,154 @@ +######################## +# Platform Independent # +######################## + +load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository", "new_git_repository") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +# GoogleTest/GoogleMock framework. +git_repository( + name = "com_google_googletest", + remote = "https://github.com/google/googletest.git", + tag = "release-1.10.0", +) + +# Google benchmark. +http_archive( + name = "com_github_google_benchmark", + urls = ["https://github.com/google/benchmark/archive/bf585a2789e30585b4e3ce6baf11ef2750b54677.zip"], # 2020-11-26T11:14:03Z + strip_prefix = "benchmark-bf585a2789e30585b4e3ce6baf11ef2750b54677", + sha256 = "2a778d821997df7d8646c9c59b8edb9a573a6e04c534c01892a40aa524a7b68c", +) + +# proto_library, cc_proto_library, and java_proto_library rules implicitly +# depend on @com_google_protobuf for protoc and proto runtimes. +# This statement defines the @com_google_protobuf repo. +git_repository( + name = "com_google_protobuf", + remote = "https://github.com/protocolbuffers/protobuf.git", + tag = "v3.15.4", +) + +load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") +protobuf_deps() + +# Google Abseil Libs +git_repository( + name = "com_google_absl", + remote = "https://github.com/abseil/abseil-cpp.git", + branch = "lts_2020_09_23", +) + +# Filesystem +# The new_* prefix is used because it is not a bazel project and there is +# no BUILD file in that repo. +FILESYSTEM_BUILD = """ +cc_library( + name = "filesystem", + hdrs = glob(["include/ghc/*"]), + visibility = ["//visibility:public"], +) +""" + +new_git_repository( + name = "gulrak_filesystem", + remote = "https://github.com/gulrak/filesystem.git", + tag = "v1.3.6", + build_file_content = FILESYSTEM_BUILD +) + +# Audio DSP +git_repository( + name = "com_google_audio_dsp", + remote = "https://github.com/google/multichannel-audio-tools.git", + # There are no tags for this repo, we are synced to bleeding edge. + branch = "master", + repo_mapping = { + "@com_github_glog_glog" : "@com_google_glog" + } +) + + +http_archive( + name = "pybind11_bazel", + strip_prefix = "pybind11_bazel-72cbbf1fbc830e487e3012862b7b720001b70672", + urls = ["https://github.com/pybind/pybind11_bazel/archive/72cbbf1fbc830e487e3012862b7b720001b70672.zip"], +) +# We still require the pybind library. +http_archive( + name = "pybind11", + build_file = "@pybind11_bazel//:pybind11.BUILD", + strip_prefix = "pybind11-2.9.0", + urls = ["https://github.com/pybind/pybind11/archive/v2.9.0.tar.gz"], +) +load("@pybind11_bazel//:python_configure.bzl", "python_configure") +python_configure(name = "local_config_python") + + + +# Transitive dependencies of Audio DSP. +http_archive( + name = "eigen_archive", + build_file = "eigen.BUILD", + sha256 = "f3d69ac773ecaf3602cb940040390d4e71a501bb145ca9e01ce5464cf6d4eb68", + strip_prefix = "eigen-eigen-049af2f56331", + urls = [ + "http://mirror.tensorflow.org/bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz", + "https://bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz", + ], +) + +http_archive( + name = "fft2d", + build_file = "fft2d.BUILD", + sha256 = "ada7e99087c4ed477bfdf11413f2ba8db8a840ba9bbf8ac94f4f3972e2a7cec9", + urls = [ + "http://www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz", + ], +) + +# Google logging +git_repository( + name = "com_google_glog", + remote = "https://github.com/google/glog.git", + branch = "master" +) +# Dependency for glog +git_repository( + name = "com_github_gflags_gflags", + remote = "https://github.com/mchinen/gflags.git", + branch = "android_linking_fix" +) + +# Bazel/build rules + +http_archive( + name = "bazel_skylib", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz", + "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz", + ], + sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44", +) +load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") +bazel_skylib_workspace() + +http_archive( + name = "rules_android", + sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806", + strip_prefix = "rules_android-0.1.1", + urls = ["https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip"], +) + +# Google Maven Repository +GMAVEN_TAG = "20180625-1" + +http_archive( + name = "gmaven_rules", + strip_prefix = "gmaven_rules-%s" % GMAVEN_TAG, + url = "https://github.com/bazelbuild/gmaven_rules/archive/%s.tar.gz" % GMAVEN_TAG, +) + +load("@gmaven_rules//:gmaven.bzl", "gmaven_rules") + +gmaven_rules() diff --git a/alphabet.txt b/alphabet.txt new file mode 100644 index 0000000000000000000000000000000000000000..4f964d9409a1a9d126d91e5ee1ff0fc429a42f53 --- /dev/null +++ b/alphabet.txt @@ -0,0 +1,57 @@ +_ +■ + +! +" +, +. +: +; +? +a +b +d +e +f +h +i +j +k +l +m +n +o +p +r +s +t +u +v +w +x +z +æ +ð +ŋ +ɐ +ɑ +ɔ +ə +ɚ +ɛ +ɜ +ɡ +ɪ +ɹ +ɾ +ʃ +ʊ +ʌ +ʒ +ʔ +ˈ +ˌ +ː +̩ +θ +ᵻ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..1fcfc69d670d9016ceee3402578708b91a200f50 --- /dev/null +++ b/app.py @@ -0,0 +1,60 @@ +## build wavegru-cpp +# import os +# os.system("./bazelisk-linux-amd64 clean --expunge") +# os.system("./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native") + +# install espeak +import os + +if not os.path.isdir("./espeak"): + os.system("bash ./install_espeak_ng.sh") + +import gradio as gr +from inference import load_tacotron_model, load_wavegru_net, mel_to_wav, text_to_mel +from wavegru_cpp import extract_weight_mask, load_wavegru_cpp + + +alphabet, tacotron_net, tacotron_config = load_tacotron_model( + "./alphabet.txt", "./tacotron.toml", "./tacotrons_ljs_24k_v1_0300000.ckpt" +) + +wavegru_config, wavegru_net = load_wavegru_net( + "./wavegru.yaml", "./wavegru_vocoder_1024_v4_1320000.ckpt" +) + +wave_cpp_weight_mask = extract_weight_mask(wavegru_net) +wavecpp = load_wavegru_cpp( + wave_cpp_weight_mask, wavegru_config["upsample_factors"][-1] +) + + +def speak(text): + mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config) + y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config) + return 24_000, y + + +title = "WaveGRU-TTS" +description = "WaveGRU text-to-speech demo." + +gr.Interface( + fn=speak, + inputs="text", + examples=[ + "This is a test!", + "President Trump met with other leaders at the Group of 20 conference.", + "The buses aren't the problem, they actually provide a solution.", + "Generative adversarial network or variational auto-encoder.", + "Basilar membrane and otolaryngology are not auto-correlations.", + "There are several variations on the full gated unit, with gating done using the previous hidden state and the bias in various combinations, and a simplified form called minimal gated unit.", + "October arrived, spreading a damp chill over the grounds and into the castle. Madam Pomfrey, the nurse, was kept busy by a sudden spate of colds among the staff and students.", + "Artificial intelligence is intelligence demonstrated by machines, as opposed to natural intelligence displayed by animals including humans.", + 'Uncle Vernon entered the kitchen as Harry was turning over the bacon. "Comb your hair!" he barked, by way of a morning greeting. About once a week, Uncle Vernon looked over the top of his newspaper and shouted that Harry needed a haircut. Harry must have had more haircuts than the rest of the boys in his class put together, but it made no difference, his hair simply grew that way - all over the place.', + ], + outputs="audio", + title=title, + description=description, + theme="default", + allow_screenshot=False, + allow_flagging="never", +).launch(enable_queue=True) diff --git a/bazelisk-linux-amd64 b/bazelisk-linux-amd64 new file mode 100755 index 0000000000000000000000000000000000000000..d84f631d0e3b6a85f40c1cbff58722624847ae49 --- /dev/null +++ b/bazelisk-linux-amd64 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:231ec5ca8115e94c75a1f4fbada1a062b48822ca04f21f26e4cb1cd8973cd458 +size 5152768 diff --git a/extract_model.py b/extract_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f03adc21e8d3ef05c4040e53797ed0cb3d748b9a --- /dev/null +++ b/extract_model.py @@ -0,0 +1,5 @@ +import pickle + +dic = pickle.load(open("./tacotrons_ljs_24k_v1_0300000.ckpt", "rb")) +del dic["optim_state_dict"] +pickle.dump(dic, open("./tacotrons_ljs_24k_v1_0300000.ckpt", "wb")) diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a3e87ff49b2ed4f666559655c8e7bf3801a174 --- /dev/null +++ b/inference.py @@ -0,0 +1,91 @@ +import os + +import jax +import jax.numpy as jnp +import librosa +import numpy as np +import pax + +from text import english_cleaners +from utils import ( + create_tacotron_model, + load_tacotron_ckpt, + load_tacotron_config, + load_wavegru_ckpt, + load_wavegru_config, +) +from wavegru import WaveGRU + +os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = "./espeak/usr/lib/libespeak-ng.so.1.1.51" +from phonemizer.backend import EspeakBackend + +backend = EspeakBackend("en-us", preserve_punctuation=True, with_stress=True) + + +def load_tacotron_model(alphabet_file, config_file, model_file): + """load tacotron model to memory""" + with open(alphabet_file, "r", encoding="utf-8") as f: + alphabet = f.read().split("\n") + + config = load_tacotron_config(config_file) + net = create_tacotron_model(config) + _, net, _ = load_tacotron_ckpt(net, None, model_file) + net = net.eval() + net = jax.device_put(net) + return alphabet, net, config + + +tacotron_inference_fn = pax.pure(lambda net, text: net.inference(text, max_len=2400)) + + +def text_to_mel(net, text, alphabet, config): + """convert text to mel spectrogram""" + text = english_cleaners(text) + text = backend.phonemize([text], strip=True)[0] + text = text + config["END_CHARACTER"] + text = text + config["PAD"] * (100 - (len(text) % 100)) + tokens = [] + for c in text: + if c in alphabet: + tokens.append(alphabet.index(c)) + tokens = jnp.array(tokens, dtype=jnp.int32) + mel = tacotron_inference_fn(net, tokens[None]) + return mel + + +def load_wavegru_net(config_file, model_file): + """load wavegru to memory""" + config = load_wavegru_config(config_file) + net = WaveGRU( + mel_dim=config["mel_dim"], + rnn_dim=config["rnn_dim"], + upsample_factors=config["upsample_factors"], + has_linear_output=True, + ) + _, net, _ = load_wavegru_ckpt(net, None, model_file) + net = net.eval() + net = jax.device_put(net) + return config, net + + +wavegru_inference = pax.pure(lambda net, mel: net.inference(mel, no_gru=True)) + + +def mel_to_wav(net, netcpp, mel, config): + """convert mel to wav""" + if len(mel.shape) == 2: + mel = mel[None] + pad = config["num_pad_frames"] // 2 + 2 + mel = np.pad(mel, [(0, 0), (pad, pad), (0, 0)], mode="edge") + ft = wavegru_inference(net, mel) + ft = jax.device_get(ft[0]) + wav = netcpp.inference(ft, 1.0) + wav = np.array(wav) + wav = librosa.mu_expand(wav - 127, mu=255) + wav = librosa.effects.deemphasis(wav, coef=0.86) + wav = wav * 2.0 + wav = wav / max(1.0, np.max(np.abs(wav))) + wav = wav * 2**15 + wav = np.clip(wav, a_min=-(2**15), a_max=(2**15) - 1) + wav = wav.astype(np.int16) + return wav diff --git a/install_espeak_ng.sh b/install_espeak_ng.sh new file mode 100755 index 0000000000000000000000000000000000000000..cf3c52d273384786eaaffa9546fc5102b866d516 --- /dev/null +++ b/install_espeak_ng.sh @@ -0,0 +1,18 @@ +pip install -U pip +pip install gradio==3.42.0 +( + rm -rf espeak + mkdir -p espeak + cd espeak + wget https://github.com/espeak-ng/espeak-ng/archive/refs/tags/1.51.zip + unzip -qq 1.51.zip + cd espeak-ng-1.51 + ./autogen.sh + ./configure --prefix=`pwd`/../usr + make + make install +) +# build bazel too +rm wavegru_mod.so +USE_BAZEL_VERSION=5.0.0 ./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native +cp -f bazel-bin/wavegru_mod.so . \ No newline at end of file diff --git a/mynumbers.py b/mynumbers.py new file mode 100644 index 0000000000000000000000000000000000000000..5c30252e1c96fd9d3c762491e3107f0b5e811041 --- /dev/null +++ b/mynumbers.py @@ -0,0 +1,73 @@ +""" from https://github.com/keithito/tacotron """ + +import inflect +import re + + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return "%s %s" % (dollars, dollar_unit) + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s" % (cents, cent_unit) + else: + return "zero dollars" + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + else: + return _inflect.number_to_words( + num, andword="", zero="oh", group=2 + ).replace(", ", " ") + else: + return _inflect.number_to_words(num, andword="") + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r"\1 pounds", text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..edbf7d49a69a51a1f23759f35f1b87e92fe6633f --- /dev/null +++ b/packages.txt @@ -0,0 +1,7 @@ +libsndfile1-dev +make +autoconf +automake +libtool +gcc +pkg-config diff --git a/pooch.py b/pooch.py new file mode 100644 index 0000000000000000000000000000000000000000..9ee727a247d443f7fd78b4f9e393a8f935f5776d --- /dev/null +++ b/pooch.py @@ -0,0 +1,10 @@ +def os_cache(x): + return x + + +def create(*args, **kwargs): + class T: + def load_registry(self, *args, **kwargs): + return None + + return T() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0d32ee9c31690268bc6795e6b46d81dfba49adbf --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +inflect +jax +jaxlib +jinja2 +librosa +numpy +pax3 +pyyaml +toml +unidecode +phonemizer +gradio==3.42.0 \ No newline at end of file diff --git a/sparse_matmul/BUILD b/sparse_matmul/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..e56e1269143df73d0b8ae72c4ee85f4bcf999149 --- /dev/null +++ b/sparse_matmul/BUILD @@ -0,0 +1,22 @@ +# [internal] load placeholder + +licenses(["notice"]) + +cc_library( + name = "sparse_matmul", + hdrs = [ + "sparse_matmul.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//sparse_matmul/compute:gru_gates", + "//sparse_matmul/layers:layer", + "//sparse_matmul/layers:matrix", + "//sparse_matmul/layers:utils", + "//sparse_matmul/numerics:fast_transcendentals", + "//sparse_matmul/numerics:types", + "//sparse_matmul/os:coop_threads", + "//sparse_matmul/vector:cache_aligned_vector", + ], # internal :sparse_matmul deps placeholder +) + diff --git a/sparse_matmul/compute/BUILD b/sparse_matmul/compute/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..41f5b3f620fd0bcd53962b6b623589dbe1ad3233 --- /dev/null +++ b/sparse_matmul/compute/BUILD @@ -0,0 +1,88 @@ +# Low-level computation code, including generic and architecture-specific +# variants. + +licenses(["notice"]) + +cc_library( + name = "gru_gates", + srcs = [ + "ar_inputs.h", + "gru_gates_arm.h", + "gru_gates_avx_fixed.h", + "gru_gates_generic.h", + ], + hdrs = ["gru_gates.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":matmul", + "//sparse_matmul/numerics:fast_transcendentals", + "//sparse_matmul/numerics:types", + "//sparse_matmul/vector:cache_aligned_vector", + ], +) + +cc_library( + name = "kernels", + srcs = [ + "kernels_arm.h", + "kernels_avx.h", + ], + hdrs = [ + "kernels_generic.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + "//sparse_matmul/numerics:fast_transcendentals", + "//sparse_matmul/numerics:types", + ], +) + +cc_library( + name = "matmul", + srcs = [ + "matmul_fixed_avx2.cc", + "matmul_fixed_avx2.h", + "matmul_generic.cc", + "matmul_generic.h", + ], + hdrs = [ + "matmul.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + "//sparse_matmul/numerics:types", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "thread_bounds", + srcs = ["thread_bounds.cc"], + hdrs = ["thread_bounds.h"], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + "@com_google_glog//:glog", + ], +) + +cc_test( + name = "gru_gates_test", + size = "small", + srcs = [ + "gru_gates_test.cc", + ], + deps = [ + ":gru_gates", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/sparse_matmul/compute/ar_inputs.h b/sparse_matmul/compute/ar_inputs.h new file mode 100644 index 0000000000000000000000000000000000000000..d10e2d9635f11636edc0a7b647bae5876b3656c5 --- /dev/null +++ b/sparse_matmul/compute/ar_inputs.h @@ -0,0 +1,37 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_ + +namespace csrblocksparse { + +// Possible numbers of Autoregressive inputs. +// TODO(b/188702959): Generalize to any non-negative integer value? +enum class ARInputsMode { + // There are no autoregressive inputs. Inputs to the GRU gates are strictly + // from the gate-recurrent matmul and other unrelated inputs. + k0ARInputs, + // Two autoregressive inputs, such as coarse and fine for WaveRNN. + k2ARInputs, + // Three autoregressive inputs, such as prev coarse and fine plus current + // coarse for WaveRNN. + k3ARInputs, +}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_ diff --git a/sparse_matmul/compute/gru_gates.h b/sparse_matmul/compute/gru_gates.h new file mode 100644 index 0000000000000000000000000000000000000000..7b8cd489f5c6ef42de262d54727f99c5f9020b82 --- /dev/null +++ b/sparse_matmul/compute/gru_gates.h @@ -0,0 +1,214 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_ + +#include +#include + +// IWYU pragma: begin_exports +#include "sparse_matmul/compute/ar_inputs.h" +#include "sparse_matmul/compute/gru_gates_arm.h" +#include "sparse_matmul/compute/gru_gates_avx_fixed.h" +#include "sparse_matmul/compute/gru_gates_generic.h" +#include "sparse_matmul/compute/matmul.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/type_utils.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" +// IWYU pragma: end_exports + +namespace csrblocksparse { + +// The master template is really a catch-all for the unimplemented cases to +// run the generics. +template +class GruGates : public MatmulBase { + public: + using SampleWeightType = float; + static constexpr int kSIMDWidth = kGenericSIMDWidth; + + // Generic GRU function covers all uses for WaveRNN-like architectures and + // conditioning. + // Controlled by template parameters thus: + // - |kInputsMode| == |k0ARInputs|: There are no autoregressive inputs so + // |ar_sample0|, |ar_sample1|, |ar_sample2|, |ar_01_weights|, + // |ar_2_weights| are ignored. + // - |kInputsMode| == |k2ARInputs|: |ar_sample0|, |ar_sample1| are multiplied + // by |ar_01_weights| and added to the (conditioning) input. + // - |kInputsMode| == |k3ARInputs|: |ar_sample2| is multiplied by + // |ar_2_weights| and added to the other two |ar_inputs| (and added to the + // conditioning input). + // - If |kSplitGates| is true: The |*gru_recurrent_other_ptr| is secondary + // recurrent input that must be added to |*gru_recurrent_ptr|. + // - |num_replicas| determines the number of duplicates of the output to be + // written, separated by |replica_stride|. + // - |start|, |end| are |rows| in [0, |state_size|] to be processed by this + // thread. + // + // Previous state is read from |*gru_state_ptr| and the new state is written + // to *(|gru_state_ptr| + i * |replica_stride| for i in [0, |num_replicas|)). + template + void GruWithARInput(int start, int end, int state_size, + const InputType* gru_recurrent_ptr, + const InputType* input_ptr, GRUStateType* gru_state_ptr, + const SampleType* ar_sample0 = nullptr, + const SampleType* ar_sample1 = nullptr, + const SampleWeightType* ar_01_weights = nullptr, + int num_replicas = 1, int replica_stride = 0, + const SampleType* ar_sample2 = nullptr, + const SampleWeightType* ar_2_weights = nullptr, + const InputType* gru_recurrent_other_ptr = nullptr) { + CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica"; + GoThroughGates( + start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr, + input_ptr, gru_state_ptr, ar_2_weights, state_size, ar_sample0, + ar_sample1, ar_sample2); + } + + // No AR inputs, no split gates, no batching, no replicated outputs. + // TODO(b/188702959): Redirect conditioning GRU here, removing code from + // gru_layer.h. + // Copy to specializations. + void PlainGru(int start, int end, int state_size, + const InputType* gru_recurrent_ptr, const InputType* input_ptr, + GRUStateType* gru_state_ptr) { + GruWithARInput( + start, end, state_size, gru_recurrent_ptr, input_ptr, gru_state_ptr); + } +}; + +#if defined __ARM_NEON || defined __aarch64__ +// Partial specialization for float. +template <> +class GruGates : public MatmulBase { + public: + static constexpr int kSIMDWidth = kNeonSIMDWidth; + + // Generic GRU function covers all uses for WaveRNN-like architectures and + // conditioning. + template + void GruWithARInput(int start, int end, int state_size, + const float* gru_recurrent_data, const float* input_data, + float* gru_state_data, const float* ar_sample0 = nullptr, + const float* ar_sample1 = nullptr, + const float* ar_01_weights = nullptr, + int num_replicas = 1, int replica_stride = 0, + const float* ar_sample2 = nullptr, + const float* ar_2_weights = nullptr, + const float* gru_recurrent_other_data = nullptr) { + DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica"; + GoThroughGatesFloat( + start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data, + input_data, gru_state_data, ar_2_weights, state_size, ar_sample0, + ar_sample1, ar_sample2); + } +}; +#endif // defined __ARM_NEON || defined __aarch64__ + +// Partial specialization for fixed types. The sample weights are always float +// whatever the fixed type of the other weights. +template +class GruGates, fixed32, + fixed16> : public MatmulBase { + public: +#if defined __ARM_NEON || defined __aarch64__ + static constexpr int kSIMDWidth = kNeonSIMDWidth; +#elif defined __AVX2__ + static constexpr int kSIMDWidth = kAVX2SIMDWidth * 2; +#else // Generic case. + static constexpr int kSIMDWidth = kGenericSIMDWidth; +#endif // __ARM_NEON || defined __aarch64__ / __AVX2__ + + using GRUStateType = fixed16; + using InputType = fixed32; + using SampleType = fixed16; + using SampleWeightType = float; + static constexpr int kInputMantissaBits = InputType::kMantissaBits; + static constexpr int kSampleMantissaBits = SampleType::kMantissaBits; + static constexpr int kStateMantissaBits = GRUStateType::kMantissaBits; + // Generic GRU function covers all uses for WaveRNN-like architectures and + // conditioning. + template + void GruWithARInput(int start, int end, int state_size, + const InputType* gru_recurrent_data, + const InputType* input_data, GRUStateType* gru_state_data, + const SampleType* ar_sample0 = nullptr, + const SampleType* ar_sample1 = nullptr, + const SampleWeightType* ar_01_weights = nullptr, + int num_replicas = 1, int replica_stride = 0, + const SampleType* ar_sample2 = nullptr, + const SampleWeightType* ar_2_weights = nullptr, + const InputType* gru_recurrent_other_data = nullptr) { +#if defined __ARM_NEON || defined __aarch64__ || defined __AVX2__ + const int32_t* gru_recurrent_ptr = + reinterpret_cast(gru_recurrent_data); + const int32_t* gru_recurrent_other_ptr = + reinterpret_cast(gru_recurrent_other_data); + const int32_t* input_ptr = reinterpret_cast(input_data); + int16_t* gru_state_ptr = reinterpret_cast(gru_state_data); +#if defined __AVX2__ + // The samples are fixed16, but we scale them up here and convert to float + // so that the product with the QR weights is always on the same scale as + // InputType, so we don't have to do any more scaling inside. + const float sample_factor = static_cast(1 << kInputMantissaBits); +#else + const float sample_factor = 1.0f; +#endif + // AR sample 0 and 1 are packed into a pair because the QR weights are + // formatted with the weights interleaved for sample 0 and 1. + std::pair ar_sample01; + float ar_sample2_float = 0.0f; + if (kInputsMode == ARInputsMode::k2ARInputs || + kInputsMode == ARInputsMode::k3ARInputs) { + ar_sample01 = {static_cast(*ar_sample0) * sample_factor, + static_cast(*ar_sample1) * sample_factor}; + if (kInputsMode == ARInputsMode::k3ARInputs) { + ar_sample2_float = static_cast(*ar_sample2) * sample_factor; + } + } +#if defined __AVX2__ + CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!"; + GruGatesAVXFixed( + start, end, state_size, gru_recurrent_ptr, input_ptr, &ar_sample01, + ar_01_weights, num_replicas, replica_stride, &ar_sample2_float, + ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); +#else // ARM. + DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica"; + GoThroughGatesFixed( + start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr, + input_ptr, gru_state_ptr, ar_2_weights, state_size, &ar_sample01, + &ar_sample2_float); +#endif // __AVX2__ / ARM. +#else // Generic case. + CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica"; + GoThroughGates( + start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data, + input_data, gru_state_data, ar_2_weights, state_size, ar_sample0, + ar_sample1, ar_sample2); +#endif // __ARM_NEON || defined __aarch64__ / __AVX2__ + } +}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_ diff --git a/sparse_matmul/compute/gru_gates_arm.h b/sparse_matmul/compute/gru_gates_arm.h new file mode 100644 index 0000000000000000000000000000000000000000..d95805da4165df71c00b4e82557c647e2f746d1a --- /dev/null +++ b/sparse_matmul/compute/gru_gates_arm.h @@ -0,0 +1,288 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_ + +#if defined __ARM_NEON || defined __aarch64__ +#include +#endif +#include + +#include "sparse_matmul/compute/ar_inputs.h" +#include "sparse_matmul/numerics/fast_transcendentals.h" + +namespace csrblocksparse { + +static constexpr int kNeonSIMDWidth = 4; + +// ------ Scalar calculation -------- +// See "Efficient Neural Audio Synthesis" for a description of the calculation. +// https://arxiv.org/abs/1802.08435 +// +// NOTE: +// |sample| = (|coarse_at_sminus1|, |fine_at_sminus1|, +// |coarse_at_sminus1|, |fine_at_sminus1|) +// |w_sample| = (|coarse_at_s|, |coarse_at_s|, |coarse_at_s|, |coarse_at_s|) +// +// CHEATSHEET: +// vld1q_f32 = load 4 32-bit floats +// vmulq_f32(a, b) : return a * b; +// vaddq_f32(a, b) : return a + b; +// vmlaq_f32(c, a, b) : return c + a * b; +// vpaddq_f32(a, b) : return (a0 + a1, a2 + a3, b0 + b1, b2 + b3) +// vsubq_f32(a, b) : return a - b; +// vst1q_f32 = store 4 32-bit floats +#if defined __ARM_NEON || defined __aarch64__ + +#if !defined __aarch64__ +// Backport of vpaddq_f32 to ARM32. +inline float32x4_t vpaddq_f32(float32x4_t a, float32x4_t b) { + float32x2_t a10 = vget_low_f32(a); + float32x2_t a32 = vget_high_f32(a); + float32x2_t b10 = vget_low_f32(b); + float32x2_t b32 = vget_high_f32(b); + return vcombine_f32(vpadd_f32(a10, a32), vpadd_f32(b10, b32)); +} +#endif + +template +void GoThroughGatesFloat(int start, int end, const float* qr_ptr, + const float* gru_gates_ptr, + const float* gru_gates_other_ptr, + const float* conditioning_ptr, float* gru_h_ptr, + const float* w_hat, int proj_size, + const float* coarse_at_sminus1, + const float* fine_at_sminus1, + const float* coarse_at_s) { + // Increment all the pointers to save on pointer arithmetic in the loop. + conditioning_ptr += start; + gru_h_ptr += start; + gru_gates_ptr += start; + if (SplitGates) { + DCHECK_NE(gru_gates_other_ptr, nullptr); + gru_gates_other_ptr += start; + } + if (kInputsMode != ARInputsMode::k0ARInputs) { + DCHECK_NE(qr_ptr, nullptr); + qr_ptr += 2 * start; + DCHECK_NE(coarse_at_sminus1, nullptr); + DCHECK_NE(fine_at_sminus1, nullptr); + if (kInputsMode == ARInputsMode::k3ARInputs) { + DCHECK_NE(w_hat, nullptr); + DCHECK_NE(coarse_at_s, nullptr); + w_hat += start; + } + } + for (int i = start; i < end; i += kNeonSIMDWidth) { + float32x4_t reset = vld1q_f32(gru_gates_ptr); + float32x4_t update = vld1q_f32(gru_gates_ptr + proj_size); + float32x4_t cell = vld1q_f32(gru_gates_ptr + 2 * proj_size); + float32x4_t qr_cell; + if (SplitGates) { + reset = vaddq_f32(reset, vld1q_f32(gru_gates_other_ptr)); + update = vaddq_f32(update, vld1q_f32(gru_gates_other_ptr + proj_size)); + cell = vaddq_f32(cell, vld1q_f32(gru_gates_other_ptr + 2 * proj_size)); + } + if (kInputsMode != ARInputsMode::k0ARInputs) { + // Setup the sample vector. + float32x4_t sample = vdupq_n_f32(*coarse_at_sminus1); + sample = vsetq_lane_f32(*fine_at_sminus1, sample, 1); + sample = vsetq_lane_f32(*fine_at_sminus1, sample, 3); + + // All auto types are float32x4_t, auto used to fit statements on one line + // for readability. Do two rows of QR at once. + auto qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample); + auto qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample); + auto qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1); + + auto qr_update_0 = vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample); + auto qr_update_1 = + vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample); + auto qr_update = vpaddq_f32(qr_update_0, qr_update_1); + + auto qr_cell_0 = vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample); + auto qr_cell_1 = vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample); + qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1); + + if (kInputsMode == ARInputsMode::k3ARInputs) { + float32x4_t w_sample = vdupq_n_f32(*coarse_at_s); + qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample); + qr_update = + vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample); + qr_cell = + vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample); + } + reset = vaddq_f32(reset, qr_reset); + update = vaddq_f32(update, qr_update); + } + auto reset_conditioning = vld1q_f32(conditioning_ptr); + auto update_conditioning = vld1q_f32(conditioning_ptr + proj_size); + auto cell_conditioning = vld1q_f32(conditioning_ptr + 2 * proj_size); + + reset = fast_sigmoid(vaddq_f32(reset, reset_conditioning)); + update = fast_sigmoid(vaddq_f32(update, update_conditioning)); + if (kInputsMode == ARInputsMode::k0ARInputs) { + cell = vmulq_f32(reset, cell); + } else { + cell = vmlaq_f32(qr_cell, reset, cell); + } + auto hbar = fast_tanh(vaddq_f32(cell, cell_conditioning)); + + auto prev_h = vld1q_f32(gru_h_ptr); + auto diff = vsubq_f32(prev_h, hbar); + auto new_h = vmlaq_f32(hbar, diff, update); + + vst1q_f32(gru_h_ptr, new_h); + // Increment all the pointers. + conditioning_ptr += kNeonSIMDWidth; + gru_h_ptr += kNeonSIMDWidth; + gru_gates_ptr += kNeonSIMDWidth; + if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth; + if (kInputsMode != ARInputsMode::k0ARInputs) { + qr_ptr += 2 * kNeonSIMDWidth; + if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth; + } + } +} + +// This version should only be used if all of the 32-bit fixed point +// representations have the same number of mantissa bits. +// |ar_at_sminus1| packs sample 0 and 1 into a pair because the QR weights are +// formatted with the weights interleaved for sample 0 and 1. The two samples +// represent coarse and fine for WaveRNN. +template +void GoThroughGatesFixed(int start, int end, const float* qr_ptr, + const int32_t* gru_gates_ptr, + const int32_t* gru_gates_other_ptr, + const int32_t* conditioning_ptr, int16_t* gru_h_ptr, + const float* w_hat, int proj_size, + const std::pair* ar_at_sminus1, + const float* coarse_at_s) { + // Increment all the pointers to save on pointer arithmetic in the loop. + conditioning_ptr += start; + gru_h_ptr += start; + gru_gates_ptr += start; + if (SplitGates) { + DCHECK_NE(gru_gates_other_ptr, nullptr); + gru_gates_other_ptr += start; + } + float32x4_t sample01; + float32x4_t w_sample; + if (kInputsMode != ARInputsMode::k0ARInputs) { + DCHECK_NE(qr_ptr, nullptr); + qr_ptr += 2 * start; + DCHECK_NE(ar_at_sminus1, nullptr); + sample01 = vdupq_n_f32(ar_at_sminus1->first); + sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 1); + sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 3); + if (kInputsMode == ARInputsMode::k3ARInputs) { + DCHECK_NE(w_hat, nullptr); + DCHECK_NE(coarse_at_s, nullptr); + w_hat += start; + w_sample = vdupq_n_f32(*coarse_at_s); + } + } + for (int i = start; i < end; i += kNeonSIMDWidth) { + auto reset = vld1q_s32(gru_gates_ptr); + auto update = vld1q_s32(gru_gates_ptr + proj_size); + // vcvtq_n_f32_s32 = convert 32-bit fixed point to fp32 + auto cell_int = vld1q_s32(gru_gates_ptr + 2 * proj_size); + if (SplitGates) { + reset = vaddq_s32(reset, vld1q_s32(gru_gates_other_ptr)); + update = vaddq_s32(update, vld1q_s32(gru_gates_other_ptr + proj_size)); + cell_int = + vaddq_s32(cell_int, vld1q_s32(gru_gates_other_ptr + 2 * proj_size)); + } + float32x4_t cell = + vcvtq_n_f32_s32(cell_int, GRUMatMulOutType::kMantissaBits); + float32x4_t qr_cell; + if (kInputsMode != ARInputsMode::k0ARInputs) { + // Do two rows of QR at once. + float32x4_t qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample01); + float32x4_t qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample01); + float32x4_t qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1); + + float32x4_t qr_update_0 = + vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample01); + float32x4_t qr_update_1 = + vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample01); + float32x4_t qr_update = vpaddq_f32(qr_update_0, qr_update_1); + + float32x4_t qr_cell_0 = + vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample01); + float32x4_t qr_cell_1 = + vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample01); + qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1); + if (kInputsMode == ARInputsMode::k3ARInputs) { + float32x4_t w_sample = vdupq_n_f32(*coarse_at_s); + qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample); + qr_update = + vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample); + qr_cell = + vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample); + } + reset = vaddq_s32( + reset, vcvtq_n_s32_f32(qr_reset, GRUMatMulOutType::kMantissaBits)); + update = vaddq_s32( + update, vcvtq_n_s32_f32(qr_update, GRUMatMulOutType::kMantissaBits)); + } + + auto reset_conditioning = vld1q_s32(conditioning_ptr); + auto update_conditioning = vld1q_s32(conditioning_ptr + proj_size); + float32x4_t cell_conditioning = + vcvtq_n_f32_s32(vld1q_s32(conditioning_ptr + 2 * proj_size), + GRUMatMulOutType::kMantissaBits); + + float32x4_t reset_f32 = fast_sigmoid( + vaddq_s32(reset, reset_conditioning)); + float32x4_t update_f32 = fast_sigmoid( + vaddq_s32(update, update_conditioning)); + if (kInputsMode == ARInputsMode::k0ARInputs) { + cell = vmulq_f32(reset_f32, cell); + } else { + cell = vmlaq_f32(qr_cell, reset_f32, cell); + } + float32x4_t hbar = fast_tanh(vaddq_f32(cell, cell_conditioning)); + + float32x4_t prev_h = vcvtq_n_f32_s32(vmovl_s16(vld1_s16(gru_h_ptr)), + GRUStateType::kMantissaBits); + float32x4_t diff = vsubq_f32(prev_h, hbar); + float32x4_t new_h = vmlaq_f32(hbar, diff, update_f32); + + // vcvtq_n_s32_f32 = convert fp32 to signed 32-bit fixed point + // vqrshrn_n_s32 = saturating, rounding, narrowing right shift - used to + // convert a 32-bit fixed point value to a 16-bit fixed point value + vst1_s16(gru_h_ptr, + vqrshrn_n_s32( + vcvtq_n_s32_f32(new_h, GRUStateType::kMantissaBits + 16), 16)); + // Increment all the pointers. + conditioning_ptr += kNeonSIMDWidth; + gru_h_ptr += kNeonSIMDWidth; + gru_gates_ptr += kNeonSIMDWidth; + if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth; + if (kInputsMode != ARInputsMode::k0ARInputs) { + qr_ptr += 2 * kNeonSIMDWidth; + if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth; + } + } +} +#endif // defined __ARM_NEON || defined __aarch64__ + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_ diff --git a/sparse_matmul/compute/gru_gates_avx_fixed.h b/sparse_matmul/compute/gru_gates_avx_fixed.h new file mode 100644 index 0000000000000000000000000000000000000000..cf7cf0e770d27d583dd63116c350c6dd49d8a528 --- /dev/null +++ b/sparse_matmul/compute/gru_gates_avx_fixed.h @@ -0,0 +1,348 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_ + +#include +#if defined __AVX2__ +#include +#endif +#include + +#include "sparse_matmul/compute/ar_inputs.h" +#include "sparse_matmul/numerics/fast_transcendentals.h" + +namespace csrblocksparse { + +#if defined __AVX2__ + +constexpr int kAVX2SIMDWidth = 8; + +// Loads 8x fixed32 from |ptr0| and adds to |input|. +// If |kTwoInputs|, also loads from |ptr1| and adds that as well. +// Returns the 2 or 3-way sum. +template +inline __m256i LoadAndAddFixed32(const int32_t* ptr0, const int32_t* ptr1, + const __m256i& input) { + __m256i data0 = _mm256_load_si256(reinterpret_cast(ptr0)); + if (kTwoInputs) { + __m256i data1 = _mm256_load_si256(reinterpret_cast(ptr1)); + data0 = _mm256_add_epi32(data0, data1); + } + return _mm256_add_epi32(data0, input); +} + +// Loads 8x fixed32 from ptr0. +// If |kTwoInputs|, also loads from |ptr1| and adds. +// Multiplies the loaded values by the factor and adds to |input|, which also +// is converted to float. +// Returns the sum. +template +inline __m256 LoadMultiplyAddToFloat(const int32_t* ptr0, const int32_t* ptr1, + const __m256& float_factor, + const __m256& input) { + __m256i data0 = _mm256_load_si256(reinterpret_cast(ptr0)); + if (kTwoInputs) { + __m256i data1 = _mm256_load_si256(reinterpret_cast(ptr1)); + data0 = _mm256_add_epi32(data0, data1); + } + __m256 float_result = _mm256_cvtepi32_ps(data0); + float_result = _mm256_mul_ps(float_result, float_factor); + return _mm256_add_ps(float_result, input); +} + +// Loads 16x float in 2x 8x registers from |ptr0_1| and multiplies by +// |input_pairs|, likewise formatted as 8x floats, alternating between the two +// AR inputs and sums each pair of results, making 8x float results. +// If |kThreeInputs|, also loads 8x float from |ptr2| and multiplies by +// |third_input|, which must be formatted as 8x float. The second product is +// added to the previous result. +// Returns the sum added to |accumulator|. +template +inline __m256 MultiplyAddFloat(const __m256& input_pairs, + const __m256& third_input, const float* ptr0_1, + const float* ptr2, const __m256& accumulator) { + __m256 data_pair0 = _mm256_load_ps(ptr0_1); + __m256 data_pair1 = _mm256_load_ps(ptr0_1 + 8); + data_pair0 = _mm256_mul_ps(data_pair0, input_pairs); + data_pair1 = _mm256_mul_ps(data_pair1, input_pairs); + data_pair0 = _mm256_hadd_ps(data_pair0, data_pair1); + // Swap the middle 2 64 bit pairs to correct the hadd result. + data_pair0 = _mm256_permute4x64_pd((__m256d)data_pair0, 0xd8); + if (kThreeInputs) { + // Load 256 bits (8 x float) of data, then multiply-accumulate. + data_pair1 = _mm256_load_ps(ptr2); + data_pair1 = _mm256_mul_ps(data_pair1, third_input); + data_pair0 = _mm256_add_ps(data_pair0, data_pair1); + } + // Add conditioning. + return _mm256_add_ps(data_pair0, accumulator); +} + +// Processes the tanh and the final combination, returns the new GRU state. +template +inline __m256i GRUComputeState(const __m256& cell0, const __m256& cell1, + const __m256& reset0, const __m256& reset1, + const __m256& update0, const __m256& update1, + const int32_t* gate_ptr, + const int32_t* gate_other_ptr, + const void* gru_h_ptr) { + // Multiply the cell gru output and the reset. + __m256 float_gru0 = LoadMultiplyAddToFloat( + gate_ptr, gate_other_ptr, reset0, cell0); + __m256 float_gru1 = LoadMultiplyAddToFloat( + gate_ptr + kAVX2SIMDWidth, gate_other_ptr + kAVX2SIMDWidth, reset1, + cell1); + // Compute tanh on the result. + __m256 hbar0, hbar1; + float_tanh_float(float_gru0, float_gru1, + hbar0, hbar1); + // Load the 16-bit previous gru state and update. + __m256i gru = _mm256_load_si256(reinterpret_cast<__m256i const*>(gru_h_ptr)); + __m256 state_factor = + _mm256_set1_ps(1.0f / (static_cast(1 << kStateMantissaBits))); + float_gru0 = + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(gru))); + float_gru1 = _mm256_cvtepi32_ps( + _mm256_cvtepi16_epi32(_mm256_extractf128_si256(gru, 1))); + float_gru0 = _mm256_mul_ps(float_gru0, state_factor); + float_gru1 = _mm256_mul_ps(float_gru1, state_factor); + float_gru0 = _mm256_sub_ps(float_gru0, hbar0); + float_gru1 = _mm256_sub_ps(float_gru1, hbar1); + float_gru0 = _mm256_mul_ps(float_gru0, update0); + float_gru1 = _mm256_mul_ps(float_gru1, update1); + state_factor = _mm256_set1_ps(static_cast(1 << kStateMantissaBits)); + float_gru0 = _mm256_add_ps(float_gru0, hbar0); + float_gru1 = _mm256_add_ps(float_gru1, hbar1); + float_gru0 = _mm256_mul_ps(float_gru0, state_factor); + float_gru1 = _mm256_mul_ps(float_gru1, state_factor); + return PackFloatsToFixed16(float_gru0, float_gru1); +} + +// According to |kInputsMode|, processes 0, 2 or 3 autoregressive inputs and +// combines with |input| and |gates*|. +// With 2 AR inputs, loads 8x pairs of float from |pair_weights| and multiplies +// by |paired_ar|, likewise formatted as 8x float, but scaled such that the +// product with pair_weights is on the same scale as |*input| and |*gates0|, +// and sums each pair result, making 8x float results. +// If 3 AR inputs, also loads 8x float from |third_weights| and multiplies by +// |third_ar|, which must be formatted as 8x scaled floats. The second product +// is added to the previous result. +// Inputs, 8x fixed32 are loaded from |input|, and added to the total. +// Finally 8x fixed32 from |gates0| (and |gates1| if |kTwoGates|) are added as +// well. +// Returns the total sum as a float, but on the scale of |*input|. +template +inline __m256 GruInput32ToFloat(const __m256& paired_ar, + const __m256& third_ar, + const float* pair_weights, + const float* third_weights, + const int32_t* gates0, const int32_t* gates1, + const int32_t* input) { + __m256i data32 = _mm256_load_si256(reinterpret_cast<__m256i const*>(input)); + data32 = LoadAndAddFixed32(gates0, gates1, data32); + __m256 float_data = _mm256_cvtepi32_ps(data32); + if (kInputsMode != ARInputsMode::k0ARInputs) { + float_data = MultiplyAddFloat( + paired_ar, third_ar, pair_weights, third_weights, float_data); + } + return float_data; +} + +// Generic GRU gates function controlled by template parameters thus: +// - |kInputBits|: the mantissa bits in |*input_ptr|, |*gru_recurrent_ptr|. +// - |kStateBits|: the mantissa_bits in |*gru_state_ptr|. +// - |kInputsMode == |k0ARInputs|: There are no autoregressive inputs so +// |ar_sample, |ar_sample1|, |ar_sample2|, |ar_01_weights|, |ar_2_weights| are +// ignored. +// - |kInputsMode| == |k2ARInputs|: |ar_sample0|, |ar_sample1| are multiplied by +// |ar_01_weights| and added to the (conditioning) input. +// - |kInputsMode| == |k3ARInputs|: |ar_sample2| is multiplied by |ar_2_weights| +// and added to the other two AR inputs (and added to the conditioning input). +// - |kReplicas| determines the number of duplicates of the output to be +// written, separated by |replica_stride|. If zero, then the number of +// replicas is variable and taken from the |replicas| argument. +// - If |kSplitGates| is true: The |*gru_recurrent_other_ptr| is secondary +// recurrent input that must be added to |*gru_recurrent_ptr|. +// - |start|, |end| are |rows| in [0, |state_size|] to be processed by this +// thread. +// +// Previous state is read from |*gru_state_ptr| and the new state is written to +// *(|gru_state_ptr| + i * |replica_stride| for i in [0, |kReplicas|]). +template +inline void GruGatesTemplate( + int start, int end, int state_size, int replicas, int replica_stride, + const int32_t* gru_recurrent_ptr, const int32_t* input_ptr, + const std::pair* ar_sample01, const float* ar_01_weights, + const float* ar_sample2, const float* ar_2_weights, + const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) { + constexpr int kQRIncrement = kAVX2SIMDWidth; + // Increment all the pointers to save on pointer arithmetic in the loop. + input_ptr += start; + gru_state_ptr += start; + gru_recurrent_ptr += start; + if (kSplitGates) gru_recurrent_other_ptr += start; + __m256 ar_2_inputs, ar_3rd_input; + if (kInputsMode != ARInputsMode::k0ARInputs) { + ar_01_weights += 2 * start; + ar_2_inputs = _mm256_castsi256_ps( + _mm256_set1_epi64x(*reinterpret_cast(ar_sample01))); + if (kInputsMode == ARInputsMode::k3ARInputs) { + ar_2_weights += start; + ar_3rd_input = _mm256_set1_ps(*ar_sample2); + } else { + ar_3rd_input = {}; + } + } else { + ar_2_inputs = {}; + ar_3rd_input = {}; + } + // The transcendentals handle 2x registers of data at once, so we have to do + // everything in duplicate. + for (int i = start; i < end; i += kQRIncrement * 2) { + // Load 8 pairs of fixed16s for each of reset, update and cell. + __m256 reset0 = GruInput32ToFloat( + ar_2_inputs, ar_3rd_input, ar_01_weights, ar_2_weights, + gru_recurrent_ptr, gru_recurrent_other_ptr, input_ptr); + __m256 reset1 = GruInput32ToFloat( + ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * kQRIncrement, + ar_2_weights + kQRIncrement, gru_recurrent_ptr + kAVX2SIMDWidth, + gru_recurrent_other_ptr + kAVX2SIMDWidth, input_ptr + kAVX2SIMDWidth); + float_sigmoid_float(reset0, reset1); + __m256 update0 = GruInput32ToFloat( + ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * state_size, + ar_2_weights + state_size, gru_recurrent_ptr + state_size, + gru_recurrent_other_ptr + state_size, input_ptr + state_size); + __m256 update1 = GruInput32ToFloat( + ar_2_inputs, ar_3rd_input, + ar_01_weights + 2 * state_size + 2 * kQRIncrement, + ar_2_weights + state_size + kQRIncrement, + gru_recurrent_ptr + state_size + kAVX2SIMDWidth, + gru_recurrent_other_ptr + state_size + kAVX2SIMDWidth, + input_ptr + state_size + kAVX2SIMDWidth); + float_sigmoid_float(update0, update1); + __m256 cell0 = _mm256_cvtepi32_ps(_mm256_load_si256( + reinterpret_cast<__m256i const*>(input_ptr + 2 * state_size))); + __m256 cell1 = + _mm256_cvtepi32_ps(_mm256_load_si256(reinterpret_cast<__m256i const*>( + input_ptr + 2 * state_size + kAVX2SIMDWidth))); + if (kInputsMode != ARInputsMode::k0ARInputs) { + cell0 = MultiplyAddFloat( + ar_2_inputs, ar_3rd_input, ar_01_weights + 4 * state_size, + ar_2_weights + 2 * state_size, cell0); + cell1 = MultiplyAddFloat( + ar_2_inputs, ar_3rd_input, + ar_01_weights + 4 * state_size + 2 * kQRIncrement, + ar_2_weights + 2 * state_size + kQRIncrement, cell1); + } + __m256i gru_state = GRUComputeState( + cell0, cell1, reset0, reset1, update0, update1, + gru_recurrent_ptr + 2 * state_size, + gru_recurrent_other_ptr + 2 * state_size, gru_state_ptr); + if (kReplicas > 0) { + // With |kReplicas| a template parameter, the compiler will unroll the + // loop. + for (int j = 0; j < kReplicas; ++j) { + _mm256_store_si256( + reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride), + gru_state); + } + } else { + // This loop will not unroll as replicas is variable. + for (int j = 0; j < replicas; ++j) { + _mm256_store_si256( + reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride), + gru_state); + } + } + // Increment all the pointers. + input_ptr += 2 * kAVX2SIMDWidth; + gru_state_ptr += 2 * kAVX2SIMDWidth; + gru_recurrent_ptr += 2 * kAVX2SIMDWidth; + if (kSplitGates) gru_recurrent_other_ptr += 2 * kAVX2SIMDWidth; + if (kInputsMode != ARInputsMode::k0ARInputs) { + ar_01_weights += 4 * kQRIncrement; + if (kInputsMode == ARInputsMode::k3ARInputs) + ar_2_weights += 2 * kQRIncrement; + } + } +} + +// Dispatches calls to the GruGatesTemplate function above converting the +// replicas variable argument to a template parameter to allow the compiler to +// unroll the write loop. +// |ar_sample01| packs sample 0 and 1 into a pair because the QR weights are +// formatted with the weights interleaved for sample 0 and 1. The two samples +// represent coarse and fine for WaveRNN. +template +inline void GruGatesAVXFixed( + int start, int end, int state_size, const int32_t* gru_recurrent_ptr, + const int32_t* input_ptr, const std::pair* ar_sample01, + const float* ar_01_weights, int num_replicas, int replica_stride, + const float* ar_sample2, const float* ar_2_weights, + const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) { + // Convert the number of replicas from a variable to a template parameter + // with a switch. This enables the compiler to unroll the loop for + // the write, making it faster for common numbers of threads. + switch (num_replicas) { + case 1: + GruGatesTemplate( + start, end, state_size, num_replicas, replica_stride, + gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, + ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); + break; + case 2: + GruGatesTemplate( + start, end, state_size, num_replicas, replica_stride, + gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, + ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); + break; + case 4: + GruGatesTemplate( + start, end, state_size, num_replicas, replica_stride, + gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, + ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); + break; + case 6: + GruGatesTemplate( + start, end, state_size, num_replicas, replica_stride, + gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, + ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); + break; + default: + // Zero |kReplicas| tells the function to use the |num_replicas| variable. + GruGatesTemplate( + start, end, state_size, num_replicas, replica_stride, + gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, + ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); + } +} + +#endif // __AVX2__ + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_ diff --git a/sparse_matmul/compute/gru_gates_generic.h b/sparse_matmul/compute/gru_gates_generic.h new file mode 100644 index 0000000000000000000000000000000000000000..691efb1f822e7f1e4862a99ef5ccb495fbc000d8 --- /dev/null +++ b/sparse_matmul/compute/gru_gates_generic.h @@ -0,0 +1,97 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_ + +#include "sparse_matmul/compute/ar_inputs.h" +#include "sparse_matmul/numerics/fast_transcendentals.h" + +namespace csrblocksparse { + +constexpr int kGenericSIMDWidth = 4; + +// TODO(b/188702959): Rename arguments to match gru_gates.h. +template +void GoThroughGates(int start, int end, const QR_W_Type* qr_ptr, + const GRUMatMulOutType* gru_gates_ptr, + const GRUMatMulOutType* gru_gates_other_ptr, + const GRUMatMulOutType* conditioning_ptr, + GRUStateType* gru_h_ptr, const QR_W_Type* w_hat, + int proj_size, const SampleType* coarse_at_sminus1, + const SampleType* fine_at_sminus1, + const SampleType* coarse_at_s = nullptr) { + float qr_cell = 0.0f, reset, update, cell; + for (int i = start; i < end; ++i) { + if (kInputsMode == ARInputsMode::k0ARInputs) { + reset = static_cast(gru_gates_ptr[i]); + update = static_cast(gru_gates_ptr[proj_size + i]); + } else { + float qr_c_reset = static_cast(qr_ptr[2 * i + 0]); + float qr_f_reset = static_cast(qr_ptr[2 * i + 1]); + float qr_c_update = static_cast(qr_ptr[2 * proj_size + 2 * i + 0]); + float qr_f_update = static_cast(qr_ptr[2 * proj_size + 2 * i + 1]); + float qr_c_cell = static_cast(qr_ptr[4 * proj_size + 2 * i + 0]); + float qr_f_cell = static_cast(qr_ptr[4 * proj_size + 2 * i + 1]); + float w_hat_i_reset = 0.0f; + float w_hat_i_update = 0.0f; + float w_hat_i_cell = 0.0f; + if (kInputsMode == ARInputsMode::k3ARInputs) { + w_hat_i_reset = static_cast(w_hat[i]); + w_hat_i_update = static_cast(w_hat[proj_size + i]); + w_hat_i_cell = static_cast(w_hat[2 * proj_size + i]); + } + float coarse = static_cast(coarse_at_sminus1[0]); + float fine = static_cast(fine_at_sminus1[0]); + reset = qr_c_reset * coarse + qr_f_reset * fine; + update = qr_c_update * coarse + qr_f_update * fine; + qr_cell = qr_c_cell * coarse + qr_f_cell * fine; + if (kInputsMode == ARInputsMode::k3ARInputs) { + float coarse = static_cast(coarse_at_s[0]); + reset += w_hat_i_reset * coarse; + update += w_hat_i_update * coarse; + qr_cell += w_hat_i_cell * coarse; + } + reset += static_cast(gru_gates_ptr[i]); + update += static_cast(gru_gates_ptr[proj_size + i]); + } + cell = static_cast(gru_gates_ptr[2 * proj_size + i]); + if (SplitGates) { + reset += static_cast(gru_gates_other_ptr[i]); + update += static_cast(gru_gates_other_ptr[proj_size + i]); + cell += static_cast(gru_gates_other_ptr[2 * proj_size + i]); + } + float reset_conditioning = static_cast(conditioning_ptr[i]); + float update_conditioning = + static_cast(conditioning_ptr[proj_size + i]); + float cell_conditioning = + static_cast(conditioning_ptr[2 * proj_size + i]); + reset = fast_sigmoid(reset + reset_conditioning); + update = fast_sigmoid(update + update_conditioning); + float hbar = fast_tanh(qr_cell + reset * cell + cell_conditioning); + int h_index = i; + float prev_h = static_cast(gru_h_ptr[h_index]); + float diff = prev_h - hbar; + float new_h = hbar + diff * update; + gru_h_ptr[h_index] = static_cast(new_h); + } +} + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_ diff --git a/sparse_matmul/compute/gru_gates_test.cc b/sparse_matmul/compute/gru_gates_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f626c98481c97dfc00c348818b988cf9e174d1c --- /dev/null +++ b/sparse_matmul/compute/gru_gates_test.cc @@ -0,0 +1,164 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/compute/gru_gates.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/types/span.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace { + +using csrblocksparse::ARInputsMode; + +template +csrblocksparse::CacheAlignedVector TestGruGates() { + using SampleWeightType = float; + constexpr int kStateSize = 16; + csrblocksparse::CacheAlignedVector qr(6 * kStateSize); + csrblocksparse::CacheAlignedVector w(3 * kStateSize); + csrblocksparse::CacheAlignedVector gru_gates(3 * kStateSize); + csrblocksparse::CacheAlignedVector gru_other_gates(3 * kStateSize); + csrblocksparse::CacheAlignedVector conditioning(3 * kStateSize); + csrblocksparse::CacheAlignedVector gru_h(kStateSize); + csrblocksparse::GruGates gru_gates_impl; + const SampleType kCoarseAtSMinus1(0.03f); + const SampleType kFineAtSMinus1(0.07f); + const SampleType kCoarseAtS(-0.02f); + + qr.FillOnes(); + w.FillOnes(); + gru_gates.FillRandom(); + gru_other_gates.FillRandom(); + conditioning.FillRandom(); + gru_h.FillZero(); + + gru_gates_impl.template GruWithARInput( + /*start=*/0, /*end=*/kStateSize, kStateSize, gru_gates.data(), + conditioning.data(), gru_h.data(), &kCoarseAtSMinus1, &kFineAtSMinus1, + qr.data(), + /*num_replicas=*/1, /*replica_stride=*/0, &kCoarseAtS, w.data(), + gru_other_gates.data()); + return gru_h; +} + +TEST(GruGates, FloatWaveRNNCoarseMatchesGolden) { + // If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers + // will also need to change. + const std::vector kGoldenValues = { + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.746f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.970f, 0.0f, 0.0f, 1.0f, 0.0f, -0.993f}; + csrblocksparse::CacheAlignedVector gru_h = + TestGruGates(); + + ASSERT_EQ(kGoldenValues.size(), gru_h.size()); + for (int i = 0; i < gru_h.size(); ++i) { + EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i; + } +} + +TEST(GruGates, FloatWaveRNNFineMatchesGolden) { + // If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers + // will also need to change. + const std::vector kGoldenValues = { + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.737f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.969f, 0.0f, 0.0f, 1.0f, 0.0f, -0.994f}; + csrblocksparse::CacheAlignedVector gru_h = + TestGruGates(); + + ASSERT_EQ(kGoldenValues.size(), gru_h.size()); + for (int i = 0; i < gru_h.size(); ++i) { + EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i; + } +} + +TEST(GruGates, FloatTwoArInputsNonSplitGateMatchesGolden) { + // If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers + // will also need to change. + const std::vector kGoldenValues = { + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.714f, 0.0f, -0.002f, + 0.0f, 0.0f, 0.970f, 0.0f, 0.0f, 1.0f, 0.0f, -0.965f}; + csrblocksparse::CacheAlignedVector gru_h = + TestGruGates(); + + ASSERT_EQ(kGoldenValues.size(), gru_h.size()); + for (int i = 0; i < gru_h.size(); ++i) { + EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i; + } +} + +TEST(GruGates, FixedWaveRNNCoarseMatchesFloat) { + using GRUMatMulOutType = csrblocksparse::fixed32<11>; + using GRUStateType = csrblocksparse::fixed16<2>; + using SampleType = csrblocksparse::fixed16<0>; + csrblocksparse::CacheAlignedVector float_gru_h = + TestGruGates(); + csrblocksparse::CacheAlignedVector fixed_gru_h = + TestGruGates(); + + ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size()); + for (int i = 0; i < fixed_gru_h.size(); ++i) { + EXPECT_NEAR(float_gru_h[i], static_cast(fixed_gru_h[i]), 1e-3) + << "i=" << i; + } +} + +TEST(GruGates, FixedWaveRNNFineMatchesFloat) { + using GRUMatMulOutType = csrblocksparse::fixed32<11>; + using GRUStateType = csrblocksparse::fixed16<2>; + using SampleType = csrblocksparse::fixed16<0>; + csrblocksparse::CacheAlignedVector float_gru_h = + TestGruGates(); + csrblocksparse::CacheAlignedVector fixed_gru_h = + TestGruGates(); + + ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size()); + for (int i = 0; i < fixed_gru_h.size(); ++i) { + EXPECT_NEAR(float_gru_h[i], static_cast(fixed_gru_h[i]), 1e-3) + << "i=" << i; + } +} + +TEST(GruGates, FixedTwoArInputsNonSplitGateMatchesFloat) { + using GRUMatMulOutType = csrblocksparse::fixed32<11>; + using GRUStateType = csrblocksparse::fixed16<2>; + using SampleType = csrblocksparse::fixed16<0>; + csrblocksparse::CacheAlignedVector float_gru_h = + TestGruGates(); + csrblocksparse::CacheAlignedVector fixed_gru_h = + TestGruGates(); + + ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size()); + for (int i = 0; i < fixed_gru_h.size(); ++i) { + EXPECT_NEAR(float_gru_h[i], static_cast(fixed_gru_h[i]), 1e-3) + << "i=" << i; + } +} + +} // namespace diff --git a/sparse_matmul/compute/kernels_arm.h b/sparse_matmul/compute/kernels_arm.h new file mode 100644 index 0000000000000000000000000000000000000000..494430fef873ebd86064263b7ab4d401906910e8 --- /dev/null +++ b/sparse_matmul/compute/kernels_arm.h @@ -0,0 +1,2886 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_ARM_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_ARM_H_ + +#if defined __aarch64__ + +#include + +#include + +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" +#include "sparse_matmul/numerics/type_utils.h" + +#define LABEL_COL_LOOP "1" +#define LABEL_ROW_LOOP "2" +#define LABEL_SKIP_COL_LOOP "3" +#define LABEL_TOP_LOOP "4" + +namespace csrblocksparse { +namespace detail { + +template +struct IsFloatOrBfloat + : std::integral_constant::value || + std::is_same::value> {}; + +template +struct IsAllowableFloatTypes + : std::integral_constant::value && + std::is_same::value && + std::is_same::value> {}; + +// 16-bit inputs, 32-bit output exponent matches sum of input exponents +// OR +// 16-bit inputs, 16-bit output - will shift to match exponent +template +struct IsAllowableFixedTypes + : std::integral_constant::value && + IsFixed16Type::value) && + (IsFixed32Type::value || + IsFixed16Type::value)> {}; + +template +struct ShouldEnableGenericKernel + : std::integral_constant< + bool, + !IsAllowableFloatTypes::value && + !IsAllowableFixedTypes::value> {}; + +template +struct ShouldEnableGenericSpMV_4x4 + : ShouldEnableGenericKernel {}; +template +struct ShouldEnableGenericSpMM5_4x4 + : ShouldEnableGenericKernel {}; +template +struct ShouldEnableGenericSpMV_1x1 : std::true_type {}; +template +struct ShouldEnableGenericSpMM5_1x1 : std::true_type {}; +template +struct IsAddableFixedTypes + : std::integral_constant::value || + IsFixed16Type::value> {}; +template +struct ShouldEnableGenericAdd + : std::integral_constant::value> {}; + +// The computational routines do NO error checking for speed. It is assumed +// that this has been handled by CSRBlockSparseMatrix. + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a vector and b is vector. Weights are stored for this +// routine by making each 4x4 block contiguous. Blocks are ordered in standard +// row-major format. column indices are converted to deltas and then multiplied +// by 2 to convert to bytes, so that the value can be used directly to offset +// the pointer into the rhs vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if::value && + std::is_same::value && + std::is_same::value>::type +SpMV_4x4(const bfloat16* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const float* rhs_ptr, + const float* bias_ptr, float* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + /* This instrinsic version exists for reference, note that in the + intrinsic version col_deltas_bytes should NOT actually be in bytes, + but rather elements. Intrinsics are 25-35% slower than the + assembly version. + + for (int r = 0; r < rows; r += 4) { + int reduced_col_count = nnz_per_row[r / 4]; + float32x4_t accum0 = vdupq_n_f32(bias_ptr + r); + float32x4_t accum1 = vdupq_n_f32(bias_ptr + r + 1); + float32x4_t accum2 = vdupq_n_f32(bias_ptr + r + 2); + float32x4_t accum3 = vdupq_n_f32(bias_ptr + r + 3); + for (int c = 0; c < reduced_col_count; ++c) { + int32_t offset = *col_deltas_bytes; col_deltas_bytes++; + rhs_ptr += offset; + float32x4_t rhs = vld1q_f32(rhs_ptr); + + uint16x4_t lhs0_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs1_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs2_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs3_int = vld1_u16(weights_ptr); weights_ptr += 4; + + float32x4_t lhs0 = vreinterpretq_f32_u32(vshll_n_u16(lhs0_int, 16)); + float32x4_t lhs1 = vreinterpretq_f32_u32(vshll_n_u16(lhs1_int, 16)); + float32x4_t lhs2 = vreinterpretq_f32_u32(vshll_n_u16(lhs2_int, 16)); + float32x4_t lhs3 = vreinterpretq_f32_u32(vshll_n_u16(lhs3_int, 16)); + + accum0 = vmlaq_f32(accum0, lhs0, rhs); + accum1 = vmlaq_f32(accum1, lhs1, rhs); + accum2 = vmlaq_f32(accum2, lhs2, rhs); + accum3 = vmlaq_f32(accum3, lhs3, rhs); + } + + float32x4_t reduce0 = vpaddq_f32(accum0, accum1); + float32x4_t reduce1 = vpaddq_f32(accum2, accum3); + float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); + vst1q_f32(out_ptr + r, reduce2); + } */ + + // If the relu is handled in the routine with a comparison and vbit (insert + // if true), or by branching, then it is slightly, but noticeably slower + // ~5%, the outer branch avoids that penalty. + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Convert bfloat16 -> float32. + "shll v4.4s, v2.4h, #16\n" + "shll2 v5.4s, v2.8h, #16\n" + "shll v6.4s, v3.4h, #16\n" + "shll2 v7.4s, v3.8h, #16\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "fmla v30.4s, v6.4s, v0.4s\n" + "fmla v31.4s, v7.4s, v0.4s\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result + "faddp v28.4s, v28.4s, v29.4s\n" + "faddp v30.4s, v30.4s, v31.4s\n" + "faddp v28.4s, v28.4s, v30.4s\n" + + // Do relu if requested. + "fmax v28.4s, v28.4s, v25.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Convert bfloat16 -> float32. + "shll v4.4s, v2.4h, #16\n" + "shll2 v5.4s, v2.8h, #16\n" + "shll v6.4s, v3.4h, #16\n" + "shll2 v7.4s, v3.8h, #16\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "fmla v30.4s, v6.4s, v0.4s\n" + "fmla v31.4s, v7.4s, v0.4s\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "faddp v28.4s, v28.4s, v29.4s\n" + "faddp v30.4s, v30.4s, v31.4s\n" + "faddp v28.4s, v28.4s, v30.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a fat vector with 5 columns and b is vector. b is +// broadcast. Weights are stored for this routine by making each 4x4 block +// contiguous. Blocks are ordered in standard row-major format. column indices +// are converted to deltas and then multiplied by 2 to convert to bytes, so +// that the value can be used directly to offset the pointer into the rhs +// vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if::value && + std::is_same::value && + std::is_same::value>::type +SpMM5_4x4(const bfloat16* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const float* rhs_ptr, + const float* bias_ptr, float* out_ptr, int64_t assigned_rows, + int64_t rows, int64_t cols, int relu) { + /* This instrinsic version exists for reference, note that in the + intrinsic version col_deltas_bytes should NOT actually be in bytes, + but rather elements. Intrinsics are 25-35% slower than the + assembly version. + + for (int r = 0; r < rows; r += 4) { + int reduced_col_count = nnz_per_row[r / 4]; + float32x4_t accum0 = vdupq_n_f32(bias_ptr + r); + float32x4_t accum1 = vdupq_n_f32(bias_ptr + r + 1); + float32x4_t accum2 = vdupq_n_f32(bias_ptr + r + 2); + float32x4_t accum3 = vdupq_n_f32(bias_ptr + r + 3); + float32x4_t accum4 = vdupq_n_f32(bias_ptr + r); + float32x4_t accum5 = vdupq_n_f32(bias_ptr + r + 1); + float32x4_t accum6 = vdupq_n_f32(bias_ptr + r + 2); + float32x4_t accum7 = vdupq_n_f32(bias_ptr + r + 3); + ... + for (int c = 0; c < reduced_col_count; ++c) { + int32_t offset = *col_deltas_bytes; col_deltas_bytes++; + rhs_ptr += offset; + float32x4_t rhs = vld1q_f32(rhs_ptr); + float32x4_t rhs2 = vld1q_f32(rhs2_ptr); + float32x4_t rhs3 = vld1q_f32(rhs3_ptr); + float32x4_t rhs4 = vld1q_f32(rhs4_ptr); + float32x4_t rhs5 = vld1q_f32(rhs5_ptr); + + uint16x4_t lhs0_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs1_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs2_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs3_int = vld1_u16(weights_ptr); weights_ptr += 4; + + float32x4_t lhs0 = vreinterpretq_f32_u32(vshll_n_u16(lhs0_int, 16)); + float32x4_t lhs1 = vreinterpretq_f32_u32(vshll_n_u16(lhs1_int, 16)); + float32x4_t lhs2 = vreinterpretq_f32_u32(vshll_n_u16(lhs2_int, 16)); + float32x4_t lhs3 = vreinterpretq_f32_u32(vshll_n_u16(lhs3_int, 16)); + + accum0 = vmlaq_f32(accum0, lhs0, rhs); + accum1 = vmlaq_f32(accum1, lhs1, rhs); + accum2 = vmlaq_f32(accum2, lhs2, rhs); + accum3 = vmlaq_f32(accum3, lhs3, rhs); + accum4 = vmlaq_f32(accum0, lhs0, rhs2); + accum5 = vmlaq_f32(accum1, lhs1, rhs2); + accum6 = vmlaq_f32(accum2, lhs2, rhs2); + accum7 = vmlaq_f32(accum3, lhs3, rhs2); + ... + } + + float32x4_t reduce0 = vpaddq_f32(accum0, accum1); + float32x4_t reduce1 = vpaddq_f32(accum2, accum3); + float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); + vst1q_f32(out_ptr + r, reduce2); + + float32x4_t reduce0 = vpaddq_f32(accum4, accum5); + float32x4_t reduce1 = vpaddq_f32(accum6, accum7); + float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); + vst1q_f32(out2_ptr + r, reduce2); + + ... + } */ + + // If the relu is handled in the routine with a comparison and vbit (insert + // if true), or by branching, then it is slightly, but noticeably slower + // ~5%, the outer branch avoids that penalty. + // + // Pointers to the columns. + const float* rhs2_ptr = rhs_ptr + cols; + float* out2_ptr = out_ptr + rows; + const float* rhs3_ptr = rhs_ptr + 2 * cols; + float* out3_ptr = out_ptr + 2 * rows; + const float* rhs4_ptr = rhs_ptr + 3 * cols; + float* out4_ptr = out_ptr + 3 * rows; + const float* rhs5_ptr = rhs_ptr + 4 * cols; + float* out5_ptr = out_ptr + 4 * rows; + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + "ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" + "ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" + "ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" + "ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Convert bfloat16 -> float32. + "shll v4.4s, v2.4h, #16\n" + "shll2 v5.4s, v2.8h, #16\n" + "shll v6.4s, v3.4h, #16\n" + "shll2 v7.4s, v3.8h, #16\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" // for 1st column + "fmla v29.4s, v5.4s, v0.4s\n" // for 1st column + "fmla v30.4s, v6.4s, v0.4s\n" // for 1st column + "fmla v31.4s, v7.4s, v0.4s\n" // for 1st column + "fmla v23.4s, v4.4s, v1.4s\n" // for 2nd column + "fmla v24.4s, v5.4s, v1.4s\n" // for 2nd column + "fmla v25.4s, v6.4s, v1.4s\n" // for 2nd column + "fmla v26.4s, v7.4s, v1.4s\n" // for 2nd column + "fmla v19.4s, v4.4s, v8.4s\n" // for 3rd column + "fmla v20.4s, v5.4s, v8.4s\n" // for 3rd column + "fmla v21.4s, v6.4s, v8.4s\n" // for 3rd column + "fmla v22.4s, v7.4s, v8.4s\n" // for 3rd column + "fmla v15.4s, v4.4s, v9.4s\n" // for 4th column + "fmla v16.4s, v5.4s, v9.4s\n" // for 4th column + "fmla v17.4s, v6.4s, v9.4s\n" // for 4th column + "fmla v18.4s, v7.4s, v9.4s\n" // for 4th column + "fmla v11.4s, v4.4s, v10.4s\n" // for 5th column + "fmla v12.4s, v5.4s, v10.4s\n" // for 5th column + "fmla v13.4s, v6.4s, v10.4s\n" // for 5th column + "fmla v14.4s, v7.4s, v10.4s\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "movi v0.4s, #0\n" + "faddp v28.4s, v28.4s, v29.4s\n" // 1st column + "faddp v23.4s, v23.4s, v24.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v20.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v16.4s\n" // 4th column + "faddp v11.4s, v11.4s, v12.4s\n" // 5th column + + "faddp v30.4s, v30.4s, v31.4s\n" // 1st column + "faddp v25.4s, v25.4s, v26.4s\n" // 2nd column + "faddp v21.4s, v21.4s, v22.4s\n" // 3rd column + "faddp v17.4s, v17.4s, v18.4s\n" // 4th column + "faddp v13.4s, v13.4s, v14.4s\n" // 5th column + + "faddp v28.4s, v28.4s, v30.4s\n" // 1st column + "faddp v23.4s, v23.4s, v25.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v21.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v17.4s\n" // 4th column + "faddp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Do relu as requested. + "fmax v28.4s, v28.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), + [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), + [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + "ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" + "ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" + "ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" + "ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Convert bfloat16 -> float32. + "shll v4.4s, v2.4h, #16\n" + "shll2 v5.4s, v2.8h, #16\n" + "shll v6.4s, v3.4h, #16\n" + "shll2 v7.4s, v3.8h, #16\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" // for 1st column + "fmla v29.4s, v5.4s, v0.4s\n" // for 1st column + "fmla v30.4s, v6.4s, v0.4s\n" // for 1st column + "fmla v31.4s, v7.4s, v0.4s\n" // for 1st column + "fmla v23.4s, v4.4s, v1.4s\n" // for 2nd column + "fmla v24.4s, v5.4s, v1.4s\n" // for 2nd column + "fmla v25.4s, v6.4s, v1.4s\n" // for 2nd column + "fmla v26.4s, v7.4s, v1.4s\n" // for 2nd column + "fmla v19.4s, v4.4s, v8.4s\n" // for 3rd column + "fmla v20.4s, v5.4s, v8.4s\n" // for 3rd column + "fmla v21.4s, v6.4s, v8.4s\n" // for 3rd column + "fmla v22.4s, v7.4s, v8.4s\n" // for 3rd column + "fmla v15.4s, v4.4s, v9.4s\n" // for 4th column + "fmla v16.4s, v5.4s, v9.4s\n" // for 4th column + "fmla v17.4s, v6.4s, v9.4s\n" // for 4th column + "fmla v18.4s, v7.4s, v9.4s\n" // for 4th column + "fmla v11.4s, v4.4s, v10.4s\n" // for 5th column + "fmla v12.4s, v5.4s, v10.4s\n" // for 5th column + "fmla v13.4s, v6.4s, v10.4s\n" // for 5th column + "fmla v14.4s, v7.4s, v10.4s\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "faddp v28.4s, v28.4s, v29.4s\n" // 1st column + "faddp v23.4s, v23.4s, v24.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v20.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v16.4s\n" // 4th column + "faddp v11.4s, v11.4s, v12.4s\n" // 5th column + + "faddp v30.4s, v30.4s, v31.4s\n" // 1st column + "faddp v25.4s, v25.4s, v26.4s\n" // 2nd column + "faddp v21.4s, v21.4s, v22.4s\n" // 3rd column + "faddp v17.4s, v17.4s, v18.4s\n" // 4th column + "faddp v13.4s, v13.4s, v14.4s\n" // 5th column + + "faddp v28.4s, v28.4s, v30.4s\n" // 1st column + "faddp v23.4s, v23.4s, v25.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v21.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v17.4s\n" // 4th column + "faddp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), + [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), + [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// float implementations below the line. + +template +typename std::enable_if::value && + std::is_same::value && + std::is_same::value>::type +SpMV_4x4(const float* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const float* rhs_ptr, + const float* bias_ptr, float* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + /* This instrinsic version exists for reference, note that in the + intrinsic version col_deltas_bytes should NOT actually be in bytes, + but rather elements. Intrinsics are 25-35% slower than the + assembly version. + + for (int r = 0; r < rows; r += 4) { + int reduced_col_count = nnz_per_row[r / 4]; + float32x4_t accum0 = vdupq_n_f32(bias_ptr + r); + float32x4_t accum1 = vdupq_n_f32(bias_ptr + r + 1); + float32x4_t accum2 = vdupq_n_f32(bias_ptr + r + 2); + float32x4_t accum3 = vdupq_n_f32(bias_ptr + r + 3); + for (int c = 0; c < reduced_col_count; ++c) { + int32_t offset = *col_deltas_bytes; col_deltas_bytes++; + rhs_ptr += offset; + float32x4_t rhs = vld1q_f32(rhs_ptr); + + uint16x4_t lhs0_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs1_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs2_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs3_int = vld1_u16(weights_ptr); weights_ptr += 4; + + float32x4_t lhs0 = vreinterpretq_f32_u32(vshll_n_u16(lhs0_int, 16)); + float32x4_t lhs1 = vreinterpretq_f32_u32(vshll_n_u16(lhs1_int, 16)); + float32x4_t lhs2 = vreinterpretq_f32_u32(vshll_n_u16(lhs2_int, 16)); + float32x4_t lhs3 = vreinterpretq_f32_u32(vshll_n_u16(lhs3_int, 16)); + + accum0 = vmlaq_f32(accum0, lhs0, rhs); + accum1 = vmlaq_f32(accum1, lhs1, rhs); + accum2 = vmlaq_f32(accum2, lhs2, rhs); + accum3 = vmlaq_f32(accum3, lhs3, rhs); + } + + float32x4_t reduce0 = vpaddq_f32(accum0, accum1); + float32x4_t reduce1 = vpaddq_f32(accum2, accum3); + float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); + vst1q_f32(out_ptr + r, reduce2); + } */ + + // If the relu is handled in the routine with a comparison and vbit (insert + // if true), or by branching, then it is slightly, but noticeably slower + // ~5%, the outer branch avoids that penalty. + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "fmla v30.4s, v6.4s, v0.4s\n" + "fmla v31.4s, v7.4s, v0.4s\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "faddp v28.4s, v28.4s, v29.4s\n" + "faddp v30.4s, v30.4s, v31.4s\n" + "faddp v28.4s, v28.4s, v30.4s\n" + + // Do relu as requested. + "fmax v28.4s, v28.4s, v25.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "fmla v30.4s, v6.4s, v0.4s\n" + "fmla v31.4s, v7.4s, v0.4s\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "faddp v28.4s, v28.4s, v29.4s\n" + "faddp v30.4s, v30.4s, v31.4s\n" + "faddp v28.4s, v28.4s, v30.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a fat vector with 5 columns and b is vector. b is +// broadcast. Weights are stored for this routine by making each 4x4 block +// contiguous. Blocks are ordered in standard row-major format. column indices +// are converted to deltas and then multiplied by 2 to convert to bytes, so +// that the value can be used directly to offset the pointer into the rhs +// vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in sparse_linear_layer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if::value && + std::is_same::value && + std::is_same::value>::type +SpMM5_4x4(const float* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const float* rhs_ptr, + const float* bias_ptr, float* out_ptr, int64_t assigned_rows, + int64_t rows, int64_t cols, int relu) { + /* This instrinsic version exists for reference, note that in the + intrinsic version col_deltas_bytes should NOT actually be in bytes, + but rather elements. Intrinsics are 25-35% slower than the + assembly version. + + for (int r = 0; r < rows; r += 4) { + int reduced_col_count = nnz_per_row[r / 4]; + float32x4_t accum0 = vdupq_n_f32(bias_ptr + r); + float32x4_t accum1 = vdupq_n_f32(bias_ptr + r + 1); + float32x4_t accum2 = vdupq_n_f32(bias_ptr + r + 2); + float32x4_t accum3 = vdupq_n_f32(bias_ptr + r + 3); + float32x4_t accum4 = vdupq_n_f32(bias_ptr + r); + float32x4_t accum5 = vdupq_n_f32(bias_ptr + r + 1); + float32x4_t accum6 = vdupq_n_f32(bias_ptr + r + 2); + float32x4_t accum7 = vdupq_n_f32(bias_ptr + r + 3); + ... + for (int c = 0; c < reduced_col_count; ++c) { + int32_t offset = *col_deltas_bytes; col_deltas_bytes++; + rhs_ptr += offset; + float32x4_t rhs = vld1q_f32(rhs_ptr); + float32x4_t rhs2 = vld1q_f32(rhs2_ptr); + float32x4_t rhs3 = vld1q_f32(rhs3_ptr); + float32x4_t rhs4 = vld1q_f32(rhs4_ptr); + float32x4_t rhs5 = vld1q_f32(rhs5_ptr); + + uint16x4_t lhs0_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs1_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs2_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs3_int = vld1_u16(weights_ptr); weights_ptr += 4; + + float32x4_t lhs0 = vreinterpretq_f32_u32(vshll_n_u16(lhs0_int, 16)); + float32x4_t lhs1 = vreinterpretq_f32_u32(vshll_n_u16(lhs1_int, 16)); + float32x4_t lhs2 = vreinterpretq_f32_u32(vshll_n_u16(lhs2_int, 16)); + float32x4_t lhs3 = vreinterpretq_f32_u32(vshll_n_u16(lhs3_int, 16)); + + accum0 = vmlaq_f32(accum0, lhs0, rhs); + accum1 = vmlaq_f32(accum1, lhs1, rhs); + accum2 = vmlaq_f32(accum2, lhs2, rhs); + accum3 = vmlaq_f32(accum3, lhs3, rhs); + accum4 = vmlaq_f32(accum0, lhs0, rhs2); + accum5 = vmlaq_f32(accum1, lhs1, rhs2); + accum6 = vmlaq_f32(accum2, lhs2, rhs2); + accum7 = vmlaq_f32(accum3, lhs3, rhs2); + ... + } + + float32x4_t reduce0 = vpaddq_f32(accum0, accum1); + float32x4_t reduce1 = vpaddq_f32(accum2, accum3); + float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); + vst1q_f32(out_ptr + r, reduce2); + + float32x4_t reduce0 = vpaddq_f32(accum4, accum5); + float32x4_t reduce1 = vpaddq_f32(accum6, accum7); + float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); + vst1q_f32(out2_ptr + r, reduce2); + + ... + } */ + + // If the relu is handled in the routine with a comparison and vbit (insert + // if true), or by branching, then it is slightly, but noticeably slower + // ~5%, the outer branch avoids that penalty. + // + // Pointers to the columns. + const float* rhs2_ptr = rhs_ptr + cols; + float* out2_ptr = out_ptr + rows; + const float* rhs3_ptr = rhs_ptr + 2 * cols; + float* out3_ptr = out_ptr + 2 * rows; + const float* rhs4_ptr = rhs_ptr + 3 * cols; + float* out4_ptr = out_ptr + 3 * rows; + const float* rhs5_ptr = rhs_ptr + 4 * cols; + float* out5_ptr = out_ptr + 4 * rows; + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + "ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" + "ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" + "ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" + "ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" // for 1st column + "fmla v29.4s, v5.4s, v0.4s\n" // for 1st column + "fmla v30.4s, v6.4s, v0.4s\n" // for 1st column + "fmla v31.4s, v7.4s, v0.4s\n" // for 1st column + "fmla v23.4s, v4.4s, v1.4s\n" // for 2nd column + "fmla v24.4s, v5.4s, v1.4s\n" // for 2nd column + "fmla v25.4s, v6.4s, v1.4s\n" // for 2nd column + "fmla v26.4s, v7.4s, v1.4s\n" // for 2nd column + "fmla v19.4s, v4.4s, v8.4s\n" // for 3rd column + "fmla v20.4s, v5.4s, v8.4s\n" // for 3rd column + "fmla v21.4s, v6.4s, v8.4s\n" // for 3rd column + "fmla v22.4s, v7.4s, v8.4s\n" // for 3rd column + "fmla v15.4s, v4.4s, v9.4s\n" // for 4th column + "fmla v16.4s, v5.4s, v9.4s\n" // for 4th column + "fmla v17.4s, v6.4s, v9.4s\n" // for 4th column + "fmla v18.4s, v7.4s, v9.4s\n" // for 4th column + "fmla v11.4s, v4.4s, v10.4s\n" // for 5th column + "fmla v12.4s, v5.4s, v10.4s\n" // for 5th column + "fmla v13.4s, v6.4s, v10.4s\n" // for 5th column + "fmla v14.4s, v7.4s, v10.4s\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "movi v0.4s, #0\n" + "faddp v28.4s, v28.4s, v29.4s\n" // 1st column + "faddp v23.4s, v23.4s, v24.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v20.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v16.4s\n" // 4th column + "faddp v11.4s, v11.4s, v12.4s\n" // 5th column + + "faddp v30.4s, v30.4s, v31.4s\n" // 1st column + "faddp v25.4s, v25.4s, v26.4s\n" // 2nd column + "faddp v21.4s, v21.4s, v22.4s\n" // 3rd column + "faddp v17.4s, v17.4s, v18.4s\n" // 4th column + "faddp v13.4s, v13.4s, v14.4s\n" // 5th column + + "faddp v28.4s, v28.4s, v30.4s\n" // 1st column + "faddp v23.4s, v23.4s, v25.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v21.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v17.4s\n" // 4th column + "faddp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Do relu as requested. + "fmax v28.4s, v28.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), + [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), + [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + "ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" + "ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" + "ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" + "ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" // for 1st column + "fmla v29.4s, v5.4s, v0.4s\n" // for 1st column + "fmla v30.4s, v6.4s, v0.4s\n" // for 1st column + "fmla v31.4s, v7.4s, v0.4s\n" // for 1st column + "fmla v23.4s, v4.4s, v1.4s\n" // for 2nd column + "fmla v24.4s, v5.4s, v1.4s\n" // for 2nd column + "fmla v25.4s, v6.4s, v1.4s\n" // for 2nd column + "fmla v26.4s, v7.4s, v1.4s\n" // for 2nd column + "fmla v19.4s, v4.4s, v8.4s\n" // for 3rd column + "fmla v20.4s, v5.4s, v8.4s\n" // for 3rd column + "fmla v21.4s, v6.4s, v8.4s\n" // for 3rd column + "fmla v22.4s, v7.4s, v8.4s\n" // for 3rd column + "fmla v15.4s, v4.4s, v9.4s\n" // for 4th column + "fmla v16.4s, v5.4s, v9.4s\n" // for 4th column + "fmla v17.4s, v6.4s, v9.4s\n" // for 4th column + "fmla v18.4s, v7.4s, v9.4s\n" // for 4th column + "fmla v11.4s, v4.4s, v10.4s\n" // for 5th column + "fmla v12.4s, v5.4s, v10.4s\n" // for 5th column + "fmla v13.4s, v6.4s, v10.4s\n" // for 5th column + "fmla v14.4s, v7.4s, v10.4s\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "faddp v28.4s, v28.4s, v29.4s\n" // 1st column + "faddp v23.4s, v23.4s, v24.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v20.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v16.4s\n" // 4th column + "faddp v11.4s, v11.4s, v12.4s\n" // 5th column + + "faddp v30.4s, v30.4s, v31.4s\n" // 1st column + "faddp v25.4s, v25.4s, v26.4s\n" // 2nd column + "faddp v21.4s, v21.4s, v22.4s\n" // 3rd column + "faddp v17.4s, v17.4s, v18.4s\n" // 4th column + "faddp v13.4s, v13.4s, v14.4s\n" // 5th column + + "faddp v28.4s, v28.4s, v30.4s\n" // 1st column + "faddp v23.4s, v23.4s, v25.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v21.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v17.4s\n" // 4th column + "faddp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), + [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), + [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Note that the number of exponent bits in the output must exactly match +// the sum of the input and rhs types. +template +typename std::enable_if< + IsFixed16Type::value && IsFixed16Type::value && + std::is_same::type>::value>::type +SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + // Duplicate the lower half into the upper half. + "mov v0.d[1], v0.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" + "smlal2 v29.4s, v2.8h, v0.8h\n" + "smlal v30.4s, v3.4h, v0.4h\n" + "smlal2 v31.4s, v3.8h, v0.8h\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + // Do relu if requested. + "smax v28.4s, v28.4s, v25.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + // Duplicate the lower half into the upper half. + "mov v0.d[1], v0.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" + "smlal2 v29.4s, v2.8h, v0.8h\n" + "smlal v30.4s, v3.4h, v0.4h\n" + "smlal2 v31.4s, v3.8h, v0.8h\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Note that the number of exponent bits in the output must exactly match +// the sum of the input and rhs types. +template +typename std::enable_if< + IsFixed16Type::value && IsFixed16Type::value && + std::is_same::type>::value>::type +SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, + int relu) { + // Pointers to the columns. + const RhsType* rhs2_ptr = rhs_ptr + cols; + OutType* out2_ptr = out_ptr + rows; + const RhsType* rhs3_ptr = rhs_ptr + 2 * cols; + OutType* out3_ptr = out_ptr + 2 * rows; + const RhsType* rhs4_ptr = rhs_ptr + 3 * cols; + OutType* out4_ptr = out_ptr + 3 * rows; + const RhsType* rhs5_ptr = rhs_ptr + 4 * cols; + OutType* out5_ptr = out_ptr + 4 * rows; + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + "mov v0.d[1], v0.d[0]\n" + "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" + "mov v1.d[1], v1.d[0]\n" + "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" + "mov v8.d[1], v8.d[0]\n" + "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" + "mov v9.d[1], v9.d[0]\n" + "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" + "mov v10.d[1], v10.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column + "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column + "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column + "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh + "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column + "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column + "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column + "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column + "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column + "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column + "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column + "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column + "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column + "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column + "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column + "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column + "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column + "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column + "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column + "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "movi v0.4s, #0\n" + "addp v28.4s, v28.4s, v29.4s\n" // 1st column + "addp v23.4s, v23.4s, v24.4s\n" // 2nd column + "addp v19.4s, v19.4s, v20.4s\n" // 3rd column + "addp v15.4s, v15.4s, v16.4s\n" // 4th column + "addp v11.4s, v11.4s, v12.4s\n" // 5th column + + "addp v30.4s, v30.4s, v31.4s\n" // 1st column + "addp v25.4s, v25.4s, v26.4s\n" // 2nd column + "addp v21.4s, v21.4s, v22.4s\n" // 3rd column + "addp v17.4s, v17.4s, v18.4s\n" // 4th column + "addp v13.4s, v13.4s, v14.4s\n" // 5th column + + "addp v28.4s, v28.4s, v30.4s\n" // 1st column + "addp v23.4s, v23.4s, v25.4s\n" // 2nd column + "addp v19.4s, v19.4s, v21.4s\n" // 3rd column + "addp v15.4s, v15.4s, v17.4s\n" // 4th column + "addp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Do relu as requested. + "smax v28.4s, v28.4s, v0.4s\n" + "smax v23.4s, v23.4s, v0.4s\n" + "smax v19.4s, v19.4s, v0.4s\n" + "smax v15.4s, v15.4s, v0.4s\n" + "smax v11.4s, v11.4s, v0.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + "mov v0.d[1], v0.d[0]\n" + "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" + "mov v1.d[1], v1.d[0]\n" + "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" + "mov v8.d[1], v8.d[0]\n" + "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" + "mov v9.d[1], v9.d[0]\n" + "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" + "mov v10.d[1], v10.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column + "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column + "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column + "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh + "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column + "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column + "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column + "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column + "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column + "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column + "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column + "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column + "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column + "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column + "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column + "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column + "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column + "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column + "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column + "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "addp v28.4s, v28.4s, v29.4s\n" // 1st column + "addp v23.4s, v23.4s, v24.4s\n" // 2nd column + "addp v19.4s, v19.4s, v20.4s\n" // 3rd column + "addp v15.4s, v15.4s, v16.4s\n" // 4th column + "addp v11.4s, v11.4s, v12.4s\n" // 5th column + + "addp v30.4s, v30.4s, v31.4s\n" // 1st column + "addp v25.4s, v25.4s, v26.4s\n" // 2nd column + "addp v21.4s, v21.4s, v22.4s\n" // 3rd column + "addp v17.4s, v17.4s, v18.4s\n" // 4th column + "addp v13.4s, v13.4s, v14.4s\n" // 5th column + + "addp v28.4s, v28.4s, v30.4s\n" // 1st column + "addp v23.4s, v23.4s, v25.4s\n" // 2nd column + "addp v19.4s, v19.4s, v21.4s\n" // 3rd column + "addp v15.4s, v15.4s, v17.4s\n" // 4th column + "addp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Note that the number of exponent bits in the bias must exactly match +// the sum of the input and rhs types. +template +typename std::enable_if::value && + IsFixed16Type::value && + IsFixed16Type::value>::type +SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + constexpr int kShiftAmount = 15 - WeightType::kExponentBits - + RhsType::kExponentBits + OutType::kExponentBits; + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + // Duplicate the lower half into the upper half. + "mov v0.d[1], v0.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" + "smlal2 v29.4s, v2.8h, v0.8h\n" + "smlal v30.4s, v3.4h, v0.4h\n" + "smlal2 v31.4s, v3.8h, v0.8h\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + // Do relu if requested. + "smax v28.4s, v28.4s, v25.4s\n" + "sqrshrn v26.4h, v28.4s, %[shift_amount]\n" + + // Store accumulators. + "st1 {v26.4h}, [%[out_ptr]], #8\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + // Duplicate the lower half into the upper half. + "mov v0.d[1], v0.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" + "smlal2 v29.4s, v2.8h, v0.8h\n" + "smlal v30.4s, v3.4h, v0.4h\n" + "smlal2 v31.4s, v3.8h, v0.8h\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + "sqrshrn v26.4h, v28.4s, %[shift_amount]\n" + + // Store accumulators. + "st1 {v26.4h}, [%[out_ptr]], #8\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Note that the number of exponent bits in the output must exactly match +// the sum of the input and rhs types. +template +typename std::enable_if::value && + IsFixed16Type::value && + IsFixed16Type::value>::type +SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, + int relu) { + constexpr int kShiftAmount = 15 - WeightType::kExponentBits - + RhsType::kExponentBits + OutType::kExponentBits; + // Pointers to the columns. + const RhsType* rhs2_ptr = rhs_ptr + cols; + OutType* out2_ptr = out_ptr + rows; + const RhsType* rhs3_ptr = rhs_ptr + 2 * cols; + OutType* out3_ptr = out_ptr + 2 * rows; + const RhsType* rhs4_ptr = rhs_ptr + 3 * cols; + OutType* out4_ptr = out_ptr + 3 * rows; + const RhsType* rhs5_ptr = rhs_ptr + 4 * cols; + OutType* out5_ptr = out_ptr + 4 * rows; + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + "mov v0.d[1], v0.d[0]\n" + "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" + "mov v1.d[1], v1.d[0]\n" + "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" + "mov v8.d[1], v8.d[0]\n" + "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" + "mov v9.d[1], v9.d[0]\n" + "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" + "mov v10.d[1], v10.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column + "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column + "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column + "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh + "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column + "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column + "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column + "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column + "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column + "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column + "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column + "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column + "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column + "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column + "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column + "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column + "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column + "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column + "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column + "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "movi v0.4s, #0\n" + "addp v28.4s, v28.4s, v29.4s\n" // 1st column + "addp v23.4s, v23.4s, v24.4s\n" // 2nd column + "addp v19.4s, v19.4s, v20.4s\n" // 3rd column + "addp v15.4s, v15.4s, v16.4s\n" // 4th column + "addp v11.4s, v11.4s, v12.4s\n" // 5th column + + "addp v30.4s, v30.4s, v31.4s\n" // 1st column + "addp v25.4s, v25.4s, v26.4s\n" // 2nd column + "addp v21.4s, v21.4s, v22.4s\n" // 3rd column + "addp v17.4s, v17.4s, v18.4s\n" // 4th column + "addp v13.4s, v13.4s, v14.4s\n" // 5th column + + "addp v28.4s, v28.4s, v30.4s\n" // 1st column + "addp v23.4s, v23.4s, v25.4s\n" // 2nd column + "addp v19.4s, v19.4s, v21.4s\n" // 3rd column + "addp v15.4s, v15.4s, v17.4s\n" // 4th column + "addp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Do relu as requested. + "smax v28.4s, v28.4s, v0.4s\n" + "smax v23.4s, v23.4s, v0.4s\n" + "smax v19.4s, v19.4s, v0.4s\n" + "smax v15.4s, v15.4s, v0.4s\n" + "smax v11.4s, v11.4s, v0.4s\n" + "sqrshrn v26.4h, v28.4s, %[shift_amount]\n" + "sqrshrn v22.4h, v23.4s, %[shift_amount]\n" + "sqrshrn v18.4h, v19.4s, %[shift_amount]\n" + "sqrshrn v14.4h, v15.4s, %[shift_amount]\n" + "sqrshrn v10.4h, v11.4s, %[shift_amount]\n" + + // Store accumulators. + "st1 {v26.4h}, [%[out_ptr]], #8\n" + "st1 {v22.4h}, [%[out2_ptr]], #8\n" + "st1 {v18.4h}, [%[out3_ptr]], #8\n" + "st1 {v14.4h}, [%[out4_ptr]], #8\n" + "st1 {v10.4h}, [%[out5_ptr]], #8\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + "mov v0.d[1], v0.d[0]\n" + "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" + "mov v1.d[1], v1.d[0]\n" + "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" + "mov v8.d[1], v8.d[0]\n" + "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" + "mov v9.d[1], v9.d[0]\n" + "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" + "mov v10.d[1], v10.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column + "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column + "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column + "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh + "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column + "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column + "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column + "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column + "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column + "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column + "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column + "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column + "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column + "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column + "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column + "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column + "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column + "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column + "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column + "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "addp v28.4s, v28.4s, v29.4s\n" // 1st column + "addp v23.4s, v23.4s, v24.4s\n" // 2nd column + "addp v19.4s, v19.4s, v20.4s\n" // 3rd column + "addp v15.4s, v15.4s, v16.4s\n" // 4th column + "addp v11.4s, v11.4s, v12.4s\n" // 5th column + + "addp v30.4s, v30.4s, v31.4s\n" // 1st column + "addp v25.4s, v25.4s, v26.4s\n" // 2nd column + "addp v21.4s, v21.4s, v22.4s\n" // 3rd column + "addp v17.4s, v17.4s, v18.4s\n" // 4th column + "addp v13.4s, v13.4s, v14.4s\n" // 5th column + + "addp v28.4s, v28.4s, v30.4s\n" // 1st column + "addp v23.4s, v23.4s, v25.4s\n" // 2nd column + "addp v19.4s, v19.4s, v21.4s\n" // 3rd column + "addp v15.4s, v15.4s, v17.4s\n" // 4th column + "addp v11.4s, v11.4s, v13.4s\n" // 5th column + + "sqrshrn v26.4h, v28.4s, %[shift_amount]\n" + "sqrshrn v22.4h, v23.4s, %[shift_amount]\n" + "sqrshrn v18.4h, v19.4s, %[shift_amount]\n" + "sqrshrn v14.4h, v15.4s, %[shift_amount]\n" + "sqrshrn v10.4h, v11.4s, %[shift_amount]\n" + + // Store accumulators. + "st1 {v26.4h}, [%[out_ptr]], #8\n" + "st1 {v22.4h}, [%[out2_ptr]], #8\n" + "st1 {v18.4h}, [%[out3_ptr]], #8\n" + "st1 {v14.4h}, [%[out4_ptr]], #8\n" + "st1 {v10.4h}, [%[out5_ptr]], #8\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Note that the number of exponent bits in the output must exactly match +// the sum of the input and rhs types. +template +typename std::enable_if< + IsFixed16Type::value && IsFixed16Type::value && + IsFixed32Type::value && + !std::is_same::type>::value>::type +SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + constexpr int kShiftAmount = + TypeOfProduct::type::kMantissaBits - + OutType::kMantissaBits; + static_assert(kShiftAmount > 0, + "Result must have fewer mantissa bits than product"); + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + // Duplicate the lower half into the upper half. + "mov v0.d[1], v0.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" + "smlal2 v29.4s, v2.8h, v0.8h\n" + "smlal v30.4s, v3.4h, v0.4h\n" + "smlal2 v31.4s, v3.8h, v0.8h\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + // Do relu if requested. + "smax v28.4s, v28.4s, v25.4s\n" + "srshr v28.4s, v28.4s, %[shift_amount]\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + // Duplicate the lower half into the upper half. + "mov v0.d[1], v0.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" + "smlal2 v29.4s, v2.8h, v0.8h\n" + "smlal v30.4s, v3.4h, v0.4h\n" + "smlal2 v31.4s, v3.8h, v0.8h\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + "srshr v28.4s, v28.4s, %[shift_amount]\n" + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Note that the number of exponent bits in the output must exactly match +// the sum of the input and rhs types. +template +typename std::enable_if< + IsFixed16Type::value && IsFixed16Type::value && + IsFixed32Type::value && + !std::is_same::type>::value>::type +SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, + int relu) { + constexpr int kShiftAmount = + TypeOfProduct::type::kMantissaBits - + OutType::kMantissaBits; + static_assert(kShiftAmount > 0, + "Result must have fewer mantissa bits than product"); + // Pointers to the columns. + const RhsType* rhs2_ptr = rhs_ptr + cols; + OutType* out2_ptr = out_ptr + rows; + const RhsType* rhs3_ptr = rhs_ptr + 2 * cols; + OutType* out3_ptr = out_ptr + 2 * rows; + const RhsType* rhs4_ptr = rhs_ptr + 3 * cols; + OutType* out4_ptr = out_ptr + 3 * rows; + const RhsType* rhs5_ptr = rhs_ptr + 4 * cols; + OutType* out5_ptr = out_ptr + 4 * rows; + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + "mov v0.d[1], v0.d[0]\n" + "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" + "mov v1.d[1], v1.d[0]\n" + "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" + "mov v8.d[1], v8.d[0]\n" + "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" + "mov v9.d[1], v9.d[0]\n" + "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" + "mov v10.d[1], v10.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column + "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column + "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column + "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh + "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column + "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column + "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column + "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column + "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column + "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column + "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column + "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column + "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column + "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column + "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column + "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column + "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column + "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column + "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column + "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "movi v0.4s, #0\n" + "addp v28.4s, v28.4s, v29.4s\n" // 1st column + "addp v23.4s, v23.4s, v24.4s\n" // 2nd column + "addp v19.4s, v19.4s, v20.4s\n" // 3rd column + "addp v15.4s, v15.4s, v16.4s\n" // 4th column + "addp v11.4s, v11.4s, v12.4s\n" // 5th column + + "addp v30.4s, v30.4s, v31.4s\n" // 1st column + "addp v25.4s, v25.4s, v26.4s\n" // 2nd column + "addp v21.4s, v21.4s, v22.4s\n" // 3rd column + "addp v17.4s, v17.4s, v18.4s\n" // 4th column + "addp v13.4s, v13.4s, v14.4s\n" // 5th column + + "addp v28.4s, v28.4s, v30.4s\n" // 1st column + "addp v23.4s, v23.4s, v25.4s\n" // 2nd column + "addp v19.4s, v19.4s, v21.4s\n" // 3rd column + "addp v15.4s, v15.4s, v17.4s\n" // 4th column + "addp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Do relu as requested. + "smax v28.4s, v28.4s, v0.4s\n" + "smax v23.4s, v23.4s, v0.4s\n" + "smax v19.4s, v19.4s, v0.4s\n" + "smax v15.4s, v15.4s, v0.4s\n" + "smax v11.4s, v11.4s, v0.4s\n" + + "srshr v28.4s, v28.4s, %[shift_amount]\n" + "srshr v23.4s, v23.4s, %[shift_amount]\n" + "srshr v19.4s, v19.4s, %[shift_amount]\n" + "srshr v15.4s, v15.4s, %[shift_amount]\n" + "srshr v11.4s, v11.4s, %[shift_amount]\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + "mov v0.d[1], v0.d[0]\n" + "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" + "mov v1.d[1], v1.d[0]\n" + "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" + "mov v8.d[1], v8.d[0]\n" + "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" + "mov v9.d[1], v9.d[0]\n" + "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" + "mov v10.d[1], v10.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column + "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column + "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column + "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh + "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column + "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column + "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column + "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column + "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column + "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column + "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column + "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column + "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column + "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column + "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column + "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column + "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column + "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column + "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column + "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "addp v28.4s, v28.4s, v29.4s\n" // 1st column + "addp v23.4s, v23.4s, v24.4s\n" // 2nd column + "addp v19.4s, v19.4s, v20.4s\n" // 3rd column + "addp v15.4s, v15.4s, v16.4s\n" // 4th column + "addp v11.4s, v11.4s, v12.4s\n" // 5th column + + "addp v30.4s, v30.4s, v31.4s\n" // 1st column + "addp v25.4s, v25.4s, v26.4s\n" // 2nd column + "addp v21.4s, v21.4s, v22.4s\n" // 3rd column + "addp v17.4s, v17.4s, v18.4s\n" // 4th column + "addp v13.4s, v13.4s, v14.4s\n" // 5th column + + "addp v28.4s, v28.4s, v30.4s\n" // 1st column + "addp v23.4s, v23.4s, v25.4s\n" // 2nd column + "addp v19.4s, v19.4s, v21.4s\n" // 3rd column + "addp v15.4s, v15.4s, v17.4s\n" // 4th column + "addp v11.4s, v11.4s, v13.4s\n" // 5th column + + "srshr v28.4s, v28.4s, %[shift_amount]\n" + "srshr v23.4s, v23.4s, %[shift_amount]\n" + "srshr v19.4s, v19.4s, %[shift_amount]\n" + "srshr v15.4s, v15.4s, %[shift_amount]\n" + "srshr v11.4s, v11.4s, %[shift_amount]\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +template +typename std::enable_if::value>::type SumVectors( + int start, int end, const Type* add1, const Type* add2, Type* result) { + constexpr int kSIMDWidth = 4; + for (int i = start; i < end; i += kSIMDWidth) { + int32x4_t add1_int = vld1q_s32(reinterpret_cast(add1 + i)); + int32x4_t add2_int = vld1q_s32(reinterpret_cast(add2 + i)); + int32x4_t result_int = vqaddq_s32(add1_int, add2_int); + vst1q_s32(reinterpret_cast(result + i), result_int); + } +} + +template +typename std::enable_if::value>::type SumVectors( + int start, int end, const Type* add1, const Type* add2, Type* result) { + constexpr int kSIMDWidth = 8; + for (int i = start; i < end; i += kSIMDWidth) { + int16x8_t add1_int = vld1q_s16(reinterpret_cast(add1 + i)); + int16x8_t add2_int = vld1q_s16(reinterpret_cast(add2 + i)); + int16x8_t result_int = vqaddq_s16(add1_int, add2_int); + vst1q_s16(reinterpret_cast(result + i), result_int); + } +} + +} // namespace detail +} // namespace csrblocksparse + +#undef LABEL_COL_LOOP +#undef LABEL_ROW_LOOP +#undef LABEL_SKIP_COL_LOOP +#undef LABEL_TOP_LOOP + +#endif // defined __aarch64__ +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_ARM_H_ diff --git a/sparse_matmul/compute/kernels_avx.h b/sparse_matmul/compute/kernels_avx.h new file mode 100644 index 0000000000000000000000000000000000000000..a56fb9cdeabaeb5d2c2613f1fb84ef5c67be435d --- /dev/null +++ b/sparse_matmul/compute/kernels_avx.h @@ -0,0 +1,601 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_ + +#if defined __AVX__ +#include + +#include +#include +// TODO(b/188702959): Remove fast_transcendentals with GRU refactor. +#include "sparse_matmul/numerics/fast_transcendentals.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" +#include "sparse_matmul/numerics/type_utils.h" + +namespace csrblocksparse { +namespace detail { + +template +struct IsAllowableFloatTypes + : std::integral_constant::value && + std::is_same::value && + std::is_same::value> {}; + +#if defined __AVX2__ +// 16-bit inputs, 32-bit output exponent matches sum of input exponents +// OR +// 16-bit inputs, 16-bit output - will shift to match exponent +template +struct IsAllowableFixedTypes + : std::integral_constant::value && + IsFixed16Type::value) && + (IsFixed32Type::value || + IsFixed16Type::value)> {}; + +template +struct ShouldEnableGenericKernel + : std::integral_constant< + bool, + !IsAllowableFloatTypes::value && + !IsAllowableFixedTypes::value> {}; + +template +struct IsAddableFixedTypes + : std::integral_constant::value || + IsFixed16Type::value> {}; +template +struct ShouldEnableGenericAdd + : std::integral_constant::value> {}; + +#else // No AVX2. + +template +struct ShouldEnableGenericKernel + : std::integral_constant< + bool, !IsAllowableFloatTypes::value> {}; + +template +struct ShouldEnableGenericAdd : std::true_type {}; +#endif // __AVX2__ + +template +struct ShouldEnableGenericSpMV_4x4 + : ShouldEnableGenericKernel {}; +template +struct ShouldEnableGenericSpMM5_4x4 + : ShouldEnableGenericKernel {}; +template +struct ShouldEnableGenericSpMV_1x1 : std::true_type {}; +template +struct ShouldEnableGenericSpMM5_1x1 : std::true_type {}; + +// The computational routines do NO error checking for speed. It is assumed +// that this has been handled by CSRBlockSparseMatrix. + +// In-line function to extract results from a pair of registers and store in +// memory. Note that the non-const references are registers, and are modified +// by this function! +inline void Extract4Results(bool relu, __m256& sum1, __m256& sum2, + float** out_ptr) { + // Horizontally add the results. We have 2 registers, |sum1| and |sum2| that + // each contain 2 sets of 4 values that need to be added. + sum1 = _mm256_hadd_ps(sum1, sum2); + sum1 = _mm256_hadd_ps(sum1, sum1); + // Now |sum1| contains [|res0|, |res2|, |res0|, |res2|, |res1|, |res3|, + // |res1|, |res3|] + if (relu) { + sum1 = _mm256_max_ps(sum1, _mm256_setzero_ps()); + } + // It is really hard in AVX to cross the 128 bit 'lanes' and this is the + // *only* way to do it. + // Get the top half of |sum1| in to bottom of |sum2|. + sum2 = _mm256_permute2f128_ps(sum1, sum1, 1); + // Interleave the values between the two registers. + sum1 = _mm256_unpacklo_ps(sum1, sum2); + // Save the lower 128 bits (4 floats). + __m128 result = _mm256_extractf128_ps(sum1, 0); + _mm_store_ps(*out_ptr, result); + *out_ptr += 4; +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a vector and b is vector. Weights are stored for this +// routine by making each 4x4 block contiguous. Blocks are ordered in standard +// row-major format. column indices are converted to deltas and then multiplied +// by 2 to convert to bytes, so that the value can be used directly to offset +// the pointer into the rhs vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if::value && + std::is_same::value && + std::is_same::value>::type +SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { + // Broadcast the biases by 4 to undo the division by 4 in the input biases. + __m256 sum1 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), + _mm_broadcast_ss(bias_ptr)); + bias_ptr += 2; + __m256 sum2 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), + _mm_broadcast_ss(bias_ptr)); + bias_ptr += 2; + + int reduced_col_count = *nnz_per_row++; + for (int c = 0; c < reduced_col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + rhs_ptr += col_delta; + // Multiply this 4x4 block. + __m256 rhs = + _mm256_broadcast_ps(reinterpret_cast(rhs_ptr)); + __m256 weights1 = _mm256_load_ps(weights_ptr); + weights_ptr += 8; + sum1 = _mm256_add_ps(sum1, _mm256_mul_ps(weights1, rhs)); + __m256 weights2 = _mm256_load_ps(weights_ptr); + weights_ptr += 8; + sum2 = _mm256_add_ps(sum2, _mm256_mul_ps(weights2, rhs)); + } + Extract4Results(relu, sum1, sum2, &out_ptr); + } +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a fat vector with 5 columns and b is vector. b is +// broadcast. Weights are stored for this routine by making each 4x4 block +// contiguous. Blocks are ordered in standard row-major format. column indices +// are converted to deltas and then multiplied by 2 to convert to bytes, so +// that the value can be used directly to offset the pointer into the rhs +// vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if::value && + std::is_same::value && + std::is_same::value>::type +SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, + int relu) { + const RhsType* rhs_ptrs[5]; + for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; + + OutType* out_ptrs[5]; + for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; + + for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { + // We will acumulate the results in 10 registers, |sum1_0| to |sum2_4|. + // Broadcast the biases by 4 to undo the division by 4 in the input biases. + __m256 sum1_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), + _mm_broadcast_ss(bias_ptr)); + bias_ptr += 2; + __m256 sum2_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), + _mm_broadcast_ss(bias_ptr)); + bias_ptr += 2; + __m256 sum1_1 = sum1_0; + __m256 sum2_1 = sum2_0; + __m256 sum1_2 = sum1_0; + __m256 sum2_2 = sum2_0; + __m256 sum1_3 = sum1_0; + __m256 sum2_3 = sum2_0; + __m256 sum1_4 = sum1_0; + __m256 sum2_4 = sum2_0; + + int reduced_col_count = *nnz_per_row++; + for (int c = 0; c < reduced_col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta; + + // Multiply this 4x4 block. + __m256 rhs = + _mm256_broadcast_ps(reinterpret_cast(rhs_ptrs[0])); + __m256 weights1 = _mm256_load_ps(weights_ptr); + weights_ptr += 8; + sum1_0 = _mm256_add_ps(sum1_0, _mm256_mul_ps(weights1, rhs)); + __m256 weights2 = _mm256_load_ps(weights_ptr); + weights_ptr += 8; + sum2_0 = _mm256_add_ps(sum2_0, _mm256_mul_ps(weights2, rhs)); + rhs = _mm256_broadcast_ps(reinterpret_cast(rhs_ptrs[1])); + sum1_1 = _mm256_add_ps(sum1_1, _mm256_mul_ps(weights1, rhs)); + sum2_1 = _mm256_add_ps(sum2_1, _mm256_mul_ps(weights2, rhs)); + rhs = _mm256_broadcast_ps(reinterpret_cast(rhs_ptrs[2])); + sum1_2 = _mm256_add_ps(sum1_2, _mm256_mul_ps(weights1, rhs)); + sum2_2 = _mm256_add_ps(sum2_2, _mm256_mul_ps(weights2, rhs)); + rhs = _mm256_broadcast_ps(reinterpret_cast(rhs_ptrs[3])); + sum1_3 = _mm256_add_ps(sum1_3, _mm256_mul_ps(weights1, rhs)); + sum2_3 = _mm256_add_ps(sum2_3, _mm256_mul_ps(weights2, rhs)); + rhs = _mm256_broadcast_ps(reinterpret_cast(rhs_ptrs[4])); + sum1_4 = _mm256_add_ps(sum1_4, _mm256_mul_ps(weights1, rhs)); + sum2_4 = _mm256_add_ps(sum2_4, _mm256_mul_ps(weights2, rhs)); + } + + Extract4Results(relu, sum1_0, sum2_0, &out_ptrs[0]); + Extract4Results(relu, sum1_1, sum2_1, &out_ptrs[1]); + Extract4Results(relu, sum1_2, sum2_2, &out_ptrs[2]); + Extract4Results(relu, sum1_3, sum2_3, &out_ptrs[3]); + Extract4Results(relu, sum1_4, sum2_4, &out_ptrs[4]); + } +} + +#ifdef __AVX2__ + +// In-line function to finish the computation of the result as 4x int32 in +// |sum|. +inline void Compute4Results(bool relu, int kShiftAmount, __m256i& sum) { + // Horizontally add the results. We have 1 register that contains results + // [0 0 1 1 2 2 3 3], but hadd (and almost no other AVX instruction) will not + // cross lanes, so we end up with [0 1 0 1 2 3 2 3] + sum = _mm256_hadd_epi32(sum, sum); + // Permutes the middle two pairs to get the answers together. + sum = _mm256_permute4x64_epi64(sum, 0xd8); + if (kShiftAmount > 0) { + // Shift right with rounding to get the right number of mantissa bits. + __m256i rounding = _mm256_set1_epi32(1 << (kShiftAmount - 1)); + sum = _mm256_add_epi32(sum, rounding); + sum = _mm256_srai_epi32(sum, kShiftAmount); + } + // Now |sum| contains [|res0|, |res1|, |res2|, |res3|, |res0|, |res1|, + // |res2|, |res3|] + if (relu) { + sum = _mm256_max_epi32(sum, _mm256_setzero_si256()); + } +} + +// In-line function to extract the 4x int32 results from |sum| to memory. +// Non-const reference for |sum| as it is a register. +inline void Extract4xint32(bool relu, int kShiftAmount, __m256i& sum, + int32_t** out_ptr) { + Compute4Results(relu, kShiftAmount, sum); + // Save the lower 128 bits (4x int32). + __m128i result = _mm256_extractf128_si256(sum, 0); + _mm_store_si128(reinterpret_cast<__m128i*>(*out_ptr), result); + *out_ptr += 4; +} + +// In-line function to extract the 4x int32 results from sum to 4x int16 in +// memory. +// Non-const reference for |sum| as it is a register. +inline void Extract4xint16(bool relu, int kShiftAmount, __m256i& sum, + int16_t** out_ptr) { + Compute4Results(relu, kShiftAmount, sum); + // Clip to 16 bit range (with saturation) and pack in the bottom 64 bits. + // Converts the lower 4x int32 in bottom 128 bits to 4x int16 in bottom 64 + // bits, replicated in the next 64 bits. + sum = _mm256_packs_epi32(sum, sum); + // Save 4x int 16 from the bottom 64 bits. + *reinterpret_cast(*out_ptr) = _mm256_extract_epi64(sum, 0); + *out_ptr += 4; +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a vector and b is vector. Weights are stored for this +// routine by making each 4x4 block contiguous. Blocks are ordered in standard +// row-major format. column indices are converted to deltas and then multiplied +// by 2 to convert to bytes, so that the value can be used directly to offset +// the pointer into the rhs vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if< + IsFixed16Type::value && IsFixed16Type::value && + (IsFixed32Type::value || IsFixed16Type::value)>::type +SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + constexpr int kShiftAmount = + TypeOfProduct::type::kMantissaBits - + OutType::kMantissaBits; + static_assert(kShiftAmount >= 0, + "Result must have fewer mantissa bits than product"); + for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { + // Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3]. + __m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr)); + __m256i biases = _mm256_set_m128i(bias, bias); + bias_ptr += 4; + // Swap the top two pairs: [0 1 2 3 2 3 0 1] + // TODO(b/188702959): consider |_mm256_permutevar8x32|, and set the index + // register outside the row loop. + biases = _mm256_permute4x64_epi64(biases, 0xb4); + // Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3]. + biases = _mm256_unpacklo_epi32(biases, biases); + // Double the results to make up for the division by 4. + // TODO(b/188702959): consider moving this to where the biases are computed. + __m256i sum = _mm256_add_epi32(biases, biases); + + // TODO(b/188702959): People don't like the old-fashioned, close-to-the- + // metal notation of *|nnz_per_row|++, so measure the effect of putting the + // increment in the for loop. + int reduced_col_count = *nnz_per_row; + ++nnz_per_row; + for (int c = 0; c < reduced_col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + rhs_ptr += col_delta; + // Multiply this 4x4 block. + // Get the 4x int16 into the bottom of rhs_64. + __m128i rhs_64 = + _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptr)); + // Load all 16 weights. + __m256i weights = + _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); + // Broadcast the rhs, pretending that each is a 64-bit unit: + // [0123 0123 0123 0123]. + __m256i rhs = _mm256_broadcastq_epi64(rhs_64); + weights_ptr += 16; + // |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally + // adds adjacent pairs to make 8x32 bit results. Add these to the sum. + sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs)); + } + static_assert( + IsFixed16Type::value || IsFixed32Type::value, + "AVX2 kernel only supports fixed16 and fixed32 types"); + // The only significant difference between fixed16 and fixed32 is the size + // of the storage unit. The registers have to be repacked accordingly. + if (IsFixed32Type::value) { + Extract4xint32(relu, kShiftAmount, sum, + reinterpret_cast(&out_ptr)); + } else { + Extract4xint16(relu, kShiftAmount, sum, + reinterpret_cast(&out_ptr)); + } + } +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a fat vector with 5 columns and b is vector. b is +// broadcast. Weights are stored for this routine by making each 4x4 block +// contiguous. Blocks are ordered in standard row-major format. column indices +// are converted to deltas and then multiplied by 2 to convert to bytes, so +// that the value can be used directly to offset the pointer into the rhs +// vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if< + IsFixed16Type::value && IsFixed16Type::value && + (IsFixed32Type::value || IsFixed16Type::value)>::type +SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, + int relu) { + constexpr int kShiftAmount = + TypeOfProduct::type::kMantissaBits - + OutType::kMantissaBits; + static_assert(kShiftAmount >= 0, + "Result must have fewer mantissa bits than product"); + const RhsType* rhs_ptrs[5]; + for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; + + OutType* out_ptrs[5]; + for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; + + for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { + // We will acumulate the results in 5 registers, sum_0 to sum_4. + // Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3]. + __m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr)); + __m256i biases = _mm256_set_m128i(bias, bias); + bias_ptr += 4; + // Swap the top two pairs: [0 1 2 3 2 3 0 1] + biases = _mm256_permute4x64_epi64(biases, 0xb4); + // Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3]. + biases = _mm256_unpacklo_epi32(biases, biases); + // Double the results to make up for the division by 4. + __m256i sum_0 = _mm256_add_epi32(biases, biases); + __m256i sum_1 = sum_0; + __m256i sum_2 = sum_0; + __m256i sum_3 = sum_0; + __m256i sum_4 = sum_0; + + int reduced_col_count = *nnz_per_row; + ++nnz_per_row; + for (int c = 0; c < reduced_col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta; + // Multiply this 4x4 block. + // Get the 4x int16 into the bottom of |rhs_64|. + __m128i rhs_64 = + _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[0])); + // Load all 16 weights. + __m256i weights = + _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); + // Broadcast the rhs, pretending that each is a 64-bit unit: + // [0123 0123 0123 0123]. + __m256i rhs = _mm256_broadcastq_epi64(rhs_64); + weights_ptr += 16; + // |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally + // adds adjacent pairs to make 8x32 bit results. Add these to the sum. + sum_0 = _mm256_add_epi32(sum_0, _mm256_madd_epi16(weights, rhs)); + rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[1])); + rhs = _mm256_broadcastq_epi64(rhs_64); + sum_1 = _mm256_add_epi32(sum_1, _mm256_madd_epi16(weights, rhs)); + rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[2])); + rhs = _mm256_broadcastq_epi64(rhs_64); + sum_2 = _mm256_add_epi32(sum_2, _mm256_madd_epi16(weights, rhs)); + rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[3])); + rhs = _mm256_broadcastq_epi64(rhs_64); + sum_3 = _mm256_add_epi32(sum_3, _mm256_madd_epi16(weights, rhs)); + rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[4])); + rhs = _mm256_broadcastq_epi64(rhs_64); + sum_4 = _mm256_add_epi32(sum_4, _mm256_madd_epi16(weights, rhs)); + } + static_assert( + IsFixed16Type::value || IsFixed32Type::value, + "AVX2 kernel only supports fixed16 and fixed32 types"); + // The only significant difference between fixed16 and fixed32 is the size + // of the storage unit. The registers have to be repacked accordingly. + if (IsFixed32Type::value) { + Extract4xint32(relu, kShiftAmount, sum_0, + reinterpret_cast(&out_ptrs[0])); + Extract4xint32(relu, kShiftAmount, sum_1, + reinterpret_cast(&out_ptrs[1])); + Extract4xint32(relu, kShiftAmount, sum_2, + reinterpret_cast(&out_ptrs[2])); + Extract4xint32(relu, kShiftAmount, sum_3, + reinterpret_cast(&out_ptrs[3])); + Extract4xint32(relu, kShiftAmount, sum_4, + reinterpret_cast(&out_ptrs[4])); + } else { + Extract4xint16(relu, kShiftAmount, sum_0, + reinterpret_cast(&out_ptrs[0])); + Extract4xint16(relu, kShiftAmount, sum_1, + reinterpret_cast(&out_ptrs[1])); + Extract4xint16(relu, kShiftAmount, sum_2, + reinterpret_cast(&out_ptrs[2])); + Extract4xint16(relu, kShiftAmount, sum_3, + reinterpret_cast(&out_ptrs[3])); + Extract4xint16(relu, kShiftAmount, sum_4, + reinterpret_cast(&out_ptrs[4])); + } + } +} + +// Processes one GRU gate input with sigmoid. +template +inline __m256i GRUGateSigmoid(const void* gate_ptr, const void* gate_other_ptr, + const __m256i& input, + const int32_t* sigmoid_table) { + __m256i gate = _mm256_loadu_si256(reinterpret_cast(gate_ptr)); + if (SplitGates) { + __m256i other = + _mm256_loadu_si256(reinterpret_cast(gate_other_ptr)); + gate = _mm256_add_epi32(gate, other); + } + gate = _mm256_add_epi32(gate, input); + // Compute sigmoids on reset and update. + return csrblocksparse::fixed32_sigmoid_fixed16( + sigmoid_table, gate); +} + +// Processes the tanh and the final combination, returning the new GRU state. +template +inline __m256i GRUGateState(const __m256i& cell, const __m256i& reset, + const __m256i& update, + const __m256i& rounding_offset, + const void* gate_ptr, const void* gate_other_ptr, + const void* gru_h_ptr, const int32_t* tanh_table) { + // Multiply the cell GRU output and the reset. There is a slight danger of + // loss of precision here, so use 32x32=64 bit and shift back after. + __m256i gru = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_ptr)); + if (SplitGates) { + __m256i other_gru = + _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_other_ptr)); + gru = _mm256_add_epi32(gru, other_gru); + } + // This only computes the products of the low-order 32 bits of each pair. + __m256i gru_lo = _mm256_mul_epi32(gru, reset); + // Swap odd and even 32-bit units and do it again to get the high products. + gru = _mm256_shuffle_epi32(gru, 0xb1); + __m256i gru_hi = _mm256_mul_epi32(gru, _mm256_shuffle_epi32(reset, 0xb1)); + // Now shift right to compensate for the multiply and re-interleave the + // 32-bit results. + // NOTE: There is no shift right arithmetic for 64 bit values until AVX512! + // Fortunately it doesn't matter, as the results are being truncated to 32 + // bits and we aren't shifting right by more than 32 bits here. + gru_lo = _mm256_srli_epi64(gru_lo, StateMantissaBits); + // The upper results are shifted LEFT, so we can use blend to recombine in + // a single instruction. + gru_hi = _mm256_slli_epi64(gru_hi, 32 - StateMantissaBits); + // Recombine the 32 bit results from lo and hi, alternating. + gru = _mm256_blend_epi32(gru_lo, gru_hi, 0xaa); + gru = _mm256_add_epi32(cell, gru); + // Compute tanh on the result. Although this instantly discards a bunch of + // bits, there were only 7 surplus bits for the multiply, which isn't enough + // to do it as 16x16=32. + __m256i hbar = + csrblocksparse::fixed32_tanh_fixed16(tanh_table, gru); + // Load the 16-bit previous GRU state and sign-extend to 32 bits. + gru = _mm256_cvtepi16_epi32( + _mm_load_si128(reinterpret_cast<__m128i const*>(gru_h_ptr))); + gru = _mm256_sub_epi32(gru, hbar); + // Since |gru| is 16 bit sign-extended to 32, and |update| is the output of + // sigmoid, it is always contained within 16 bits and never negative, we can + // use |madd_epi16| to do 16x16=32 multiply with horizontal adding as the + // addend will always be zero, and this is twice as fast as full blown + // 32x32=32. The only possible problem is if the subtract above caused + // overflow. + gru = _mm256_madd_epi16(gru, update); + // Renormalize to fixed16. This time rounding is critical, as this is the + // output GRU state. + gru = _mm256_add_epi32(gru, rounding_offset); + gru = _mm256_srai_epi32(gru, StateMantissaBits); + return _mm256_add_epi32(gru, hbar); +} + +template +typename std::enable_if::value>::type SumVectors( + int start, int end, const Type* add1, const Type* add2, Type* result) { + constexpr int kSIMDWidth = 8; + for (int i = start; i < end; i += kSIMDWidth) { + __m256i data1 = + _mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i)); + __m256i data2 = + _mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i)); + data1 = _mm256_add_epi32(data1, data2); + _mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1); + } +} + +template +typename std::enable_if::value>::type SumVectors( + int start, int end, const Type* add1, const Type* add2, Type* result) { + constexpr int kSIMDWidth = 16; + for (int i = start; i < end; i += kSIMDWidth) { + __m256i data1 = + _mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i)); + __m256i data2 = + _mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i)); + data1 = _mm256_add_epi16(data1, data2); + _mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1); + } +} + +#endif // __AVX2__ + +} // namespace detail +} // namespace csrblocksparse + +#undef LABEL_COL_LOOP +#undef LABEL_ROW_LOOP +#undef LABEL_SKIP_COL_LOOP +#undef LABEL_TOP_LOOP + +#endif // __AVX__ + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_ diff --git a/sparse_matmul/compute/kernels_generic.h b/sparse_matmul/compute/kernels_generic.h new file mode 100644 index 0000000000000000000000000000000000000000..2ff9c7ecddfc4b94457a81aa5ff0edd4ae86e134 --- /dev/null +++ b/sparse_matmul/compute/kernels_generic.h @@ -0,0 +1,273 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_ + +#include +#include + +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" +#include "sparse_matmul/numerics/type_utils.h" + +// Separate out the assembly kernels for readability. Eventually this will +// become an ifdef switch on the architecture type. +#if defined __aarch64__ +#include "sparse_matmul/compute/kernels_arm.h" +#elif defined __AVX__ +#include "sparse_matmul/compute/kernels_avx.h" +#else // defined __AVX__ +// If there is no architecture-specific implementation, then always use generic. +template +struct ShouldEnableGenericSpMV_4x4 : std::true_type {}; +template +struct ShouldEnableGenericSpMM5_4x4 : std::true_type {}; +template +struct ShouldEnableGenericSpMV_1x1 : std::true_type {}; +template +struct ShouldEnableGenericSpMM5_1x1 : std::true_type {}; +template +struct ShouldEnableGenericAdd : std::true_type {}; +#endif // defined __arch64__ + +namespace csrblocksparse { +namespace detail { + +// The computational routines do NO error checking for speed. It is assumed +// that this has been handled by CSRBlockSparseMatrix. + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a vector and b is vector. Weights are stored for this +// routine by making each 4x4 block contiguous. Blocks are ordered in standard +// row-major format. column indices are converted to deltas and then multiplied +// by 2 to convert to bytes, so that the value can be used directly to offset +// the pointer into the rhs vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if< + ShouldEnableGenericSpMV_4x4::value>::type +SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { + float accumulators[4]; + // Undo the divion by the happens for the assembly version. + for (int i = 0; i < 4; ++i) + accumulators[i] = 4.f * static_cast(*bias_ptr++); + + int reduced_col_count = *nnz_per_row++; + for (int c = 0; c < reduced_col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + rhs_ptr += col_delta; + + // Multiply this 4x4 block. + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + accumulators[i] += static_cast(*weights_ptr++) * + static_cast(rhs_ptr[j]); + } + } + } + + for (int i = 0; i < 4; ++i) + *out_ptr++ = static_cast(relu ? std::max(accumulators[i], 0.f) + : accumulators[i]); + } +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a fat vector with 5 columns and b is vector. b is +// broadcast. Weights are stored for this routine by making each 4x4 block +// contiguous. Blocks are ordered in standard row-major format. column indices +// are converted to deltas and then multiplied by 2 to convert to bytes, so +// that the value can be used directly to offset the pointer into the rhs +// vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if< + ShouldEnableGenericSpMM5_4x4::value>::type +SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, + int relu) { + const RhsType* rhs_ptrs[5]; + for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; + + OutType* out_ptrs[5]; + for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; + + for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { + float accumulators[4][5]; + // Undo the divion by the happens for the assembly version. + for (int i = 0; i < 4; ++i) { + for (int k = 0; k < 5; ++k) { + accumulators[i][k] = 4.f * static_cast(*bias_ptr); + } + ++bias_ptr; + } + + int reduced_col_count = *nnz_per_row++; + for (int c = 0; c < reduced_col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta; + + // multiply this 4x4 block + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 5; ++k) { + accumulators[i][k] += static_cast(*weights_ptr) * + static_cast(rhs_ptrs[k][j]); + } + weights_ptr++; + } + } + } + + for (int k = 0; k < 5; ++k) { + for (int i = 0; i < 4; ++i) { + out_ptrs[k][0] = static_cast( + relu ? std::max(accumulators[i][k], 0.f) : accumulators[i][k]); + out_ptrs[k]++; + } + } + } +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with +// a 1x1 blocked pattern (ie unstructured), x is a +// vector and b is vector. +// Weights are stored for this routine in standard CSR format. Each row must +// have a multiple of 8 columns. +// column indices are converted to deltas and then multiplied by 2 to convert +// to bytes, so that the value can be used directly to offset the pointer +// into the rhs vector. +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if< + ShouldEnableGenericSpMV_1x1::value>::type +SpMV_1x1(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + for (int row = 0; row < assigned_rows; ++row) { + // Undo the divion by the happens for the assembly version. + float accumulator = 4.f * static_cast(*bias_ptr++); + + int col_count = *nnz_per_row++; + for (int c = 0; c < col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + rhs_ptr += col_delta; + + accumulator += + static_cast(*weights_ptr++) * static_cast(*rhs_ptr); + } + + *out_ptr++ = + static_cast(relu ? std::max(accumulator, 0.f) : accumulator); + } +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with +// a 1x1 blocked pattern (ie unstructured), x is a +// vector and b is vector. +// Weights are stored for this routine in standard CSR format. Each row must +// have a multiple of 8 columns. +// column indices are converted to deltas and then multiplied by 2 to convert +// to bytes, so that the value can be used directly to offset the pointer +// into the rhs vector. +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if< + ShouldEnableGenericSpMM5_1x1::value>::type +SpMM5_1x1(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, + int relu) { + const RhsType* rhs_ptrs[5]; + for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; + + OutType* out_ptrs[5]; + for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; + + for (int row = 0; row < assigned_rows; ++row) { + // Undo the divion by the happens for the assembly version. + float accumulator[5]; + for (int i = 0; i < 5; ++i) + accumulator[i] = 4.f * static_cast(*bias_ptr); + + ++bias_ptr; + + int col_count = *nnz_per_row++; + for (int c = 0; c < col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + for (int i = 0; i < 5; ++i) { + rhs_ptrs[i] += col_delta; + accumulator[i] += static_cast(*weights_ptr) * + static_cast(rhs_ptrs[i][0]); + } + weights_ptr++; + } + + for (int i = 0; i < 5; ++i) { + out_ptrs[i][0] = static_cast(relu ? std::max(accumulator[i], 0.f) + : accumulator[i]); + out_ptrs[i]++; + } + } +} + +template +typename std::enable_if::value>::type SumVectors( + int start, int end, const Type* add1, const Type* add2, Type* result) { + LOG_FIRST_N(WARNING, 1) << "SumVectors: using generic kernel!"; + for (int i = start; i < end; ++i) { + Type sum = static_cast(static_cast(add1[i]) + + static_cast(add2[i])); + result[i] = sum; + } +} + +} // namespace detail +} // namespace csrblocksparse + +#undef LABEL_COL_LOOP +#undef LABEL_ROW_LOOP +#undef LABEL_SKIP_COL_LOOP +#undef LABEL_TOP_LOOP + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_ diff --git a/sparse_matmul/compute/matmul.h b/sparse_matmul/compute/matmul.h new file mode 100644 index 0000000000000000000000000000000000000000..442164defab63430fc9545fd56a790618fe4796d --- /dev/null +++ b/sparse_matmul/compute/matmul.h @@ -0,0 +1,199 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_ + +#include +#include + +#include "absl/time/time.h" +#include "sparse_matmul/compute/matmul_fixed_avx2.h" +#include "sparse_matmul/compute/matmul_generic.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/type_utils.h" +#if defined(__x86_64__) || defined(__i386__) || defined(_WIN32) +#include +#endif + +namespace csrblocksparse { + +// The number of elements in a block. +constexpr int kBlockSize = 4; + +// Base class for Matmul containing the members that are non type-specicfic. +class MatmulBase { + public: + // Constructor initializes the flags that determine which implementation to + // use at run-time, constrained by both compiler flags and cpuid. + MatmulBase() { +#if defined(__x86_64__) || defined(__i386__) || defined(_WIN32) + // Code tested to work on Linux systems and multiple Android emulators. + unsigned int eax, ebx, ecx, edx; + if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) != 0) { + using_avx_ = (ecx & bit_AVX) != 0; + if (using_avx_) { + __get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx); + using_avx2_ = (ebx & bit_AVX2) != 0; + using_avx512_ = (ebx & bit_AVX512F) != 0 && (ebx & bit_AVX512DQ) && + (ebx & bit_AVX512BW) != 0; + VLOG(2) << "avx2 flag=" << using_avx2_ << " 512=" << using_avx512_; + } else { + LOG(ERROR) << "AVX not found at all!"; + } + } +#else + using_aarch64_ = true; +#endif + } + + protected: + // Flags that define what (runtime) architectures are available. Flags that + // are set are limited by both the compiler flags and runtime environment. + bool using_avx512_ = false; + bool using_avx2_ = false; + bool using_avx_ = false; + bool using_aarch64_ = false; +}; + +// The master template is really a catch-all for the unimplmented cases to +// report an error. +template +class Matmul : public MatmulBase { + public: + // Sparse inputs, outputs replicated strided for each thread. + template + void MatVec4x4(const WeightType* weights, const RhsType* rhs, + const typename TypeOfProduct::type* bias, + const int32_t* nnz_per_row, const int16_t* rhs_indices, + int start_row, int end_row, bool relu, int replicas, + int stride, OutType* output) { + // The specializations should take care of every real case. + CHECK(false) << "Unsupported combination of types used!"; + } + template + void MatVec8x4(const WeightType* weights, const RhsType* rhs, + const typename TypeOfProduct::type* bias, + const int32_t* nnz_per_row, const int16_t* rhs_indices, + int start_row, int end_row, bool relu, int replicas, + int stride, OutType* output) { + // The specializations should take care of every real case. + CHECK(false) << "Unsupported combination of types used!"; + } +}; + +// Full specialization for float. +template <> +class Matmul : public MatmulBase { + public: + void MatVec4x4(const float* weights, const float* rhs, const float* bias, + const int32_t* nnz_per_row, const int16_t* rhs_indices, + int start_row, int end_row, bool relu, int replicas, + int stride, float* output) { + detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, + start_row, end_row, /*block_height=*/4, + /*block_width=*/4, relu, replicas, stride, + output); + } + void MatVec8x4(const float* weights, const float* rhs, const float* bias, + const int32_t* nnz_per_row, const int16_t* rhs_indices, + int start_row, int end_row, bool relu, int replicas, + int stride, float* output) { + detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, + start_row, end_row, /*block_height=*/8, + /*block_width=*/4, relu, replicas, stride, + output); + } +}; + +// Partial specialization for fixed types. Covers fixed16xfixed16 = OutType, +// where OutType should be fixed16 or fixed32. The mantissa bits don't have +// to match. +template +class Matmul, fixed16> : public MatmulBase { + public: + using WeightType = fixed16; + using RhsType = fixed16; + + template + void MatVec4x4(const int16_t* weights, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int replicas, int stride, OutType* output) { + constexpr int kShiftAmount = + TypeOfProduct::type::kMantissaBits - + OutType::kMantissaBits; + static_assert(kShiftAmount >= 0, + "OutType must not have more mantissa bits than inputs"); +#if defined __AVX2__ + CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!"; + if (sizeof(*output) == 4) { + int32_t* out32 = reinterpret_cast(output); + detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices, + start_row, end_row, relu, kShiftAmount, + replicas, stride, out32); + } else { + int16_t* out16 = reinterpret_cast(output); + detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices, + start_row, end_row, relu, kShiftAmount, + replicas, stride, out16); + } +#elif defined __aarch64__ + if (using_aarch64_) { + LOG(FATAL) << "Fixed16 MatVec4x4 not yet implemented!"; + } + +#else + detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, + start_row, end_row, /*block_height=*/4, + /*block_width=*/4, relu, sizeof(*output), + kShiftAmount, replicas, stride, output); +#endif // __AVX2__ + } + + template + void MatVec8x4(const int16_t* weights, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int replicas, int stride, OutType* output) { + constexpr int kShiftAmount = + TypeOfProduct::type::kMantissaBits - + OutType::kMantissaBits; + static_assert(kShiftAmount >= 0, + "OutType must not have more mantissa bits than inputs"); +#if defined __AVX2__ + CHECK(replicas == 1 && sizeof(*output) == 4) + << "Only replicas == 1 and fixed32 output are implemented for AVX2!"; + CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!"; + int32_t* out32 = reinterpret_cast(output); + detail::MatVec8x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices, + start_row, end_row, relu, kShiftAmount, out32); +#elif defined __aarch64__ + if (using_aarch64_) { + LOG(FATAL) << "Fixed16 MatVec8x4 not yet implemented!"; + } +#else + detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, + start_row, end_row, /*block_height=*/8, + /*block_width=*/4, relu, sizeof(*output), + kShiftAmount, replicas, stride, output); +#endif // __AVX2__ + } +}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_ diff --git a/sparse_matmul/compute/matmul_fixed_avx2.cc b/sparse_matmul/compute/matmul_fixed_avx2.cc new file mode 100644 index 0000000000000000000000000000000000000000..827d8c48b6c19cc7d38749d3857aa8951257eb8b --- /dev/null +++ b/sparse_matmul/compute/matmul_fixed_avx2.cc @@ -0,0 +1,235 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/compute/matmul_fixed_avx2.h" + +#include + +#if defined __AVX__ +#include +#endif + +#include "sparse_matmul/compute/matmul.h" + +namespace csrblocksparse { +namespace detail { + +static const int32_t kint32min = static_cast(~0x7FFFFFFF); +static const int32_t kint32max = static_cast(0x7FFFFFFF); + +#if defined __AVX2__ +// In-line function computes and returns the result of one row (of blocks) as +// 4x int32_t. |weights_ptr| is a non-const reference so it can easily be +// interpreted as belonging to the caller. +inline __m256i ComputeRowResults(const __m128i& bias128, const int16_t* rhs, + const int16_t* rhs_indices, int nnz, + int16_t const*& weights_ptr) { + // Expand bias to 64 bits in a 256 bit register [0 z 1 z 2 z 3 z], where z is + // Zero and 0-3 are the 4x32 bit bias values. + __m256i sum = _mm256_cvtepu32_epi64(bias128); + + for (int c = 0; c < nnz; ++c) { + int rhs_index = rhs_indices[c]; + // Load all 16 weights. + __m256i weights = + _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); + // Get the 4x int16_t into the bottom of |rhs_64|. + __m128i rhs_64 = _mm_loadl_epi64( + reinterpret_cast<__m128i const*>(rhs + rhs_index * kBlockSize)); + // Broadcast the rhs, pretending that each is a 64-bit unit: + // [0123 0123 0123 0123]. + __m256i rhs_value = _mm256_broadcastq_epi64(rhs_64); + weights_ptr += 16; + sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs_value)); + } + // Horizontally add the results. We have 1 register that contains results + // [0 0 1 1 2 2 3 3], but hadd (and almost no other AVX instruction) will not + // cross lanes, so we end up with [0 1 0 1 2 3 2 3] + sum = _mm256_hadd_epi32(sum, sum); + // Permutes the middle two pairs to get the answers together. + return _mm256_permute4x64_epi64(sum, 0xd8); +} + +// Template that allows any fixed combination of OutType and replicas, plus +// variable |relu|, |shift_out|. Note that |kReplicas| is a template arg as +// well as a function arg so we can hard-code a limited amount of unrolling. +template +void MatVec4x4FixedAVX2Template(const int16_t* weights_ptr, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, + int end_row, bool relu, int shift_out, + int replicas, int stride, OutType* output) { + int rounding_addon = shift_out > 0 ? (1 << (shift_out - 1)) : 0; + __m256i rounding = _mm256_set1_epi32(rounding_addon); + __m256i zero = relu ? _mm256_setzero_si256() : _mm256_set1_epi32(kint32min); + for (int row_block = start_row; row_block < end_row; ++row_block) { + // Load 4 biases [0 1 2 3]. + __m128i bias128 = _mm_load_si128(reinterpret_cast<__m128i const*>(bias)); + bias += kBlockSize; + int nnz = nnz_per_row[row_block]; + __m256i sum = + ComputeRowResults(bias128, rhs, rhs_indices, nnz, weights_ptr); + rhs_indices += nnz; + // Shift right with rounding to get the right number of mantissa bits. + sum = _mm256_add_epi32(sum, rounding); + sum = _mm256_srai_epi32(sum, shift_out); + // Now sum contains [res0, res1, res2, res3, res0, res1, res2, res3] + sum = _mm256_max_epi32(sum, zero); + if (sizeof(OutType) == 2) { + // Clip to 16 bit range (with saturation) and pack in the bottom 64 + // bits. The 64 bit result is replicated across the whole 256 bit + // register. [0123 0123 0123 0123] + sum = _mm256_packs_epi32(sum, sum); + int64_t result = _mm256_extract_epi64(sum, 0); + *reinterpret_cast(output) = result; + if (kReplicas > 1) { + *reinterpret_cast(output + stride) = result; + if (kReplicas > 2) { + for (int r = 2; r < replicas; ++r) { + *reinterpret_cast(output + r * stride) = result; + } + } + } + } else { + // Save the lower 128 bits (4x int32_t). + __m128i result = _mm256_extractf128_si256(sum, 0); + _mm_store_si128(reinterpret_cast<__m128i*>(output), result); + if (kReplicas > 1) { + _mm_store_si128(reinterpret_cast<__m128i*>(output + stride), result); + if (kReplicas > 2) { + for (int r = 2; r < replicas; ++r) { + _mm_store_si128(reinterpret_cast<__m128i*>(output + r * stride), + result); + } + } + } + } + output += kBlockSize; + } +} + +// Version that covers all possible combinations of the variable conditions: +// |relu|, |shift_out|, |replicas|, with int16_t |output|. +void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int shift_out, int replicas, int stride, + int16_t* output) { + if (replicas <= 1) { + MatVec4x4FixedAVX2Template(weights_ptr, rhs, bias, nnz_per_row, + rhs_indices, start_row, end_row, + relu, shift_out, 1, stride, output); + } else if (replicas == 2) { + MatVec4x4FixedAVX2Template(weights_ptr, rhs, bias, nnz_per_row, + rhs_indices, start_row, end_row, + relu, shift_out, 2, stride, output); + } else { + MatVec4x4FixedAVX2Template( + weights_ptr, rhs, bias, nnz_per_row, rhs_indices, start_row, end_row, + relu, shift_out, replicas, stride, output); + } +} + +// Version that covers all possible combinations of the variable conditions: +// |relu|, |shift_out|, |replicas|, with int32_t |output|. +void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int shift_out, int replicas, int stride, + int32_t* output) { + if (replicas <= 1) { + MatVec4x4FixedAVX2Template(weights_ptr, rhs, bias, nnz_per_row, + rhs_indices, start_row, end_row, + relu, shift_out, 1, stride, output); + } else if (replicas == 2) { + MatVec4x4FixedAVX2Template(weights_ptr, rhs, bias, nnz_per_row, + rhs_indices, start_row, end_row, + relu, shift_out, 2, stride, output); + } else { + MatVec4x4FixedAVX2Template( + weights_ptr, rhs, bias, nnz_per_row, rhs_indices, start_row, end_row, + relu, shift_out, replicas, stride, output); + } +} + +// In-line function computes and returns the result of one row (of blocks) as +// 8x int32_t. weights_ptr is a non-const reference so it can easily be +// interpreted as belonging to the caller. +inline __m256i Compute8RowResults(const __m256i& bias256, const int16_t* rhs, + const int16_t* rhs_indices, int nnz, + int16_t const*& weights_ptr) { + // Expand bias to 64 bits in a 256 bit register [0 z 1 z 2 z 3 z], where z is + // Zero and 0-3 are the 4x32 bit bias values from 128 bit half of the input. + __m256i sum1 = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(bias256)); + // Plus 4 more in another sum register from the upper 128 bit half. + __m256i sum2 = _mm256_cvtepu32_epi64(_mm256_extractf128_si256(bias256, 1)); + + for (int c = 0; c < nnz; ++c) { + int rhs_index = rhs_indices[c]; + // Load all 16 weights. + __m256i weights = + _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); + // Get the 4x int16_t into the bottom of |rhs_64|. + __m128i rhs_64 = _mm_loadl_epi64( + reinterpret_cast<__m128i const*>(rhs + rhs_index * kBlockSize)); + // Broadcast the rhs, pretending that each is a 64-bit unit: + // [0123 0123 0123 0123]. + __m256i rhs_value = _mm256_broadcastq_epi64(rhs_64); + weights_ptr += 16; + sum1 = _mm256_add_epi32(sum1, _mm256_madd_epi16(weights, rhs_value)); + // Same again for the other 4 results, re-using the same rhs value. + weights = _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); + weights_ptr += 16; + sum2 = _mm256_add_epi32(sum2, _mm256_madd_epi16(weights, rhs_value)); + } + // Horizontally add the results. We have 2 registers that contain results + // [0 0 1 1 2 2 3 3], and [4 4 5 5 6 6 7 7] but hadd (and almost no other AVX + // instruction) will not cross lanes, so we end up with [0 1 4 5 2 3 6 7] + sum1 = _mm256_hadd_epi32(sum1, sum2); + // Permutes the middle two pairs to get the answers in the right order. + return _mm256_permute4x64_epi64(sum1, 0xd8); +} + +// Version that covers the main conditions used with 8x4: +// |relu|, |shift_out|, with int32_t |output|. +void MatVec8x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int shift_out, int32_t* output) { + int rounding_addon = shift_out > 0 ? (1 << (shift_out - 1)) : 0; + __m256i rounding = _mm256_set1_epi32(rounding_addon); + __m256i zero = relu ? _mm256_setzero_si256() : _mm256_set1_epi32(kint32min); + for (int row_block = start_row; row_block < end_row; ++row_block) { + // Load 4 biases [0 1 2 3 4 5 6 7]. + __m256i bias256 = _mm256_load_si256(reinterpret_cast<__m256i const*>(bias)); + bias += kBlockSize * 2; + int nnz = nnz_per_row[row_block]; + __m256i sum = + Compute8RowResults(bias256, rhs, rhs_indices, nnz, weights_ptr); + rhs_indices += nnz; + // Shift right with rounding to get the right number of mantissa bits. + sum = _mm256_add_epi32(sum, rounding); + sum = _mm256_srai_epi32(sum, shift_out); + // Now sum contains [res0, res1, res2, res3, res0, res1, res2, res3] + sum = _mm256_max_epi32(sum, zero); + // Save the all 256 bits (8x int32_t). + _mm256_store_si256(reinterpret_cast<__m256i*>(output), sum); + output += kBlockSize * 2; + } +} + +#endif + +} // namespace detail +} // namespace csrblocksparse diff --git a/sparse_matmul/compute/matmul_fixed_avx2.h b/sparse_matmul/compute/matmul_fixed_avx2.h new file mode 100644 index 0000000000000000000000000000000000000000..59e7d0eaa9aa576543ca428d3ad983c6ffa6b62a --- /dev/null +++ b/sparse_matmul/compute/matmul_fixed_avx2.h @@ -0,0 +1,49 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_ + +#include + +namespace csrblocksparse { +namespace detail { + +// Version that covers all possible combinations of the variable conditions: +// |relu|, |shift_out|, |replicas|, with int16 output. +void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int shift_out, int replicas, int stride, + int16_t* output); +// Version that covers all possible combinations of the variable conditions: +// |relu|, |shift_out|, |replicas|, with int32 output. +void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int shift_out, int replicas, int stride, + int32_t* output); +// Version that covers the main conditions used with 8x4: +// |relu|, |shift_out|, with int32 output. +void MatVec8x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int shift_out, int32_t* output); + +} // namespace detail +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_ diff --git a/sparse_matmul/compute/matmul_generic.cc b/sparse_matmul/compute/matmul_generic.cc new file mode 100644 index 0000000000000000000000000000000000000000..1cf4fe53fadc7717c1d15c086898041d6467a519 --- /dev/null +++ b/sparse_matmul/compute/matmul_generic.cc @@ -0,0 +1,122 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/compute/matmul_generic.h" + +#include +#include + +#include "sparse_matmul/compute/matmul.h" + +namespace csrblocksparse { +namespace detail { + +void MatVecFloatGeneric(const float* weights, const float* rhs, + const float* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + int block_height, int block_width, bool relu, + int replicas, int stride, float* output) { + int weight_index = 0; + int bias_index = 0; + std::vector accumulators(block_height); + for (int row_block = start_row; row_block < end_row; + ++row_block, output += block_height) { + int nnz = nnz_per_row[row_block]; + // Biases are now stored and used directly without pre-division. + for (int i = 0; i < block_height; ++i) accumulators[i] = bias[bias_index++]; + + for (int c = 0; c < nnz; ++c) { + int rhs_index = rhs_indices[c]; + const float* block_rhs = rhs + rhs_index * block_width; + // Multiply this |block_height| x |block_width| block. + for (int i = 0; i < block_height; ++i) { + for (int j = 0; j < block_width; ++j) { + accumulators[i] += weights[weight_index++] * block_rhs[j]; + } + } + } + rhs_indices += nnz; + // Apply relu if desired. + if (relu) { + for (int i = 0; i < block_height; ++i) { + if (accumulators[i] < 0) accumulators[i] = 0; + } + } + for (int r = 0; r < replicas; ++r) { + for (int i = 0; i < block_height; ++i) { + output[i + r * stride] = accumulators[i]; + } + } + } +} + +void MatVecFixedGeneric(const int16_t* weights, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + int block_height, int block_width, bool relu, + int bytes_out, int shift_out, int replicas, int stride, + void* output) { + int weight_index = 0; + int bias_index = 0; + std::vector accumulators(block_height); + for (int row_block = start_row; row_block < end_row; ++row_block) { + int nnz = nnz_per_row[row_block]; + // Biases are now stored and used directly without pre-division. + for (int i = 0; i < block_height; ++i) accumulators[i] = bias[bias_index++]; + + for (int c = 0; c < nnz; ++c) { + int rhs_index = rhs_indices[c]; + const int16_t* block_rhs = rhs + rhs_index * block_width; + // Multiply this |block_height| x |block_width| block. + for (int i = 0; i < block_height; ++i) { + for (int j = 0; j < block_width; ++j) { + accumulators[i] += weights[weight_index++] * block_rhs[j]; + } + } + } + rhs_indices += nnz; + // Apply relu if desired. + if (relu) { + for (int i = 0; i < block_height; ++i) { + if (accumulators[i] < 0) accumulators[i] = 0; + } + } + // Output shift. + if (shift_out > 0) { + for (int i = 0; i < block_height; ++i) { + accumulators[i] >>= shift_out; + } + } + if (bytes_out == 2) { + int16_t* out16 = reinterpret_cast(output); + output = out16 + block_height; + for (int r = 0; r < replicas; ++r, out16 += stride) { + for (int i = 0; i < block_height; ++i) { + out16[i] = accumulators[i]; + } + } + } else { + int32_t* out32 = reinterpret_cast(output); + output = out32 + block_height; + for (int r = 0; r < replicas; ++r, out32 += stride) { + for (int i = 0; i < block_height; ++i) { + out32[i] = accumulators[i]; + } + } + } + } +} + +} // namespace detail +} // namespace csrblocksparse diff --git a/sparse_matmul/compute/matmul_generic.h b/sparse_matmul/compute/matmul_generic.h new file mode 100644 index 0000000000000000000000000000000000000000..415d71cd86321836f74830f44546ce4be4e903a8 --- /dev/null +++ b/sparse_matmul/compute/matmul_generic.h @@ -0,0 +1,41 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_ + +#include + +namespace csrblocksparse { +namespace detail { + +// Generic version uses plain C++ code. +void MatVecFloatGeneric(const float* weights, const float* rhs, + const float* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + int block_height, int block_width, bool relu, + int replicas, int stride, float* output); +void MatVecFixedGeneric(const int16_t* weights, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + int block_height, int block_width, bool relu, + int bytes_out, int shift_out, int replicas, int stride, + void* output); + +} // namespace detail +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_ diff --git a/sparse_matmul/compute/thread_bounds.cc b/sparse_matmul/compute/thread_bounds.cc new file mode 100644 index 0000000000000000000000000000000000000000..e37a395e7585740d4e71acbeeffc3c319081fed4 --- /dev/null +++ b/sparse_matmul/compute/thread_bounds.cc @@ -0,0 +1,106 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/compute/thread_bounds.h" + +#include + +#include "glog/logging.h" + +namespace csrblocksparse { + +void ThreadBounds::PrepareForThreads(int block_width, int block_height, + int num_threads, + int reduced_rows_per_cache_row, + int reduced_rows, const int* nnz_per_row) { + CHECK_GT(num_threads, 0); + block_width_ = block_width; + block_height_ = block_height; + ComputeThreadSplitPoints(num_threads, reduced_rows_per_cache_row, + reduced_rows, nnz_per_row); + weight_starts_.clear(); + rhs_indices_starts_.clear(); + bias_starts_.clear(); + weight_starts_.reserve(row_starts_.size()); + rhs_indices_starts_.reserve(row_starts_.size()); + bias_starts_.reserve(row_starts_.size()); + + // Compute the start indices of each of the types, given what we know about + // padding, and number of |nnz_per_row|. + int weight_index = 0; + int rhs_indices_index = 0; + int bias_index = 0; + int row = 0; + for (int start : row_starts_) { + while (row < start) { + weight_index += nnz_per_row[row] * block_width_ * block_height_; + rhs_indices_index += nnz_per_row[row]; + bias_index += block_height_; + ++row; + } + weight_starts_.push_back(weight_index); + rhs_indices_starts_.push_back(rhs_indices_index); + bias_starts_.push_back(bias_index); + } +} + +// Computes the block row (reduced) index of the start of each thread. +void ThreadBounds::ComputeThreadSplitPoints(int num_threads, + int reduced_rows_per_cache_row, + int reduced_rows, + const int* nnz_per_row) { + row_starts_.assign(/*n=*/1, /*val=*/0); + // Break the rule if the matrix is too small to allow one per thread, which + // occurs only during tests. + if (reduced_rows_per_cache_row * num_threads > reduced_rows) + reduced_rows_per_cache_row = std::max(reduced_rows / num_threads, 1); + int cache_rows = (reduced_rows + reduced_rows_per_cache_row - 1) / + reduced_rows_per_cache_row; + + // Compute exclusive prefix sum of the amount of work per row. + std::vector work_upto_row(cache_rows + 1, 0); + int extra_row_work = 2 * reduced_rows_per_cache_row; + for (int i = 0; i < cache_rows; ++i) { + int new_nnz = 0; + for (int j = 0; j < reduced_rows_per_cache_row; ++j) { + // if |reduced_rows_per_cache_row| isn't an exact multiple of the + // matrix size, then we need to be careful here. + int index = i * reduced_rows_per_cache_row + j; + if (index < reduced_rows) new_nnz += nnz_per_row[index]; + } + work_upto_row[i + 1] = new_nnz + extra_row_work + work_upto_row[i]; + } + int total_work = work_upto_row.back(); + // Find the split point point based on assigned approximately equal amount + // of work for each thread. + int prev_split = 0; + for (int i = 1; i <= num_threads; ++i) { + int split = std::distance( + work_upto_row.begin(), + std::lower_bound(work_upto_row.begin(), work_upto_row.end(), + i * total_work / num_threads)); + int split_row = split * reduced_rows_per_cache_row; + if (i == num_threads) { + split_row = reduced_rows; + } + + VLOG(2) << "tid=" << i - 1 << " num rows=" << split_row - row_starts_.back() + << " work=" << work_upto_row[split] - work_upto_row[prev_split]; + row_starts_.push_back(split_row); + prev_split = split; + } + VLOG(2) << "total rows=" << reduced_rows << " total work=" << total_work; +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/compute/thread_bounds.h b/sparse_matmul/compute/thread_bounds.h new file mode 100644 index 0000000000000000000000000000000000000000..fd8a7d2b0e4e2fe5288efbb2e301f1a9475a9c5e --- /dev/null +++ b/sparse_matmul/compute/thread_bounds.h @@ -0,0 +1,74 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_ + +#include + +namespace csrblocksparse { + +// Class to compute and store the bounds of each thread used in a computation, +// and to provide corresponding spans of vectors. +class ThreadBounds { + public: + ThreadBounds() : block_width_(0), block_height_(0) {} + + void PrepareForThreads(int block_width, int block_height, int num_threads, + int reduced_rows_per_cache_row, int reduced_rows, + const int* nnz_per_row); + + // Functions that offset the appropriate type to the start of the data + // needed by the given thread id (|tid|). + template + const WeightType* OffsetWeights(const WeightType* weights, int tid) const { + return weights + weight_starts_[tid]; + } + template + const RhsIndType* OffsetRhsIndices(const RhsIndType* rhs_indices, + int tid) const { + return rhs_indices + rhs_indices_starts_[tid]; + } + template + const BiasType* OffsetBias(const BiasType* bias, int tid) const { + return bias + bias_starts_[tid]; + } + template + OutType* OffsetOutput(OutType* output, int tid) const { + return output + block_height_ * row_starts_[tid]; + } + int StartRow(int tid) const { return row_starts_[tid]; } + const std::vector& row_starts() const { return row_starts_; } + + private: + // Computes the block row (reduced) index of the start of each thread. + void ComputeThreadSplitPoints(int num_threads, int reduced_rows_per_cache_row, + int reduced_rows, const int* nnz_per_row); + + // Sizes of a sparse block. + int block_width_; + int block_height_; + // Start indices of each data type by thread-id with an extra value at the + // end. + std::vector row_starts_; + std::vector weight_starts_; + std::vector rhs_indices_starts_; + std::vector bias_starts_; +}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_ diff --git a/sparse_matmul/layers/BUILD b/sparse_matmul/layers/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..7c4ed36d05f66d5dea8c8a42091036a40d813f40 --- /dev/null +++ b/sparse_matmul/layers/BUILD @@ -0,0 +1,146 @@ +# Sparse/Masked Matrix and Layer. + +# [internal] load android_library_selector +# [internal] load android_cc_test:def.bzl + +licenses(["notice"]) + +cc_library( + name = "layer", + hdrs = [ + "sparse_linear_layer.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + ":matrix", + "//sparse_matmul/numerics:types", + "//sparse_matmul/os:coop_threads", + "//sparse_matmul/vector:cache_aligned_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + "@com_google_glog//:glog", + ], +) + +cc_library( + name = "matrix", + hdrs = [ + "csr_blocksparse_matrix.h", + "masked_sparse_matrix.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + "//sparse_matmul/compute:kernels", + "//sparse_matmul/compute:matmul", + "//sparse_matmul/compute:thread_bounds", + "//sparse_matmul/numerics:types", + "//sparse_matmul/os:coop_threads", + "//sparse_matmul/vector:cache_aligned_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + "@com_google_glog//:glog", + ], +) + +cc_library( + name = "utils", + srcs = [ + "utils.cc", + ], + hdrs = [ + "read_array_ifstream.h", + "utils.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + ":layer", + ":matrix", + ":status", + "//sparse_matmul/numerics:types", + "//sparse_matmul/vector:cache_aligned_vector", + "//sparse_matmul/zlib_wrapper", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@gulrak_filesystem//:filesystem", + ], +) + +cc_library( + name = "status", + srcs = [ + "errno_mapping.cc", + ], + hdrs = [ + "errno_mapping.h", + "status_macros.h", + ], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) + +cc_test( + name = "csrblocksparse_test", + size = "small", + srcs = [ + "csrblocksparse_test.cc", + ], + data = glob(["testdata/*"]), + linkopts = select({ + "@bazel_tools//platforms:android": ["-landroid"], + "//conditions:default": [], + }), + shard_count = 10, + deps = [ + ":status", + ":utils", + "//sparse_matmul/compute:matmul", + "//sparse_matmul/numerics:test_utils", + "//sparse_matmul/os:coop_threads", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@gulrak_filesystem//:filesystem", + ], +) + +cc_test( + name = "sparse_linear_layer_test", + srcs = [ + "sparse_linear_layer_test.cc", + ], + deps = [ + ":layer", + "//sparse_matmul/numerics:test_utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + deps = [ + ":layer", + ":matrix", + ":status", + ":utils", + "//sparse_matmul/numerics:fast_transcendentals", + "//sparse_matmul/numerics:test_utils", + "//sparse_matmul/numerics:types", + "//sparse_matmul/vector:cache_aligned_vector", + "@com_google_absl//absl/flags:flag", + "@com_google_googletest//:gtest_main", + "@gulrak_filesystem//:filesystem", + ], +) diff --git a/sparse_matmul/layers/csr_blocksparse_matrix.h b/sparse_matmul/layers/csr_blocksparse_matrix.h new file mode 100644 index 0000000000000000000000000000000000000000..be51573515e4433758ea3416265504308e2440f7 --- /dev/null +++ b/sparse_matmul/layers/csr_blocksparse_matrix.h @@ -0,0 +1,835 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_ +#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_ + +#include +#include +#include +#include +#include +#include + +#include "glog/logging.h" +// IWYU pragma: begin_exports +#include "sparse_matmul/compute/kernels_generic.h" +#include "sparse_matmul/compute/matmul.h" +#include "sparse_matmul/compute/thread_bounds.h" +#include "sparse_matmul/layers/masked_sparse_matrix.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" +#include "sparse_matmul/os/coop_threads.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" +// IWYU pragma: end_exports +#include "absl/memory/memory.h" + +namespace csrblocksparse { +// CsrBlockSparseMatrix stores a modified block compressed sparse row +// representation of a sparse matrix. The ordering of the weights is modified +// in the 16x1 and 1x1 cases so that a certain number (4 and 8 respectively) +// of columns of weights are stored contiguously before moving on to the next +// row. The 4x4 case stores each block contiguously. +// +// Currently it is constructed from a MaskedSparseMatrix which usees a dense +// binary mask representation. The construction generates the compressed +// representation. Further iterations will support a direct serialization +// of the compressed representation. +// +// MaskedSparseMatrix masked_matrix(rows, cols, existing_mask, existing_values) +// CsrBlockSparseMatrix matrix(masked_matrix) +// +// matrix.SpMV_bias(rhs, bias, &out); +// +// This class is thread compatible. +template +class CsrBlockSparseMatrix { + public: + CsrBlockSparseMatrix() {} + + // Reference used to indicate that this is an input and not an output. + CsrBlockSparseMatrix(const uint8_t* const& buffer, const std::size_t& len) { + ReadFromFlatBuffer(buffer, len); + ComputeRHSIndices(); + } + + template + CsrBlockSparseMatrix(const MaskedSparseMatrix& masked_matrix) { + sparsity_ = masked_matrix.sparsity(); + rows_ = masked_matrix.rows(); + cols_ = masked_matrix.cols(); + + DetermineBlockSize(masked_matrix); + + if (block_width_ == 1 && block_height_ == 1) + col_multiple_ = 8; + else + col_multiple_ = 1; + + std::vector weights(masked_matrix.values().begin(), + masked_matrix.values().end()); + + reduced_rows_ = (rows_ + block_height_ - 1) / block_height_; + rows_ = reduced_rows_ * block_height_; + reduced_cols_ = cols_ / block_width_; + + // Calculate the reduced CSR representation of the matrix. + std::vector reduced_mask(reduced_rows_ * reduced_cols_); + std::vector row_offsets = {0}; + int nnz = 0; + const auto& mask = masked_matrix.mask(); + for (int r = 0; r < reduced_rows_; ++r) { + for (int c = 0; c < reduced_cols_; ++c) { + int mask_val = mask[r * block_height_ * cols_ + c * block_width_]; + reduced_mask[r * reduced_cols_ + c] = mask_val; + nnz += mask_val; + } + row_offsets.push_back(nnz); + } + + // Make sure the reduced representation has the correct number of columns. + MakeColumnsMultiple(row_offsets, &reduced_mask, &weights); + + std::vector col_indices; + std::vector weights_csr; + std::vector nnz_per_row; + MaskAndWeightsToCsr(reduced_mask, weights, &nnz_per_row, &col_indices, + &weights_csr); + + // Generate column deltas from |col_indices|. + std::vector col_deltas; + for (int i = 0; i < col_indices.size(); ++i) { + // |col_indices| are used to index the RHS vector which is always float. + int64_t diff = sizeof(RhsType); + if (i == 0) + diff *= block_width_ * (col_indices[i]); + else + diff *= block_width_ * (col_indices[i] - col_indices[i - 1]); + + CHECK(diff < std::numeric_limits::max()) + << "delta between column indices in bytes " << diff + << " exceeded the maximum size of the DeltaType " + << std::numeric_limits::max(); + col_deltas.push_back(static_cast(diff)); + } + + // Because of pre-fetching we need some extra values at the end. + col_deltas.insert(col_deltas.end(), std::max(2, col_multiple_ + 1), 0); + nnz_per_row.insert(nnz_per_row.end(), 2, nnz_per_row.back()); + + weights_ = CacheAlignedVector(weights_csr); + col_deltas_ = CacheAlignedVector(col_deltas); + nnz_per_row_ = CacheAlignedVector(nnz_per_row); + ComputeRHSIndices(); + + num_threads_ = 0; + PrepareForThreads(1); + } + + // Constructor makes a matrix from the given weights, deltas and nnz, taking + // the other parameters from |src_matrix|. |cols| is the number of raw columns + // (NOT blocks) of the new matrix. + CsrBlockSparseMatrix( + const CsrBlockSparseMatrix& src_matrix, + const std::vector& new_weights, + const std::vector& new_deltas, const std::vector& new_nnz, + int cols) { + num_threads_ = 0; + col_multiple_ = src_matrix.col_multiple_; + block_width_ = src_matrix.block_width_; + block_height_ = src_matrix.block_height_; + reduced_rows_ = new_nnz.size(); + rows_ = reduced_rows_ * block_height_; + cols_ = cols; + reduced_cols_ = cols_ / block_width_; + weights_ = CacheAlignedVector(new_weights); + col_deltas_ = CacheAlignedVector(new_deltas); + nnz_per_row_ = CacheAlignedVector(new_nnz); + sparsity_ = 1.0f - static_cast(new_weights.size()) / (rows_ * cols_); + ComputeRHSIndices(); + name_ = src_matrix.name_; + PrepareForThreads(1); + } + + // Factory method takes a column slice out of *this and returns a sparse + // matrix that takes as inputs [|start_col|, |end_col|) of *this, and + // returns the same number of outputs, but only a partial result. + // If |keep_rhs_size|, then the new matrix takes the same rhs as the current + // matrix, but uses a subset of it, instead of expecting just the reduced rhs. + // If |start_col| > |end_col|, then we slice out the complement of the defined + // interval, ie [0, |end_col|) + [|start_col|, current end). + // NOTE That |start_col| and |end_col| are in raw column coordinates, NOT + // block units. + CsrBlockSparseMatrix SplitByColumn(int start_col, int end_col, + bool keep_rhs_size = false) const { + int weight_index = 0; + int delta_index = 0; + std::vector new_deltas; + std::vector new_weights; + std::vector new_nnz(reduced_rows_); + int col = 0; + int prev_col = keep_rhs_size ? 0 : start_col; + for (int r = 0; r < reduced_rows_; ++r) { + int reduced_col_count = nnz_per_row_[r]; + for (int c = 0; c < reduced_col_count; ++c, ++delta_index) { + col += col_deltas_[delta_index] / sizeof(RhsType); + if ((start_col < end_col && start_col <= col && col < end_col) || + (start_col > end_col && (col < end_col || col >= start_col))) { + ++new_nnz[r]; + new_deltas.push_back((col - prev_col) * sizeof(RhsType)); + prev_col = col; + for (int i = 0; i < block_width_ * block_height_; + ++i, ++weight_index) { + new_weights.push_back(weights_[weight_index]); + } + } else { + weight_index += block_width_ * block_height_; + } + } + } + int new_cols = keep_rhs_size ? cols_ : end_col - start_col; + return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz, + new_cols); + } + + // Factory method takes a row slice out of *this and returns a sparse + // matrix that takes the sampe inputs as *this, and returns the outputs for + // the range [|start_row|, |end_row|). + // NOTE That |start_row| and |end_row| are in raw column coordinates, NOT + // block units. + CsrBlockSparseMatrix SplitByRow(int start_row, int end_row) const { + int start_reduced = start_row / block_height_; + int end_reduced = end_row / block_height_; + std::vector new_nnz(nnz_per_row_.data() + start_reduced, + nnz_per_row_.data() + end_reduced); + int weight_start = 0; + for (int r = 0; r < start_reduced; ++r) { + weight_start += nnz_per_row_[r]; + } + int weight_end = weight_start; + for (int r = start_reduced; r < end_reduced; ++r) { + weight_end += nnz_per_row_[r]; + } + int delta_start = 0; + for (int i = 0; i < weight_start; ++i) { + delta_start += col_deltas_[i]; + } + std::vector new_deltas(col_deltas_.data() + weight_start, + col_deltas_.data() + weight_end); + new_deltas[0] += delta_start; + int block_size = block_height_ * block_width_; + std::vector new_weights( + weights_.data() + weight_start * block_size, + weights_.data() + weight_end * block_size); + return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz, cols_); + } + + // Combines adjacent row blocks, doubling the block height. + // This necessarily involves adding zero weights where the blocks don't align + // across adjacent pairs of rows, so use with caution, as the resulting matrix + // is most likely to run slower if very sparse to begin with. + // In the few cases where the blocks do mostly align, the resulting matmul + // could be much faster, as the number of reads of the rhs will be halved. + void DoubleBlockHeight() { + int new_rows = reduced_rows_ / 2; + std::vector new_nnz(new_rows); + std::vector new_rhs_indices; + std::vector new_weights; + int rhs_index1 = 0; + int rhs_index2 = 0; + int block_size = block_height_ * block_width_; + for (int r = 0; r < new_rows; ++r) { + int start_nnz = new_rhs_indices.size(); + rhs_index2 += nnz_per_row_[r * 2]; + int end1 = rhs_index1 + nnz_per_row_[r * 2]; + int end2 = rhs_index2 + nnz_per_row_[r * 2 + 1]; + // Run over a pair of rows with 2 iterators, combining blocks as we go, or + // padding with zeros where the block positions don't match. + while (rhs_index1 < end1 || rhs_index2 < end2) { + int col1 = rhs_index1 < end1 ? rhs_indices_[rhs_index1] : reduced_cols_; + int col2 = rhs_index2 < end2 ? rhs_indices_[rhs_index2] : reduced_cols_; + if (col1 < col2) { + // Need zero weights for row2 to pad out weights block. + new_rhs_indices.push_back(col1); + new_weights.insert(new_weights.end(), + weights_.data() + rhs_index1 * block_size, + weights_.data() + (rhs_index1 + 1) * block_size); + new_weights.insert(new_weights.end(), block_size, + static_cast(0.0f)); + ++rhs_index1; + } else if (col1 > col2) { + // Need zero weights for row1 to pad out weights block. + new_rhs_indices.push_back(col2); + new_weights.insert(new_weights.end(), block_size, + static_cast(0.0f)); + new_weights.insert(new_weights.end(), + weights_.data() + rhs_index2 * block_size, + weights_.data() + (rhs_index2 + 1) * block_size); + ++rhs_index2; + } else { + // Combine weights for both row1 and row2. + new_rhs_indices.push_back(col1); + new_weights.insert(new_weights.end(), + weights_.data() + rhs_index1 * block_size, + weights_.data() + (rhs_index1 + 1) * block_size); + new_weights.insert(new_weights.end(), + weights_.data() + rhs_index2 * block_size, + weights_.data() + (rhs_index2 + 1) * block_size); + ++rhs_index1; + ++rhs_index2; + } + } + rhs_index1 = rhs_index2; + new_nnz[r] = new_rhs_indices.size() - start_nnz; + } + block_height_ *= 2; + reduced_rows_ /= 2; + weights_ = CacheAlignedVector(new_weights); + rhs_indices_ = CacheAlignedVector(new_rhs_indices); + nnz_per_row_ = CacheAlignedVector(new_nnz); + sparsity_ = 1.0f - static_cast(new_weights.size()) / (rows_ * cols_); + ComputeColDeltas(); + if (num_threads_ > 0) { + int num_threads = num_threads_; + num_threads_ = 0; + PrepareForThreads(num_threads); + } + } + + // Allocates memory and fills buffer. + // Caller is responsible for the memory de-allocation. + // TODO(b/189958858): Both Read and Write need to eventually handle the + // different possible HalfType and DeltaType values, but punting for now as + // there is only one supported combination. + std::size_t WriteToFlatBuffer(std::string* csr_flatbuffer) { + std::size_t bytes = 0; + bytes += FixedParameterSize(); + bytes += weights_.size() * sizeof(WeightType); + bytes += col_deltas_.size() * sizeof(DeltaType); + bytes += nnz_per_row_.size() * sizeof(int); + + uint8_t* bytes_ptr_ptr = + reinterpret_cast(CHECK_NOTNULL(malloc(bytes))); + + int* int_bytes_ptr = reinterpret_cast(bytes_ptr_ptr); + + *int_bytes_ptr++ = rows_; + *int_bytes_ptr++ = cols_; + *int_bytes_ptr++ = reduced_rows_; + *int_bytes_ptr++ = reduced_cols_; + *int_bytes_ptr++ = block_width_; + *int_bytes_ptr++ = block_height_; + *int_bytes_ptr++ = col_multiple_; + *int_bytes_ptr++ = num_threads_; + *int_bytes_ptr++ = weights_.size(); + *int_bytes_ptr++ = col_deltas_.size(); + *int_bytes_ptr++ = nnz_per_row_.size(); + + float* float_bytes_ptr = reinterpret_cast(int_bytes_ptr); + *float_bytes_ptr++ = sparsity_; + + uint8_t* bytes_ptr = reinterpret_cast(float_bytes_ptr); + + memcpy(bytes_ptr, weights_.data(), weights_.size() * sizeof(WeightType)); + bytes_ptr += weights_.size() * sizeof(WeightType); + + memcpy(bytes_ptr, col_deltas_.data(), + col_deltas_.size() * sizeof(DeltaType)); + bytes_ptr += col_deltas_.size() * sizeof(DeltaType); + + memcpy(bytes_ptr, nnz_per_row_.data(), nnz_per_row_.size() * sizeof(int)); + bytes_ptr += nnz_per_row_.size() * sizeof(int); + + csr_flatbuffer->resize(bytes); + csr_flatbuffer->assign(reinterpret_cast(bytes_ptr_ptr), bytes); + free(bytes_ptr_ptr); + + return bytes; + } + + void ReadFromFlatBuffer(const uint8_t* const& bytes, const std::size_t& len) { + CHECK_GE(len, FixedParameterSize()); + + const int* int_bytes_ptr = reinterpret_cast(bytes); + rows_ = *int_bytes_ptr++; + cols_ = *int_bytes_ptr++; + reduced_rows_ = *int_bytes_ptr++; + reduced_cols_ = *int_bytes_ptr++; + block_width_ = *int_bytes_ptr++; + block_height_ = *int_bytes_ptr++; + col_multiple_ = *int_bytes_ptr++; + int num_threads = *int_bytes_ptr++; + int32_t weights_size = *int_bytes_ptr++; + int32_t col_deltas_size = *int_bytes_ptr++; + int32_t nnz_per_row_size = *int_bytes_ptr++; + + // Make sure negative sizes don't mess things up. + weights_size = std::max(0, weights_size); + col_deltas_size = std::max(0, col_deltas_size); + nnz_per_row_size = std::max(0, nnz_per_row_size); + + const float* float_bytes_ptr = + reinterpret_cast(int_bytes_ptr); + sparsity_ = *float_bytes_ptr++; + + std::size_t total_bytes = + FixedParameterSize() + weights_size * sizeof(WeightType) + + col_deltas_size * sizeof(DeltaType) + nnz_per_row_size * sizeof(int); + + CHECK_EQ(total_bytes, len) + << "total bytes: " << total_bytes << ", actual len given: " << len; + + const uint8_t* bytes_ptr = + reinterpret_cast(float_bytes_ptr); + std::vector weights_raw(weights_size); + memcpy(weights_raw.data(), bytes_ptr, weights_size * sizeof(WeightType)); + weights_ = CacheAlignedVector(weights_raw); + bytes_ptr += weights_size * sizeof(WeightType); + + std::vector deltas_raw(col_deltas_size); + memcpy(deltas_raw.data(), bytes_ptr, col_deltas_size * sizeof(DeltaType)); + col_deltas_ = CacheAlignedVector(deltas_raw); + bytes_ptr += col_deltas_size * sizeof(DeltaType); + + std::vector nnz_raw(nnz_per_row_size); + memcpy(nnz_raw.data(), bytes_ptr, nnz_per_row_size * sizeof(int)); + nnz_per_row_ = CacheAlignedVector(nnz_raw); + num_threads_ = 0; + PrepareForThreads(num_threads); + } + + // Multiply a Sparse matrix by a possibly dense matrix. Often the matrix is + // a vector with a small number of columns, hence the term "fat vector". + // 1x1 and 4x4 have specializations for output columns (ie fatness) > 5, + // and often achieve twice as many GFlops when multiplying a right hand side + // that has 5 or more columns. (Best is a multiple of 5). + // 16x1 doesn't have enough registers and just loops over the width 1 kernel. + // + // |rhs| and |out| are COLUMN MAJOR. + + // Fast Tuples WeightType, BiasType, RhsType, OutType are: + // (float, float, float, float) + // (bfloat16, float, float, float) + // and only on ARM64. All other cases use a slow generic implementation. + template + void SpMM_bias(const RhsClass& rhs, const BiasClass& bias, OutClass* out, + bool relu = false, int tid = 0, + SpinBarrier* barrier = nullptr) const { + static_assert(std::is_same::value, + "Rhs types must match"); + CHECK_LT(tid, num_threads_); + CHECK_EQ(rhs.cols(), out->cols()); + CHECK_EQ(rhs.rows(), cols_); + CHECK_GE(out->rows(), rows_); + int cols_to_go = out->cols(); + int rhs_index = *thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid); + const RhsType* rhs_ptr = rhs.data() + rhs_index * block_height_; + OutType* out_ptr = thread_bounds_.OffsetOutput(out->data(), tid); + const WeightType* weights_ptr = + thread_bounds_.OffsetWeights(weights_.data(), tid); + const DeltaType* delta_ptr = + thread_bounds_.OffsetRhsIndices(col_deltas_.data(), tid); + int offset = *delta_ptr / sizeof(RhsType); + rhs_ptr -= offset; + const int* nnz_ptr = nnz_per_row_.data() + thread_bounds_.StartRow(tid); + int assigned_rows = + thread_bounds_.StartRow(tid + 1) - thread_bounds_.StartRow(tid); + const BiasType* bias_ptr = thread_bounds_.OffsetBias(bias.data(), tid); + + while (cols_to_go > 0) { + if (block_width_ == 4 && block_height_ == 4) { + if (cols_to_go >= 5) { + detail::SpMM5_4x4( + weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, + assigned_rows, out->col_stride(), rhs.col_stride(), relu); + } else { + detail::SpMV_4x4( + weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, + assigned_rows, out->col_stride(), rhs.col_stride(), relu); + } + } else { + if (cols_to_go >= 5) { + detail::SpMM5_1x1( + weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, + assigned_rows, out->col_stride(), rhs.col_stride(), relu); + } else { + detail::SpMV_1x1( + weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, + assigned_rows, out->col_stride(), rhs.col_stride(), relu); + } + } + + if (cols_to_go >= 5) { + cols_to_go -= 5; + rhs_ptr += rhs.col_stride() * 5; + out_ptr += out->col_stride() * 5; + } else { + cols_to_go--; + rhs_ptr += rhs.col_stride(); + out_ptr += out->col_stride(); + } + if (barrier) barrier->barrier(); + } + } + template + void MatVec(const MVRhsType* rhs, const MVBiasType* bias, bool relu, int tid, + int replicas, int output_stride, OutType* output) { + CHECK_LT(tid, num_threads_); + CHECK_EQ(block_width_, 4) << "Block width must be 4!"; + if (block_height_ == 8) { + matmul_.MatVec8x4( + thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs, + thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(), + thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid), + thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu, + replicas, output_stride, thread_bounds_.OffsetOutput(output, tid)); + } else { + CHECK_EQ(block_height_, 4) << "Block height must be 4 or 8!"; + matmul_.MatVec4x4( + thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs, + thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(), + thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid), + thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu, + replicas, output_stride, thread_bounds_.OffsetOutput(output, tid)); + } + } + + int rows() const { return rows_; } + int cols() const { return cols_; } + int block_height() const { return block_height_; } + int block_width() const { return block_width_; } + float sparsity() const { return sparsity_; } + int num_threads() const { return num_threads_; } + const ThreadBounds& thread_bounds() const { return thread_bounds_; } + const CacheAlignedVector& rhs_indices() const { + return rhs_indices_; + } + const std::string& name() const { return name_; } + void set_name(const std::string& name) { name_ = name; } + const std::vector& split_points() const { + return thread_bounds_.row_starts(); + } + + std::size_t bytes() const { + return weights_.size() * sizeof(WeightType) + + col_deltas_.size() * sizeof(DeltaType) + + nnz_per_row_.size() * sizeof(int); + } + + // Multiplies a sparse matrix by a possibly dense matrix, as SpMM_bias above, + // and then samples from the output (softmax distribution) layer. + template + typename std::enable_if::value, int>::type + SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out, + float temperature, int tid, SpinBarrier* barrier, + std::minstd_rand* gen, + CacheAlignedVector* scratch) const { + SpMM_bias(rhs, bias, out, /*relu=*/false, tid, barrier); + return out->Sample(temperature, gen, scratch); + } + // Fixed32 version. + template + typename std::enable_if::value, int>::type + SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out, + float temperature, int tid, SpinBarrier* barrier, + std::minstd_rand* gen, + CacheAlignedVector* scratch) const { + // We don't pass the barrier on, as we have more work to do. + SpMM_bias(rhs, bias, out, /*relu=*/false, tid); + return out->ReducingSample(gen, scratch, tid, temperature, barrier); + } + + void Print() const { + std::cout << "Weights\n"; + weights_.Print(); + std::cout << std::endl; + std::cout << "Deltas\n"; + col_deltas_.Print(); + std::cout << std::endl; + std::cout << "nnz\n"; + nnz_per_row_.Print(); + std::cout << std::endl; + } + + // Split the computation amongst threads by rows based on the number of + // non zeros, with the addition of a constant to account for the work of the + // bias and the horizontal add at the end, and also guarantees that each + // thread writes only whole cache lines, based on the size of OutType. + // The |cache_line_size| arg is used only for testing. Normally it is provided + // through the architecture #defines. + // Each thread gets a contiguous row range (|split_points|). + // Thread t does rows [ split_points[t], split_points[t + 1] ) + // Each thread also needs to know how many non zeros were before it to skip + // (|nnz_to_skip|). And finally it also needs to know what the offset into + // the rhs vector would have been at the split point (|rhs_to_skip|). + // + // Some tricky corner cases where the number of non-zeros doesn't split + // nicely amongst the number of requested threads are not handled and default + // to one thread; these cases are only going to happen in tests and not in + // the matrices that correspond in real models. + // + // Returns the maximum number of threads that can be used; <= |num_threads|. + template + int PrepareForThreads(int num_threads, int cache_line_size = -1) { + CHECK_GT(num_threads, 0); + // we've already prepared for this number of threads, nothing to do + if (num_threads == num_threads_) return num_threads_; + + num_threads_ = num_threads; + thread_bounds_.PrepareForThreads( + block_width_, block_height_, num_threads_, + ReducedRowsPerCacheLine(cache_line_size), reduced_rows_, + nnz_per_row_.data()); + return num_threads_; + } + + // Computes and stores the |rhs_indices_| from the |col_deltas_|. + void ComputeRHSIndices() { + std::vector cumulative_deltas = CumulativeColDeltas(); + std::vector rhs_indices(cumulative_deltas.size() + + reduced_rows_); + int total_indices = 0; + int delta_index = 0; + for (int r = 0; r < reduced_rows_; ++r) { + for (int n = 0; n < nnz_per_row_[r]; ++n, ++delta_index) { + rhs_indices[total_indices++] = + cumulative_deltas[delta_index] / block_width_; + } + } + rhs_indices_ = CacheAlignedVector(rhs_indices); + } + + // Computes and stores the |col_deltas_| from the |rhs_indices_|. + void ComputeColDeltas() { + std::vector col_deltas(rhs_indices_.size()); + int prev_index = 0; + for (int i = 0; i < rhs_indices_.size(); ++i) { + int offset = rhs_indices_[i] - prev_index; + prev_index = rhs_indices_[i]; + col_deltas[i] = offset * block_width_ * sizeof(RhsType); + } + col_deltas_ = CacheAlignedVector(col_deltas); + } + + // Computes and returns the inclusive prefix sum of the deltas, ie absolute + // positions. + std::vector CumulativeColDeltas() const { + std::vector cum_col_deltas(col_deltas_.size()); + for (int i = 0; i < col_deltas_.size(); ++i) { + cum_col_deltas[i] = col_deltas_[i] / sizeof(RhsType); + if (i > 0) cum_col_deltas[i] += cum_col_deltas[i - 1]; + } + return cum_col_deltas; + } + + private: + constexpr std::size_t FixedParameterSize() const { + return sizeof(int) // rows + + sizeof(int) // cols + + sizeof(int) // reduced_rows + + sizeof(int) // reduced_cols + + sizeof(int) // block_width + + sizeof(int) // block_height + + sizeof(float) // sparsity + + sizeof(int) // col_multiple + + sizeof(int) // num_threads_ + + sizeof(int) // weights_.size() + + sizeof(int) // col_deltas_.size() + + sizeof(int); // nnz_per_row_.size() + } + // Possible block sizes are only those that are supported by the computation + // default is 1x1, other options are 4x4 and 16x1. + template + void DetermineBlockSize(const MaskedSparseMatrix& masked_matrix) { + const std::vector> kPreferredOrder = {{4, 4}}; + int rows = masked_matrix.rows(); + int cols = masked_matrix.cols(); + + for (const auto& block_size : kPreferredOrder) { + int block_height, block_width; + std::tie(block_height, block_width) = block_size; + if (cols % block_width != 0) continue; + + int reduced_rows = (rows + block_height - 1) / block_height; + int reduced_cols = cols / block_width; + + // For each possible block, confirm that it is either all 0s or all 1s. + bool all_same = true; + const auto& mask = masked_matrix.mask(); + for (int r = 0; r < reduced_rows; ++r) { + for (int c = 0; c < reduced_cols; ++c) { + int val = mask[r * block_height * cols + c * block_width]; + for (int i = 0; i < block_height; ++i) { + for (int j = 0; j < block_width; ++j) { + int index = (r * block_height + i) * cols + c * block_width + j; + if (index < masked_matrix.mask().size()) { + all_same &= (masked_matrix.mask()[index] == val); + } + } + } + } + } + + // If this block configuration is possible, accept it. + if (all_same) { + block_height_ = block_height; + block_width_ = block_width; + return; + } + } + + // No large blocks were found, default to 1x1. + block_height_ = 1; + block_width_ = 1; + } + + // CSR descriptors are for the reduced matrix, weights is the full matrix. + template + void MakeColumnsMultiple(const std::vector& row_offsets, + std::vector* reduced_mask, + std::vector* weights) { + if (col_multiple_ > 0) { + // Make sure each row has a number of columns that is a multiple of + // |col_multiple|. + for (int r = 1; r < row_offsets.size(); ++r) { + int num_row = row_offsets[r] - row_offsets[r - 1]; + int num_needed = col_multiple_ - num_row % col_multiple_; + if (num_needed < col_multiple_) { + // Find gaps in the columns where we can insert a column of 0 weights. + int num_added = 0; + for (int c = 0; c < reduced_cols_; ++c) { + if ((*reduced_mask)[(r - 1) * reduced_cols_ + c] == 0) { + (*reduced_mask)[(r - 1) * reduced_cols_ + c] = 1; + + // Zero out the weights that correspond to this block. + for (int i = 0; i < block_height_; ++i) { + for (int j = 0; j < block_width_; ++j) { + (*weights)[((r - 1) * block_height_ + i) * cols_ + + block_width_ * c + j] = InputType(0.f); + } + } + num_added++; + } + + if (num_added == num_needed) break; + } + } + } + } + } + + // Given the final dense mask and weights, convert to the compressed + // block CSR representation. + template + void MaskAndWeightsToCsr(const std::vector& mask, + const std::vector& weights, + std::vector* nnz_per_row, + std::vector* col_indices, + std::vector* weights_csr) { + std::vector row_offsets = {0}; + int nnz = 0; + // Standard CSR format. + if (block_width_ == 1 && block_height_ == 1) { + for (int r = 0; r < rows_; ++r) { + for (int c = 0; c < cols_; ++c) { + if (mask[r * cols_ + c] == 1) { + nnz++; + col_indices->push_back(c); + weights_csr->push_back(WeightType(weights[r * cols_ + c])); + } + } + row_offsets.push_back(nnz); + } + } else if (block_width_ == 4 && block_height_ == 4) { + // Weights are stored contiguously for each block in this case. + for (int r = 0; r < reduced_rows_; ++r) { + for (int c = 0; c < reduced_cols_; ++c) { + if (mask[r * reduced_cols_ + c] == 1) { + col_indices->push_back(c); + nnz++; + for (int i = 0; i < block_height_; ++i) { + for (int j = 0; j < block_width_; ++j) { + int row_index = (block_height_ * r + i) * cols_; + int w_index = row_index + block_width_ * c + j; + WeightType weight = w_index < weights.size() + ? WeightType(weights[w_index]) + : WeightType(0.0f); + weights_csr->push_back(weight); + } + } + } + } + row_offsets.push_back(nnz); + } + } + for (int i = 1; i < row_offsets.size(); ++i) + nnz_per_row->push_back(row_offsets[i] - row_offsets[i - 1]); + } + + // Returns the number of block rows per cache line. This is the minimum unit + // into which the calculation is broken for threads. + template + int ReducedRowsPerCacheLine(int override_cache_line_size = -1) const { + int line_size = kCacheLineSize; + if (override_cache_line_size >= 1) line_size = override_cache_line_size; + return std::max(line_size / (block_height_ * sizeof(OutType)), 1); + } + + int col_multiple_; + int rows_; + int cols_; + int reduced_rows_; + int reduced_cols_; + float sparsity_; + int block_width_; + int block_height_; + int num_threads_; + std::string name_; + + CacheAlignedVector weights_; + CacheAlignedVector col_deltas_; + CacheAlignedVector nnz_per_row_; + // |thread_bounds_| and |rhs_indices_| don't need to be serialized as they are + // always recalculated from serialized data. + CacheAlignedVector rhs_indices_; + Matmul matmul_; + ThreadBounds thread_bounds_; + static constexpr int kCacheLineSize = 64; +}; + +// Converts a sparse matrix represented with (|mask|, |weights|, |size|) into +// the CSR format, and returns that as a serialized string. +template +std::string ConvertDenseToSparseRepresentation_Int16Deltas( + const std::vector& mask, const std::vector& weights, + const int rows, const int cols) { + MaskedSparseMatrix masked_weights(rows, cols, mask.data(), + weights.data()); + CsrBlockSparseMatrix + sparse_masked_weights(masked_weights); + std::string buffer; + sparse_masked_weights.WriteToFlatBuffer(&buffer); + return buffer; +} + +} // namespace csrblocksparse +#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_ diff --git a/sparse_matmul/layers/csrblocksparse_test.cc b/sparse_matmul/layers/csrblocksparse_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..08a42ca31a4ba133d23feb0c2e8c1e7f826f636a --- /dev/null +++ b/sparse_matmul/layers/csrblocksparse_test.cc @@ -0,0 +1,977 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +// Placeholder for get runfiles header. +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "include/ghc/filesystem.hpp" +#include "sparse_matmul/compute/matmul.h" +#include "sparse_matmul/layers/utils.h" +#include "sparse_matmul/numerics/test_utils.h" +#include "sparse_matmul/os/coop_threads.h" + +namespace csrblocksparse { +namespace { + +inline constexpr absl::string_view kTestdataPath = "layers/testdata"; + +TEST(CSRBlockSparseMatrix, FlatBufferSerialization) { + const int kRows = 8; + const int kCols = 8; + std::vector mask = {1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, + 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, + 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}; + std::vector values(kRows * kCols, 1.f); + values[1] = 2.f; + values[3] = 3.f; + values[36] = -1.f; + values[45] = -2.f; + + csrblocksparse::CacheAlignedVector bias(kRows); + csrblocksparse::CacheAlignedVector rhs(kCols); + csrblocksparse::CacheAlignedVector out_ref(kRows); + csrblocksparse::CacheAlignedVector out_test(kRows); + + bias.FillZero(); + rhs.FillOnes(); + + csrblocksparse::MaskedSparseMatrix matrix(kRows, kCols, mask.data(), + values.data()); + + matrix.SpMM_bias(rhs, bias, &out_ref); + + csrblocksparse::CsrBlockSparseMatrix + block_sparse_matrix(matrix); + + std::string buffer; + std::size_t num_bytes = block_sparse_matrix.WriteToFlatBuffer(&buffer); + + csrblocksparse::CsrBlockSparseMatrix + new_block_sparse_matrix(reinterpret_cast(buffer.c_str()), + num_bytes); + + new_block_sparse_matrix.SpMM_bias(rhs, bias, &out_test); + + CheckResult(out_ref, out_test, kCols); +} + +template +void CorrectnessCheckBlockSpMM(int rows, int cols, int block_height, + int block_width, float sparsity, + bool use_relu = false, int num_threads = 1, + int fatness = 1, bool test_matmul = false) { + using BiasType = typename TypeOfProduct::type; + MaskedSparseMatrix matrix(rows, cols, sparsity, block_height, + block_width); + matrix.CastWeights(); + FatCacheAlignedVector rhs(cols, fatness); + CacheAlignedVector bias(rows); + FatCacheAlignedVector out(rows, fatness); + + bias.FillRandom(); + rhs.FillRandom(); + out.FillZero(); + FatCacheAlignedVector out_reference = out; + + matrix.SpMM_bias(rhs, bias, &out_reference, use_relu); + + CsrBlockSparseMatrix sparse_matrix(matrix); + + SparseLinearLayer sparse_linear_layer( + std::move(sparse_matrix), std::move(bias)); + num_threads = sparse_linear_layer.PrepareForThreads(num_threads); + + // Checks that the result of applying each thread's portion serially is + // correct. + for (int thread_id = 0; thread_id < num_threads; ++thread_id) { + sparse_linear_layer.SpMM_bias(rhs, &out, use_relu, thread_id); + } + + CheckResult(out_reference, out, sparse_linear_layer.cols()); + + if (test_matmul) { + for (int thread_id = 0; thread_id < num_threads; ++thread_id) { + sparse_linear_layer.MatVec(rhs, use_relu, thread_id, + /*replicas=*/1, /*output_stride=*/0, &out); + } + + CheckResult(out_reference, out, sparse_linear_layer.cols()); + } +} + +// Does: +// y = Ax + b; +// x = Ay + b; +// y = Ax + b; +// +// to make sure that dependent multiplies are correct. +template +void ThreadBody( + SpinBarrier* spin_barrier, int tid, + const SparseLinearLayer& sparse_linear_layer, + FatCacheAlignedVector* rhs, FatCacheAlignedVector* out, + bool use_relu) { + sparse_linear_layer.SpMM_bias(*rhs, out, use_relu, tid); + spin_barrier->barrier(); + sparse_linear_layer.SpMM_bias(*out, rhs, use_relu, tid); + spin_barrier->barrier(); + sparse_linear_layer.SpMM_bias(*rhs, out, use_relu, tid); +} + +template +void CorrectnessCheckBlockSpMM_MultiThread(int rows, int cols, int block_height, + int block_width, float sparsity, + bool use_relu = false, + int num_threads = 1, + int fatness = 1) { + typedef typename TypeOfProduct::type BiasType; + CHECK(rows == cols); + MaskedSparseMatrix matrix(rows, cols, sparsity, block_height, + block_width); + matrix.CastWeights(); + FatCacheAlignedVector rhs(cols, fatness); + FatCacheAlignedVector rhs_mt(cols, fatness); + CacheAlignedVector bias(rows); + FatCacheAlignedVector out(rows, fatness); + + bias.FillOnes(); + rhs.FillOnes(); + rhs_mt.FillOnes(); + out.FillZero(); + FatCacheAlignedVector out_reference = out; + + matrix.SpMM_bias(rhs, bias, &out_reference, use_relu); + matrix.SpMM_bias(out_reference, bias, &rhs, use_relu); + matrix.SpMM_bias(rhs, bias, &out_reference, use_relu); + + CsrBlockSparseMatrix sparse_matrix(matrix); + + num_threads = sparse_matrix.PrepareForThreads(num_threads, + /*cache_line_size=*/1); + + SparseLinearLayer sparse_linear_layer( + std::move(sparse_matrix), std::move(bias)); + + csrblocksparse::LaunchOnThreadsWithBarrier( + num_threads, ThreadBody, + sparse_linear_layer, &rhs_mt, &out, use_relu); + + CheckResult(out_reference, out, cols); +} + +} // namespace + +TEST(MaskedSparseCorrectness, HandCoded) { + const int kRows = 8; + const int kCols = 8; + // clang-format off + std::vector mask = {1, 1, 0, 0, 0, 1, 1, 1, + 0, 1, 0, 1, 0, 1, 0, 1, + 1, 0, 0, 1, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 1, 1, 0, 0, + 1, 1, 0, 0, 1, 1, 0, 0, + 1, 0, 0, 0, 0, 1, 0, 1}; + // clang-format on + std::vector values(kRows * kCols, 1.f); + + std::vector answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f}; + + MaskedSparseMatrix matrix(kRows, kCols, mask.data(), values.data()); + CacheAlignedVector rhs(kCols); + CacheAlignedVector bias(kRows); + CacheAlignedVector out(kRows); + + bias.FillOnes(); + rhs.FillOnes(); + out.FillZero(); + + MaskedLinearLayer masked_linear_layer(std::move(matrix), + std::move(bias)); + + masked_linear_layer.SpMM_bias(rhs, &out); + + for (int i = 0; i < kRows; ++i) { + EXPECT_EQ(answer[i], out[i]); + } +} + +TEST(MaskedSparseCorrectness, HandCodedFatVector) { + const int kRows = 8; + const int kCols = 8; + // clang-format off + std::vector mask = {1, 1, 0, 0, 0, 1, 1, 1, + 0, 1, 0, 1, 0, 1, 0, 1, + 1, 0, 0, 1, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 1, 1, 0, 0, + 1, 1, 0, 0, 1, 1, 0, 0, + 1, 0, 0, 0, 0, 1, 0, 1}; + // clang-format on + + std::vector values(kRows * kCols, 1.f); + std::vector answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f}; + + MaskedSparseMatrix matrix(kRows, kCols, mask.data(), values.data()); + const int kMaxWidth = 5; + for (int width = 5; width <= kMaxWidth; ++width) { + FatCacheAlignedVector rhs(kCols, width); + CacheAlignedVector bias(kRows); + FatCacheAlignedVector out(kRows, width); + + bias.FillOnes(); + rhs.FillOnes(); + out.FillZero(); + + MaskedLinearLayer masked_linear_layer(std::move(matrix), + std::move(bias)); + + masked_linear_layer.SpMM_bias(rhs, &out); + + for (int i = 0; i < kRows; ++i) { + for (int width = 0; width < kMaxWidth; ++width) { + EXPECT_EQ(answer[i], out[i + width * kRows]); + } + } + } +} + +TEST(CsrBlockSparseMatrix, HandCodedMultiThread) { + const int kRows = 8; + const int kCols = 8; + // clang-format off + std::vector mask = {1, 1, 0, 0, 0, 1, 1, 1, + 0, 1, 0, 1, 0, 1, 0, 1, + 1, 0, 0, 1, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 1, 1, 0, 0, + 1, 1, 0, 0, 1, 1, 0, 0, + 1, 0, 0, 0, 0, 1, 0, 1}; + // clang-format on + std::vector values(kRows * kCols, 1.f); + + std::vector answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f}; + + MaskedSparseMatrix matrix(kRows, kCols, mask.data(), values.data()); + CacheAlignedVector rhs(kCols); + CacheAlignedVector bias(kRows); + CacheAlignedVector out(kRows); + + bias.FillOnes(); + rhs.FillOnes(); + out.FillZero(); + + CacheAlignedVector bias_csr = bias; + + CsrBlockSparseMatrix sparse_matrix(matrix); + + MaskedLinearLayer masked_linear_layer(std::move(matrix), + std::move(bias)); + + masked_linear_layer.SpMM_bias(rhs, &out); + + SparseLinearLayer sparse_linear_layer( + std::move(sparse_matrix), std::move(bias_csr)); + sparse_linear_layer.PrepareForThreads(2, /*cache_line_size=*/1); + + CacheAlignedVector out_tmp(kRows); + const bool kUseRelu = false; + sparse_linear_layer.SpMM_bias(rhs, &out_tmp, kUseRelu, /*tid=*/0); + sparse_linear_layer.SpMM_bias(rhs, &out_tmp, kUseRelu, /*tid=*/1); + + for (int i = 0; i < kRows; ++i) { + EXPECT_EQ(answer[i], out_tmp[i]); + } +} + +TEST(TestCasts, TestBfloat16) { + const int kRows = 1000; + const int kCols = 100; + const float kSparsity = 0.f; + + MaskedSparseMatrix matrix(kRows, kCols, kSparsity); + MaskedSparseMatrix matrix_bfloat16(kRows, kCols, matrix.mask().data(), + matrix.values().data()); + + matrix_bfloat16.CastWeights(); + + CheckResult(matrix.values(), matrix_bfloat16.values(), kCols); +} + +TEST(TestCasts, TestFP16) { + const int kRows = 1000; + const int kCols = 100; + const float kSparsity = 0.f; + + MaskedSparseMatrix matrix(kRows, kCols, kSparsity); +#if !defined __arm__ && !defined __aarch64__ + // Conversion doesn't handle denormals, so flush denormals to zero first. + for (int i = 0; i < matrix.values().size(); ++i) { + if (matrix.data()[i] < 1. / static_cast(1 << 14)) + matrix.data()[i] = 0.f; + } +#endif + MaskedSparseMatrix matrix_fp16(kRows, kCols, matrix.mask().data(), + matrix.values().data()); + + matrix_fp16.CastWeights(); + + CheckResult(matrix.values(), matrix_fp16.values(), kCols); +} + +TEST(TestCasts, TestFixed16) { + const int kRows = 100000; + const int kCols = 1; + const float kSparsity = 0.f; + + MaskedSparseMatrix matrix(kRows, kCols, kSparsity); + + // Relative error for fixed point is high near 0. + for (int i = 0; i < matrix.values().size(); ++i) { + // 1.1e-3 is based on the max error of .013 and a grid spacing of 1 / 2**16 + // == 3e-5. 3e-5 / .013 / 2 = 1.1e-3. + if (std::abs(matrix.data()[i]) < 1.1e-3) { + matrix.data()[i] = 0.f; + } + } + + MaskedSparseMatrix matrix_fixed16 = matrix; + + matrix_fixed16.CastWeights>(); + + CheckResult(matrix.values(), matrix_fixed16.values(), kCols); +} + +TEST(TestCasts, TestFixed32) { + const int kRows = 100000; + const int kCols = 1; + const float kSparsity = 0.f; + + MaskedSparseMatrix matrix(kRows, kCols, kSparsity); + MaskedSparseMatrix matrix_fixed32 = matrix; + + matrix_fixed32.CastWeights>(); + + CheckResult(matrix.values(), matrix_fixed32.values(), kCols); +} + +template +void TestSpMM(int block_width, int block_height, int fatness, + bool test_matmul = false) { + std::array use_relu = {false, true}; + std::vector sparsity_levels = {.5, .8, .9, .95, .98}; + std::vector> sizes = {{8, 8}, {128, 128}, {128, 64}, + {256, 192}, {512, 512}, {1024, 512}, + {384, 384}, {512, 384}}; + for (int num_threads = 1; num_threads < 2 + test_matmul; ++num_threads) { + for (const auto& relu : use_relu) { + for (const auto& sparsity : sparsity_levels) { + for (const auto& size : sizes) { + int rows, cols; + std::tie(rows, cols) = size; + CorrectnessCheckBlockSpMM( + rows, cols, block_height, block_width, sparsity, relu, + num_threads, fatness, test_matmul); + } + } + } + } +} + +template +void TestSpMM_MultiThread(int block_width, int block_height, int fatness) { + std::array use_relu = {false, true}; + std::vector sparsity_levels = {.5, .8, .9, .95, .98}; + std::vector> sizes = { + {48, 48}, {128, 128}, {512, 512}, {384, 384}}; + for (int num_threads = 1; num_threads < 5; ++num_threads) { + for (const auto& relu : use_relu) { + for (const auto& sparsity : sparsity_levels) { + for (const auto& size : sizes) { + int rows, cols; + std::tie(rows, cols) = size; + CorrectnessCheckBlockSpMM_MultiThread( + rows, cols, block_height, block_width, sparsity, relu, + num_threads, fatness); + } + } + } + } +} + +template +void TestSumVectors(int start = 0, int end = -1, int size = 6) { + std::vector values; + std::vector answer; + + for (int i = 1; i < size + 1; ++i) { + const float x = static_cast(i); + values.push_back(static_cast(x)); + answer.push_back(static_cast(x * 2)); + } + + if (end == -1) { + end = values.size(); + } + + csrblocksparse::CacheAlignedVector result(values.size()); + csrblocksparse::CacheAlignedVector values_aligned(values); + detail::SumVectors(start, end, values_aligned.data(), values_aligned.data(), + result.data()); + for (int i = start; i < end; ++i) { + EXPECT_EQ(static_cast(answer[i]), static_cast(result[i])); + } +} + +TEST(CsrBlockSparseMatrix, SumVectors_Generic) { + TestSumVectors(); + TestSumVectors(1); + TestSumVectors(1, 4); +} + +TEST(CsrBlockSparseMatrix, SumVectors_Bfloat16) { + TestSumVectors(); + TestSumVectors(1); + TestSumVectors(1, 4); +} + +// For SIMD-optimized SumVectors, the memory of the vector should be at least +// |kSIMDWidth * sizeof(float)| long, and the start position has to be an +// aligned memory location. So setting |size| to be 100 to be safe and +// |start| to be 0 (|start| == 1 is not aligned). +TEST(CsrBlockSparseMatrix, SumVectors_Fixed16) { + TestSumVectors>(0, -1, 100); + TestSumVectors>(0, 4, 100); +} + +TEST(CsrBlockSparseMatrix, SumVectors_Fixed32) { + TestSumVectors>(0, -1, 100); + TestSumVectors>(0, 4, 100); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block4x4_Bfloat16) { + TestSpMM(/*block_width=*/4, + /*block_height=*/4, + /*fatness=*/7); +} + +// This actually uses multiple threads, and uses the output as the input for +// multiple steps to test that synchronization and memory visibility is +// working correctly.Requires square matrices. +TEST(CsrBlockSparseMatrix, SpMV_4x4MultiThreading_Bfloat16) { + TestSpMM_MultiThread( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_4x4MultiThreading_Bfloat16) { + TestSpMM_MultiThread( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block1x1_Bfloat16) { + TestSpMM(/*block_width=*/1, + /*block_height=*/1, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block1x1_Bfloat16) { + TestSpMM(/*block_width=*/1, + /*block_height=*/1, + /*fatness=*/7); +} + +// This actually uses multiple threads, and uses the output as the input for +// multiple steps to test that synchronization and memory visibility is +// working correctly.Requires square matrices. +TEST(CsrBlockSparseMatrix, SpMV_1x1MultiThreading_Bfloat16) { + TestSpMM_MultiThread( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_1x1MultiThreading_Bfloat16) { + TestSpMM_MultiThread( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block4x4_float) { + TestSpMM(/*block_width=*/4, + /*block_height=*/4, + /*fatness=*/1, + /*test_matmul=*/true); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block4x4_float) { + TestSpMM(/*block_width=*/4, + /*block_height=*/4, + /*fatness=*/7); +} + +// This actually uses multiple threads, and uses the output as the input for +// multiple steps to test that synchronization and memory visibility is +// working correctly.Requires square matrices. +TEST(CsrBlockSparseMatrix, SpMV_4x4MultiThreading_float) { + TestSpMM_MultiThread(/*block_width=*/4, + /*block_height=*/4, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_4x4MultiThreading_float) { + TestSpMM_MultiThread(/*block_width=*/4, + /*block_height=*/4, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block1x1_float) { + TestSpMM(/*block_width=*/1, + /*block_height=*/1, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block1x1_float) { + TestSpMM(/*block_width=*/1, + /*block_height=*/1, + /*fatness=*/7); +} + +// This actually uses multiple threads, and uses the output as the input for +// multiple steps to test that synchronization and memory visibility is +// working correctly.Requires square matrices. +TEST(CsrBlockSparseMatrix, SpMV_1x1MultiThreading_float) { + TestSpMM_MultiThread(/*block_width=*/1, + /*block_height=*/1, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_1x1MultiThreading_float) { + TestSpMM_MultiThread(/*block_width=*/1, + /*block_height=*/1, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_32) { + TestSpMM, csrblocksparse::fixed16<4>, + typename csrblocksparse::TypeOfProduct< + csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/1, + /*test_matmul=*/true); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_32) { + TestSpMM, csrblocksparse::fixed16<4>, + typename csrblocksparse::TypeOfProduct< + csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_32) { + TestSpMM, csrblocksparse::fixed16<4>, + typename csrblocksparse::TypeOfProduct< + csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_32) { + TestSpMM, csrblocksparse::fixed16<4>, + typename csrblocksparse::TypeOfProduct< + csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_16) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed16<8>>( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/1, + /*test_matmul=*/true); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_16) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed16<8>>( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_16) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed16<8>>( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_16) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed16<8>>( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_32_unmatched) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed32<13>>( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/1, + /*test_matmul=*/true); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_32_unmatched) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed32<13>>( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_32_unmatched) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed32<13>>( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_32_unmatched) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed32<13>>( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, RhsIndicesDeltasRoundTrip) { + MaskedSparseMatrix matrix(/*rows=*/256, /*cols=*/256, + /*sparsity=*/0.9, /*block_height=*/4, + /*block_width=*/4); + CsrBlockSparseMatrix sparse_matrix(matrix); + CacheAlignedVector copy_indices = sparse_matrix.rhs_indices(); + sparse_matrix.ComputeColDeltas(); + sparse_matrix.ComputeRHSIndices(); + // They get padded when created, so the newer one could be bigger. + EXPECT_LE(copy_indices.size(), sparse_matrix.rhs_indices().size()); + for (int i = 0; i < copy_indices.size(); ++i) { + EXPECT_EQ(copy_indices[i], sparse_matrix.rhs_indices()[i]) << "i=" << i; + } +} + +// Tests that a Layer that is split into 2 by columns (inputs) computes the same +// result as the original layer. +TEST(CsrBlockSparseMatrix, SplitByCol) { + int kRows = 1024; + int kCols = 1024; + MaskedSparseMatrix matrix(kRows, kCols, 0.95, /*block_height=*/4, + /*block_width=*/4); + FatCacheAlignedVector rhs(kCols, /*cols=*/1); + CacheAlignedVector bias(kRows); + FatCacheAlignedVector out1(kRows, /*cols=*/1); + FatCacheAlignedVector out2(kRows, /*cols=*/1); + + bias.FillRandom(); + rhs.FillRandom(); + out1.FillZero(); + out2.FillZero(); + FatCacheAlignedVector out_reference = out1; + + CsrBlockSparseMatrix sparse_matrix(matrix); + + SparseLinearLayer sparse_linear_layer(std::move(sparse_matrix), + std::move(bias)); + sparse_linear_layer.PrepareForThreads(1); + sparse_linear_layer.SpMM_bias(rhs, &out_reference, /*relu=*/false, + /*tid=*/0); + // Split the layer into 2 parts. + SparseLinearLayer part1, part2; + sparse_linear_layer.SplitInputs(&part1, &part2); + part1.PrepareForThreads(1); + part2.PrepareForThreads(1); + EXPECT_EQ(kRows, part1.rows()); + EXPECT_EQ(kCols / 2, part1.cols()); + EXPECT_EQ(kRows, part2.rows()); + EXPECT_EQ(kCols / 2, part2.cols()); + MutableVectorView rhs1(&rhs, 0, kCols / 2); + MutableVectorView rhs2(&rhs, kCols / 2, kCols / 2); + for (int i = 0; i < kCols / 2; ++i) { + EXPECT_FLOAT_EQ(rhs[i], rhs1.data()[i]); + EXPECT_FLOAT_EQ(rhs[i + kCols / 2], rhs2.data()[i]); + } + part1.SpMM_bias(rhs1, &out1, /*relu=*/false, /*tid=*/0); + part2.SpMM_bias(rhs2, &out2, /*relu=*/false, /*tid=*/0); + // Check that out1 + out2 = out_reference. + for (int i = 0; i < kRows; ++i) { + EXPECT_NEAR(out_reference[i], out1[i] + out2[i], 2e-5) + << " i=" << i << " out1=" << out1[i] << " out2=" << out2[i]; + } +} +// Tests that a Layer that is split into 2 by rows (outputs) computes the same +// result as the original layer. +TEST(CsrBlockSparseMatrix, SplitByRow) { + int kRows = 1024; + int kCols = 1024; + MaskedSparseMatrix matrix(kRows, kCols, 0.95, /*block_height=*/4, + /*block_width=*/4); + FatCacheAlignedVector rhs(kCols, /*cols=*/1); + CacheAlignedVector bias(kRows); + FatCacheAlignedVector out1(kRows, /*cols=*/1); + FatCacheAlignedVector out2(kRows, /*cols=*/1); + + bias.FillRandom(); + rhs.FillRandom(); + out1.FillZero(); + out2.FillZero(); + FatCacheAlignedVector out_reference = out1; + + CsrBlockSparseMatrix sparse_matrix(matrix); + + SparseLinearLayer sparse_linear_layer(std::move(sparse_matrix), + std::move(bias)); + sparse_linear_layer.PrepareForThreads(1); + sparse_linear_layer.SpMM_bias(rhs, &out_reference, /*relu=*/false, + /*tid=*/0); + // Split the layer into 2 parts. + SparseLinearLayer part1, part2; + sparse_linear_layer.SplitOutputs(&part1, &part2); + part1.PrepareForThreads(1); + part2.PrepareForThreads(1); + EXPECT_EQ(kRows / 2, part1.rows()); + EXPECT_EQ(kCols, part1.cols()); + EXPECT_EQ(kRows / 2, part2.rows()); + EXPECT_EQ(kCols, part2.cols()); + MutableVectorView out2a(&out2, 0, kRows / 2); + MutableVectorView out2b(&out2, kRows / 2, kRows / 2); + part1.SpMM_bias(rhs, &out2a, /*relu=*/false, /*tid=*/0); + part2.SpMM_bias(rhs, &out2b, /*relu=*/false, /*tid=*/0); + // Check that out2 = out_reference. + for (int i = 0; i < kRows; ++i) { + EXPECT_NEAR(out_reference[i], out2[i], 2e-5) + << " i=" << i << " out1=" << out_reference[i] << " out2=" << out2[i]; + } +} + +TEST(CsrBlockSparseMatrix, MutableVectorView) { + const int kRows = 1024; + const int kCols = 1024; + const int kFatness = 2; + + std::vector values(kRows * kCols, 1.f); + std::vector mask(kRows * kCols); + for (int i = 0; i < mask.size(); ++i) mask[i] = i % 2; + + auto masked_matrix = + MaskedSparseMatrix(kRows, kCols, mask.data(), values.data()); + auto sparse_matrix = CsrBlockSparseMatrix(masked_matrix); + FatCacheAlignedVector x(kCols, kFatness); + x.FillOnes(); + + CacheAlignedVector bias(kRows); + bias.FillZero(); + + // First check that we can use spans as output. Split a multiplication + // into upper and lower halves times the full vector: + // --------------- x t + // | | x t + // | | x t + // --------------- = + // | | x b + // | | x b + // --------------- x b + + FatCacheAlignedVector out(kRows, kFatness); + FatCacheAlignedVector out_view(kRows, kFatness); + + MutableVectorView out_view_top(&out_view, 0, kRows / 2); + MutableVectorView out_view_bottom(&out_view, kRows / 2, kRows / 2); + + sparse_matrix.SpMM_bias(x, bias, &out); + + auto masked_matrix_top = + MaskedSparseMatrix(kRows / 2, kCols, mask.data(), values.data()); + auto masked_matrix_bottom = MaskedSparseMatrix( + kRows / 2, kCols, mask.data() + kRows * kCols / 2, + values.data() + kRows * kCols / 2); + auto sparse_matrix_top = + CsrBlockSparseMatrix(masked_matrix_top); + auto sparse_matrix_bottom = + CsrBlockSparseMatrix(masked_matrix_bottom); + + sparse_matrix_top.SpMM_bias(x, bias, &out_view_top); + sparse_matrix_bottom.SpMM_bias(x, bias, &out_view_bottom); + + CheckResult(out, out_view, kCols); + + // Check that we can use a span as an input vector. Multiply upper left + // portion of the matrix by the top half of the vector. + // --------------- + // |oooooo | x q + // |oooooo | x q + // | | = + // | | + // --------------- + + auto masked_matrix_quarter = MaskedSparseMatrix( + kRows / 2, kCols / 2, mask.data(), values.data()); + auto sparse_matrix_quarter = + CsrBlockSparseMatrix(masked_matrix_quarter); + + MutableVectorView x_top(&x, 0, kCols / 2); + FatCacheAlignedVector out_correct(kRows / 2, /*cols=*/2); + + for (int i = 0; i < kFatness * (kRows / 2); ++i) out_correct[i] = 256.f; + + MutableVectorView bias_top(&bias, 0, kRows / 2); + FatCacheAlignedVector out_quarter(kRows / 2, kFatness); + + sparse_matrix_quarter.SpMM_bias(x_top, bias_top, &out_quarter); + + CheckResult(out_correct, out_quarter, kCols / 2); +} + +namespace { + +bool skip_test(const absl::Status& status, absl::string_view msg) { + if (!status.ok()) { + LOG(INFO) << "Couldn't load " << msg << ", skipping test " << status; + return true; + } + + return false; +} + +} // namespace + +TEST(CsrBlockSparseMatrix, ModelMatrices_Bfloat16) { + std::vector names = { + "768_512_95_4x4_wavernn_gru_", "768_512_95_4x4_coarseproj_", + "768_512_95_4x4_coarselogit_", "768_512_95_4x4_fineproj_", + "768_512_95_4x4_finelogit_", "lyra_conv1d_"}; + const std::string kPath = +#if defined __arm__ || defined __aarch64__ + "/data/local/tmp/"; +#else + (ghc::filesystem::current_path() / kTestdataPath).string(); +#endif + for (auto& layer_name : names) { + SparseLinearLayer sparse_linear_layer; + auto status = LoadSparseLayer(layer_name, /*zipped=*/true, + &sparse_linear_layer, kPath); + // If the files don't exist on the device we're running on, just skip this + // test and log that it was skipped. + if (skip_test(status, layer_name)) return; + + int rows = sparse_linear_layer.rows(); + int cols = sparse_linear_layer.cols(); + + MaskedLinearLayer masked_linear_layer; + status = LoadMaskedLayer(layer_name, /*zipped=*/true, + &masked_linear_layer, kPath); + if (skip_test(status, layer_name)) return; + masked_linear_layer.CastWeights(); + + CacheAlignedVector rhs(cols); + CacheAlignedVector out_ref(rows); + CacheAlignedVector out_spmv(rows); + + rhs.FillRandom(); + out_ref.FillZero(); + out_spmv.FillZero(); + + std::array use_relus = {false, true}; + for (bool use_relu : use_relus) { + masked_linear_layer.SpMM_bias(rhs, &out_ref, use_relu); + sparse_linear_layer.SpMM_bias(rhs, &out_spmv, use_relu); + + CheckResult(out_ref, out_spmv, cols); + } + } +} + +TEST(CsrBlockSparseMatrix, ModelMatrices_float) { + std::vector names = { + "768_512_95_4x4_wavernn_gru_", "768_512_95_4x4_coarseproj_", + "768_512_95_4x4_coarselogit_", "768_512_95_4x4_fineproj_", + "768_512_95_4x4_finelogit_", "lyra_conv1d_"}; + const std::string kPath = +#if defined __arm__ || defined __aarch64__ + "/data/local/tmp/"; +#else + (ghc::filesystem::current_path() / kTestdataPath).string(); +#endif + for (auto& layer_name : names) { + SparseLinearLayer sparse_linear_layer; + auto status = LoadSparseLayer(layer_name, /*zipped=*/true, + &sparse_linear_layer, kPath); + // If the files don't exist on the device we're running on, just skip this + // test and log that it was skipped. + if (skip_test(status, layer_name)) return; + + int rows = sparse_linear_layer.rows(); + int cols = sparse_linear_layer.cols(); + + MaskedLinearLayer masked_linear_layer; + status = LoadMaskedLayer(layer_name, /*zipped=*/true, + &masked_linear_layer, kPath); + if (skip_test(status, layer_name)) return; + + CacheAlignedVector rhs(cols); + CacheAlignedVector out_ref(rows); + CacheAlignedVector out_spmv(rows); + + rhs.FillRandom(); + out_ref.FillZero(); + out_spmv.FillZero(); + + std::array use_relus = {false, true}; + for (bool use_relu : use_relus) { + masked_linear_layer.SpMM_bias(rhs, &out_ref, use_relu); + sparse_linear_layer.SpMM_bias(rhs, &out_spmv, use_relu); + + CheckResult(out_ref, out_spmv, cols); + } + } +} + +#undef SKIP_TEST + +} // namespace csrblocksparse diff --git a/sparse_matmul/layers/errno_mapping.cc b/sparse_matmul/layers/errno_mapping.cc new file mode 100644 index 0000000000000000000000000000000000000000..558abb33937619edc9bcc6a242e414d57bfcc11c --- /dev/null +++ b/sparse_matmul/layers/errno_mapping.cc @@ -0,0 +1,195 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/layers/errno_mapping.h" + +#include + +#include "absl/strings/str_cat.h" + +namespace csrblocksparse { + +namespace { + +absl::StatusCode ErrnoToCode(int error_number) { + switch (error_number) { + case 0: + return absl::StatusCode::kOk; + case EINVAL: // Invalid argument + case ENAMETOOLONG: // Filename too long + case E2BIG: // Argument list too long + case EDESTADDRREQ: // Destination address required + case EDOM: // Mathematics argument out of domain of function + case EFAULT: // Bad address + case EILSEQ: // Illegal byte sequence + case ENOPROTOOPT: // Protocol not available + case ENOSTR: // Not a STREAM + case ENOTSOCK: // Not a socket + case ENOTTY: // Inappropriate I/O control operation + case EPROTOTYPE: // Protocol wrong type for socket + case ESPIPE: // Invalid seek + return absl::StatusCode::kInvalidArgument; + case ETIMEDOUT: // Connection timed out + case ETIME: // Timer expired + return absl::StatusCode::kDeadlineExceeded; + case ENODEV: // No such device + case ENOENT: // No such file or directory +#ifdef ENOMEDIUM + case ENOMEDIUM: // No medium found +#endif + case ENXIO: // No such device or address + case ESRCH: // No such process + return absl::StatusCode::kNotFound; + case EEXIST: // File exists + case EADDRNOTAVAIL: // Address not available + case EALREADY: // Connection already in progress +#ifdef ENOTUNIQ + case ENOTUNIQ: // Name not unique on network +#endif + return absl::StatusCode::kAlreadyExists; + case EPERM: // Operation not permitted + case EACCES: // Permission denied +#ifdef ENOKEY + case ENOKEY: // Required key not available +#endif + case EROFS: // Read only file system + return absl::StatusCode::kPermissionDenied; + case ENOTEMPTY: // Directory not empty + case EISDIR: // Is a directory + case ENOTDIR: // Not a directory + case EADDRINUSE: // Address already in use + case EBADF: // Invalid file descriptor +#ifdef EBADFD + case EBADFD: // File descriptor in bad state +#endif + case EBUSY: // Device or resource busy + case ECHILD: // No child processes + case EISCONN: // Socket is connected +#ifdef EISNAM + case EISNAM: // Is a named type file +#endif +#ifdef ENOTBLK + case ENOTBLK: // Block device required +#endif + case ENOTCONN: // The socket is not connected + case EPIPE: // Broken pipe +#ifdef ESHUTDOWN + case ESHUTDOWN: // Cannot send after transport endpoint shutdown +#endif + case ETXTBSY: // Text file busy +#ifdef EUNATCH + case EUNATCH: // Protocol driver not attached +#endif + return absl::StatusCode::kFailedPrecondition; + case ENOSPC: // No space left on device +#ifdef EDQUOT + case EDQUOT: // Disk quota exceeded +#endif + case EMFILE: // Too many open files + case EMLINK: // Too many links + case ENFILE: // Too many open files in system + case ENOBUFS: // No buffer space available + case ENODATA: // No message is available on the STREAM read queue + case ENOMEM: // Not enough space + case ENOSR: // No STREAM resources +#ifdef EUSERS + case EUSERS: // Too many users +#endif + return absl::StatusCode::kResourceExhausted; +#ifdef ECHRNG + case ECHRNG: // Channel number out of range +#endif + case EFBIG: // File too large + case EOVERFLOW: // Value too large to be stored in data type + case ERANGE: // Result too large + return absl::StatusCode::kOutOfRange; +#ifdef ENOPKG + case ENOPKG: // Package not installed +#endif + case ENOSYS: // Function not implemented + case ENOTSUP: // Operation not supported + case EAFNOSUPPORT: // Address family not supported +#ifdef EPFNOSUPPORT + case EPFNOSUPPORT: // Protocol family not supported +#endif + case EPROTONOSUPPORT: // Protocol not supported +#ifdef ESOCKTNOSUPPORT + case ESOCKTNOSUPPORT: // Socket type not supported +#endif + case EXDEV: // Improper link + return absl::StatusCode::kUnimplemented; + case EAGAIN: // Resource temporarily unavailable +#ifdef ECOMM + case ECOMM: // Communication error on send +#endif + case ECONNREFUSED: // Connection refused + case ECONNABORTED: // Connection aborted + case ECONNRESET: // Connection reset + case EINTR: // Interrupted function call +#ifdef EHOSTDOWN + case EHOSTDOWN: // Host is down +#endif + case EHOSTUNREACH: // Host is unreachable + case ENETDOWN: // Network is down + case ENETRESET: // Connection aborted by network + case ENETUNREACH: // Network unreachable + case ENOLCK: // No locks available + case ENOLINK: // Link has been severed +#ifdef ENONET + case ENONET: // Machine is not on the network +#endif + return absl::StatusCode::kUnavailable; + case EDEADLK: // Resource deadlock avoided +#ifdef ESTALE + case ESTALE: // Stale file handle +#endif + return absl::StatusCode::kAborted; + case ECANCELED: // Operation cancelled + return absl::StatusCode::kCancelled; + default: + return absl::StatusCode::kUnknown; + } +} + +// POSIX `strerror_r()` returns `int`. +ABSL_ATTRIBUTE_UNUSED std::string StrErrorResult(int result, const char* buffer, + int error_code) { + if (ABSL_PREDICT_FALSE(result != 0)) { + return absl::StrCat("Unknown error ", error_code); + } + return buffer; +} + +// GNU `strerror_r()` returns `char*`. +ABSL_ATTRIBUTE_UNUSED std::string StrErrorResult(char* result, + const char* buffer, + int error_code) { + return result; +} + +std::string StrError(int error_code) { + char message[256]; + return StrErrorResult(strerror_r(error_code, message, sizeof(message)), + message, error_code); +} + +} // namespace + +absl::Status ErrnoToCanonicalStatus(int error_number, + absl::string_view message) { + return absl::Status(ErrnoToCode(error_number), + absl::StrCat(message, ": ", StrError(error_number))); +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/layers/errno_mapping.h b/sparse_matmul/layers/errno_mapping.h new file mode 100644 index 0000000000000000000000000000000000000000..747d3b4d4b9c2761f1a3f24f8bfa0da49a34ec19 --- /dev/null +++ b/sparse_matmul/layers/errno_mapping.h @@ -0,0 +1,29 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_ +#define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" + +namespace csrblocksparse { + +// Converts |error_number| value to absl::Status. +absl::Status ErrnoToCanonicalStatus(int error_number, + absl::string_view message); + +} // namespace csrblocksparse + +#endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_ diff --git a/sparse_matmul/layers/masked_sparse_matrix.h b/sparse_matmul/layers/masked_sparse_matrix.h new file mode 100644 index 0000000000000000000000000000000000000000..a905ba4befcdc845834c37a4c07c8331deb8bd70 --- /dev/null +++ b/sparse_matmul/layers/masked_sparse_matrix.h @@ -0,0 +1,206 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_ +#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_ + +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" + +namespace csrblocksparse { + +// MaskedSparseMatrix serves two purposes: +// 1) It is useful as a reference implementation of SpMV for correctness +// checking the much more complicated implementations in CSRBlockSparseMatrix +// 2) This is the format that sparse matrices are represented after pruning +// in TF. This class provides a bridge to getting these parameters into +// a compressed form suitable for computation and serialization. +// +// MaskedSparseMatrix matrix(rows, cols, mask_from_tf, values_from_tf); +// CSRBlockSparseMatrix csr_matrix(matrix); +// csr_matrix.Multiply(rhs, bias, &out); +template +class MaskedSparseMatrix { + public: + MaskedSparseMatrix() {} + + // Construct a MaskedSparseMatrix of the given size, sparsity and block size. + // This is mainly useful for testing. + MaskedSparseMatrix(int rows, int cols, float sparsity, int block_height = 1, + int block_width = 1, float constant = 1.f, + bool random = true) + : rows_(rows), cols_(cols), sparsity_(sparsity) { + CHECK_EQ(rows % block_height, 0); + CHECK_EQ(cols % block_width, 0); + + init(sparsity, block_height, block_width, constant, random); + } + + // Construct from an existing mask and values (most likely from a TF model). + template + MaskedSparseMatrix(int rows, int cols, const MaskType* mask, const T* values) + : rows_(rows), cols_(cols) { + mask_.resize(rows * cols); + values_.resize(rows * cols); + std::copy_n(mask, rows * cols, mask_.begin()); + std::copy_n(values, rows * cols, values_.begin()); + sparsity_ = + 1.f - std::accumulate(mask_.begin(), mask_.end(), 0.f) / mask_.size(); + } + + const std::vector& mask() const { return mask_; } + const std::vector& values() const { return values_; } + T* data() { return values_.data(); } + const T* data() const { return values_.data(); } + + int rows() const { return rows_; } + int cols() const { return cols_; } + float sparsity() const { return sparsity_; } + + void Print() const { + absl::PrintF("-------Values---------\n"); + for (int r = 0; r < rows_; ++r) { + for (int c = 0; c < cols_; ++c) { + absl::PrintF("%+6.3f ", static_cast(values_[r * cols_ + c])); + } + absl::PrintF("\n"); + } + absl::PrintF("-------Mask---------\n"); + for (int r = 0; r < rows_; ++r) { + for (int c = 0; c < cols_; ++c) { + printf("%2d ", mask_[r * cols_ + c]); + } + absl::PrintF("\n"); + } + } + + // This routine is useful for rounding the possibly higher precision values + // stored in this class to a lower precision, so that correctness checks + // between this class and CSRBlockSparseMatrix can have a tighter tolerance. + template + void CastWeights() { + for (int i = 0; i < values_.size(); ++i) { + values_[i] = static_cast(U(values_[i])); + } + } + + // Only meant for correctness checking. + // RhsClassType is meant to be either CacheAlignedVector OR + // FatCacheAlignedVector. + // The weight matrix is ROW MAJOR and RhsClassType is COLUMN MAJOR. + // |bias| is broadcast if |rhs| has more than one column. + template + void SpMM_bias(const RhsClassType& rhs, + const CacheAlignedVector& bias, OutClassType* out, + bool relu = false) { + for (int r = 0; r < rows_; ++r) { + for (int n = 0; n < rhs.cols(); ++n) { + float sum = 0.f; + const RhsType* rhs_ptr = rhs.data() + n * rhs.rows(); + OutType* out_ptr = out->data() + n * out->rows(); + const int* mask_ptr = mask_.data() + r * cols_; + const T* value_ptr = values_.data() + r * cols_; + for (int c = 0; c < cols_; ++c) { + sum += mask_ptr[c] * static_cast(value_ptr[c]) * + static_cast(rhs_ptr[c]); + } + out_ptr[r] = static_cast( + relu ? std::max(sum + static_cast(bias[r]), 0.f) + : sum + static_cast(bias[r])); + } + } + } + + private: + // Generate a random matrix with the specified sparsity. + // Useful for testing. + void init(float sparsity, int block_height, int block_width, float constant, + bool random = true) { + int reduced_rows = rows_ / block_height; + int reduced_cols = cols_ / block_width; + mask_.resize(rows_ * cols_, 0); + + // Fill with non-zero value to make sure masking works. + values_.resize(rows_ * cols_, static_cast(2.f)); + + std::mt19937 generator(0); + std::uniform_real_distribution dist_sparsity; + std::uniform_real_distribution dist_value(-1.f, 1.f); + int nnz = 0; + while (nnz == 0) { + for (int r = 0; r < reduced_rows; ++r) { + for (int c = 0; c < reduced_cols; ++c) { + if (dist_sparsity(generator) > sparsity) { + nnz++; + for (int i = 0; i < block_height; ++i) { + for (int j = 0; j < block_width; ++j) { + mask_[(r * block_height + i) * cols_ + block_width * c + j] = 1; + values_[(r * block_height + i) * cols_ + block_width * c + j] = + static_cast(random ? dist_value(generator) : constant); + } + } + } + } + } + } + } + + std::vector mask_; + std::vector values_; + int rows_; + int cols_; + float sparsity_; +}; + +template +class MaskedLinearLayer { + public: + MaskedLinearLayer(MaskedSparseMatrix&& weights, + CacheAlignedVector&& bias) + : weights_(std::move(weights)), bias_(std::move(bias)) {} + + MaskedLinearLayer() {} + + template + void CastWeights() { + weights_.template CastWeights(); + } + + // Does Ax + b where A is a masked sparse ROW MAJOR matrix and + // x is a COLUMN MAJOR dense vector or matrix. Bias is a vector that is + // broadcast is rhs has more than one column. + template + void SpMM_bias(const FatVector& rhs, FatVector* out, bool relu = false) { + static_assert(std::is_same::value, + "FatVector value_type must match masked_linear_layer type"); + weights_.SpMM_bias(rhs, bias_, out, relu); + } + + private: + MaskedSparseMatrix weights_; + CacheAlignedVector bias_; +}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_ diff --git a/sparse_matmul/layers/read_array_ifstream.h b/sparse_matmul/layers/read_array_ifstream.h new file mode 100644 index 0000000000000000000000000000000000000000..3ea2bd1375435cc316e18c619767334e80040ac1 --- /dev/null +++ b/sparse_matmul/layers/read_array_ifstream.h @@ -0,0 +1,66 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Low-level array reading function using std::ifstream. + +#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_ +#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/substitute.h" +#include "include/ghc/filesystem.hpp" + +namespace csrblocksparse { +namespace detail { + +template +absl::Status ReadArrayIfstream(const std::string& file_name, + const std::string& path, std::vector* array, + int64_t* length) { + ghc::filesystem::path complete_path(path); + complete_path /= file_name; + std::ifstream in_stream(complete_path.u8string(), std::ios::binary); + if (!in_stream.is_open()) { + return absl::UnknownError( + absl::Substitute("Error opening $0", complete_path.string())); + } + + std::stringstream buffer; + buffer << in_stream.rdbuf(); + if (buffer.str().empty()) { + LOG(ERROR) << "File " << complete_path << " was empty."; + return absl::UnknownError( + absl::Substitute("File $0 was empty", complete_path.string())); + } + std::string contents = buffer.str(); + *length = contents.length(); + int64_t elem = (*length + sizeof(T) - 1) / sizeof(T); + array->resize(elem); + std::move(contents.begin(), contents.end(), + reinterpret_cast(array->data())); + + return absl::OkStatus(); +} + +} // namespace detail +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_ diff --git a/sparse_matmul/layers/sparse_linear_layer.h b/sparse_matmul/layers/sparse_linear_layer.h new file mode 100644 index 0000000000000000000000000000000000000000..9363f30113ec51b97437c62aeb512a51e13c1d71 --- /dev/null +++ b/sparse_matmul/layers/sparse_linear_layer.h @@ -0,0 +1,365 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_ +#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_ + +#include + +#include "absl/memory/memory.h" +#include "glog/logging.h" +#include "sparse_matmul/layers/csr_blocksparse_matrix.h" +#include "sparse_matmul/layers/masked_sparse_matrix.h" +#include "sparse_matmul/numerics/type_utils.h" +#include "sparse_matmul/os/coop_threads.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" + +namespace csrblocksparse { + +template ::type, + typename DeltaType = int16_t> +class SparseLinearLayer { + public: + SparseLinearLayer() {} + + SparseLinearLayer(CsrBlockSparseMatrix&& sparse_matrix, + CacheAlignedVector&& bias) + : sparse_matrix_(std::move(sparse_matrix)), full_bias_(std::move(bias)) { + CHECK_EQ(sparse_matrix_.rows(), full_bias_.size()); + // Some kernels expect that the bias is divided by 4, so we store a second + // copy of a quarter of the bias. + // TODO(b/189958858): Remove the quartered bias if it can be done without + // loss of speed, and rename the |full_bias_| member back to |bias_|. + bias_ = full_bias_; + for (int i = 0; i < bias_.size(); ++i) { + bias_[i] = static_cast(.25f * static_cast(bias_[i])); + } + } + SparseLinearLayer( + const SparseLinearLayer& src) { + *this = src; + } + SparseLinearLayer& operator=( + const SparseLinearLayer& src) { + sparse_matrix_ = src.sparse_matrix_; + bias_ = src.bias_; + full_bias_ = src.full_bias_; + mid_output_ = src.mid_output_; + thread_layers_ = src.thread_layers_; + num_threads_ = src.num_threads_; + if (src.split_pc_) { + split_pc_ = absl::make_unique( + src.split_pc_->num_producers(), src.split_pc_->num_consumers()); + } + return *this; + } + + // Does Ax + b where A is a block sparse compressed sparse row matrix and + // x is a COLUMN MAJOR dense vector or matrix. Bias is a vector that is + // broadcast if rhs has more than one column. + template + void SpMM_bias(const RhsClassType& rhs, OutType* out, bool relu = false, + int tid = 0, SpinBarrier* barrier = nullptr) const { + static_assert( + std::is_same::value, ""); + sparse_matrix_.SpMM_bias(rhs, bias_, out, relu, tid, barrier); + } + // Multiplies a sparse matrix by a possibly dense matrix, as SpMM_bias above, + // and then samples from the output (softmax distribution) layer. + template + int SpMM_bias_Sample(const RhsClassType& rhs, OutType* out, float temperature, + int tid, SpinBarrier* barrier, std::minstd_rand* gen, + CacheAlignedVector* scratch) const { + static_assert( + std::is_same::value, ""); + return sparse_matrix_.SpMM_bias_Sample(rhs, bias_, out, temperature, tid, + barrier, gen, scratch); + } + template + void MatVec(const RhsClassType& rhs, bool relu, int tid, int replicas, + int output_stride, OutType* output, + SpinBarrier* barrier = nullptr) { + static_assert( + std::is_same::value, ""); +#ifdef __AVX2__ + if (block_width() == 4 && (block_height() == 4 || block_height() == 8) && + !IsCustomFloatType::value) { + if (!IsSplit()) { + sparse_matrix_.MatVec(rhs.cast_data(), full_bias_.cast_data(), relu, + tid, replicas, output_stride, output->data()); + if (barrier != nullptr) barrier->barrier(); + return; + } + // NOTE: Until the quartered bias is removed it is a bad idea to split + // for ARM in the same way, as we would have to quarter the output of + // the first part of the split before running the second part. + // Signal completion of the previous MatVec. + split_pc_->produce(); + PartLinearLayer& thread_part = thread_layers_[tid]; + auto offset_output = + sparse_matrix_.thread_bounds().OffsetOutput(output->data(), tid); + auto mid_output = + sparse_matrix_.thread_bounds().OffsetOutput(mid_output_.data(), tid); + auto offset_bias = sparse_matrix_.thread_bounds().OffsetOutput( + mid_output_.cast_data(), tid); + // We can continue to consume the data that this thread produced and + // compute just the |self_matrix| part. + // No |relu| or |replicas|, as this is only a partial matmul. + // |tid| is always zero because the matrix has been split by tid. + thread_part.self_matrix.MatVec( + rhs.cast_data(), thread_part.full_bias.cast_data(), /*relu=*/false, + /*tid=*/0, /*replicas=*/1, output_stride, mid_output); + // We have to wait for the other threads to finish working on the previous + // MatMul before consuming the rest of |rhs|. + split_pc_->consume(); + thread_part.other_matrix.MatVec(rhs.cast_data(), offset_bias, relu, + /*tid=*/0, replicas, output_stride, + offset_output); + return; + } +#endif + DCHECK_EQ(replicas, 1) << "Must have single replica for SpMM API"; + if (IsSplit()) { + // Generics aren't setup to use a split matrix. This will be inefficient. + split_pc_->produce(); + split_pc_->consume(); + } + if (block_height() == 8) { + // We are currently forced to use MatVec generics for this case. + LOG(WARNING) << "Need to implement MatVec for 8x4 for non-AVX2 targets!!"; + sparse_matrix_.MatVec(rhs.cast_data(), full_bias_.cast_data(), relu, tid, + replicas, output_stride, output->data()); + if (barrier != nullptr) barrier->barrier(); + } else { + sparse_matrix_.SpMM_bias(rhs, bias_, output, relu, tid, barrier); + } + } + + int rows() const { return sparse_matrix_.rows(); } + int cols() const { return sparse_matrix_.cols(); } + float sparsity() const { return sparse_matrix_.sparsity(); } + int block_width() const { return sparse_matrix_.block_width(); } + int block_height() const { return sparse_matrix_.block_height(); } + int num_threads() const { return sparse_matrix_.num_threads(); } + const CacheAlignedVector& bias() const { return bias_; } + const std::vector& split_points() const { + return sparse_matrix_.split_points(); + } + bool IsSplit() const { + return !thread_layers_.empty() && split_pc_ != nullptr; + } + + std::size_t bytes() const { return sparse_matrix_.bytes() + bias_.bytes(); } + void Print() const { + printf("Matrix\n"); + sparse_matrix_.Print(); + printf("Bias\n"); + bias_.Print(); + } + + // Combines adjacent row blocks, doubling the block height. + // This necessarily involves adding zero weights where the blocks don't align + // across adjacent pairs of rows, so use with caution, as the resulting matrix + // is most likely to run slower if very sparse to begin with. + // In the few cases where the blocks do mostly align, the resulting matmul + // could be much faster, as the number of reads of the rhs will be halved. + void DoubleBlockHeight() { sparse_matrix_.DoubleBlockHeight(); } + + // Cache_line_size is provided only for testing. Normally uses a value for + // the current architecture. + int PrepareForThreads(int num_threads, int cache_line_size = -1) { + num_threads_ = num_threads; + if (num_threads_ > 1) { + split_pc_ = + absl::make_unique(num_threads_, num_threads_); + } else { + split_pc_.reset(nullptr); + } + return sparse_matrix_.PrepareForThreads(num_threads, cache_line_size); + } + + // Partitions the matrix into pieces by thread. + // In this matrix, we can go ahead and calculate the part that only depends + // on rhs inputs that were generated by this thread in the previous matvec, + // without having to use any thread synchronization, and only after that do we + // have to wait for the other threads to finish the previous matvec. + // So we split the matrix using the |split_points| from the previous matrix + // into 2 * |num_threads_| pieces: self and other for each thread, being the + // parts that can be calculated before and after the other threads have + // completed their calculation of the previous matvec. + // We then have to use a ProducerConsumer lock instead of a SpinBarrier to + // synchronize the data produced by the other threads. + void SliceForThreads(const std::vector& split_points) { + thread_layers_.clear(); + thread_layers_.reserve(num_threads_); + LOG(INFO) << "Slicing " << rows() << "x" << cols() << " matrix for " + << num_threads_ << " threads"; + for (int tid = 0; tid < num_threads_; ++tid) { + thread_layers_.emplace_back( + sparse_matrix_, full_bias_, bias_, tid, + split_points[tid] * sparse_matrix_.block_height(), + split_points[tid + 1] * sparse_matrix_.block_height()); + } + mid_output_ = + std::move(csrblocksparse::CacheAlignedVector(rows())); + mid_output_.FillZero(); + } + + // Splits the layer by inputs into 2 equal pieces. Each of the resulting + // layers should be computed independently on the first and second halves of + // the inputs respectively and the results added to achieve the same effect + // as the original layer. + void SplitInputs( + SparseLinearLayer* part1, + SparseLinearLayer* part2) { + CsrBlockSparseMatrix matrix1( + sparse_matrix_.SplitByColumn(0, sparse_matrix_.cols() / 2)); + CsrBlockSparseMatrix matrix2( + sparse_matrix_.SplitByColumn(sparse_matrix_.cols() / 2, + sparse_matrix_.cols())); + *part1 = + std::move(SparseLinearLayer( + std::move(matrix1), + std::move(CacheAlignedVector(full_bias_)))); + CacheAlignedVector bias2(sparse_matrix_.rows()); + bias2.FillZero(); + *part2 = + std::move(SparseLinearLayer( + std::move(matrix2), std::move(bias2))); + } + + // Splits the layer by outputs into 2 equal pieces. Each of the resulting + // layers should be computed independently on the full inputs and the results + // concatenated to achieve the same effect as the original layer. + void SplitOutputs( + SparseLinearLayer* part1, + SparseLinearLayer* part2) { + LOG(INFO) << "input rows=" << sparse_matrix_.rows() + << ", cols=" << sparse_matrix_.cols(); + CsrBlockSparseMatrix matrix1( + sparse_matrix_.SplitByRow(0, sparse_matrix_.rows() / 2)); + CsrBlockSparseMatrix matrix2(sparse_matrix_.SplitByRow( + sparse_matrix_.rows() / 2, sparse_matrix_.rows())); + CacheAlignedVector bias1(full_bias_, 0, full_bias_.size() / 2); + *part1 = + std::move(SparseLinearLayer( + std::move(matrix1), std::move(bias1))); + CacheAlignedVector bias2(full_bias_, full_bias_.size() / 2, + full_bias_.size()); + *part2 = + std::move(SparseLinearLayer( + std::move(matrix2), std::move(bias2))); + } + + private: + // Simple struct to hold a partitioned layer. + struct PartLinearLayer { + // The original matrix is first split by row to generate only the outputs + // for the given tid. The |row_sub_matrix| is then split by column into two + // partitions: + // self is the part for which the rhs elements in [|start_col|, |end_col|) + // were generated by this thread in some previous matmul. + // |other| is the rest of the columns that require rhs elements from other + // threads. + // NOTE that| start_col|, |end_col| are in raw columns, not blocks. + PartLinearLayer(const CsrBlockSparseMatrix& matrix, + const CacheAlignedVector& bias, + const CacheAlignedVector& bias_4, int tid, + int start_col, int end_col) { + int block_height = matrix.block_height(); + // Split the input matrix by row, selecting only the rows relevant to + // thread tid. + int start_row = matrix.split_points()[tid] * block_height; + int end_row = matrix.split_points()[tid + 1] * block_height; + LOG(INFO) << "input cols [" << start_col << "," << end_col << ") rows [" + << start_row << "," << end_row << ")"; + CsrBlockSparseMatrix row_sub_matrix = + matrix.SplitByRow(start_row, end_row); + // Partition into the columns that use rhs elements that thread tid + // produced in a previous matmul, and the other rhs elements. + // NOTE that we |keep_rhs_size|=true so that each matrix can operate on + // the same rhs input vector. The self matrix just guarantees not to + // access any of the elements that are generated by another thread. + self_matrix = std::move(row_sub_matrix.SplitByColumn( + start_col, end_col, /*keep_rhs_size=*/true)); + self_matrix.PrepareForThreads(1); + // The reversed start and end slice out the complement of [start, end). + other_matrix = std::move(row_sub_matrix.SplitByColumn( + end_col, start_col, /*keep_rhs_size=*/true)); + other_matrix.PrepareForThreads(1); + full_bias = + std::move(CacheAlignedVector(bias, start_row, end_row)); + // TODO(b/189958858): Eliminate the quarter bias from all the code. + quarter_bias = + std::move(CacheAlignedVector(bias_4, start_row, end_row)); + } + // The part of the matrix that only depends on this thread for rhs inputs. + CsrBlockSparseMatrix self_matrix; + CacheAlignedVector full_bias; + CacheAlignedVector quarter_bias; + // The part of the matrix that uses rhs inputs from other threads. + CsrBlockSparseMatrix other_matrix; + }; + CsrBlockSparseMatrix sparse_matrix_; + CacheAlignedVector bias_; + CacheAlignedVector full_bias_; + // Output from the self_matrix that will be given to |other_matrix| as bias. + CacheAlignedVector mid_output_; + // One partitioned pair of matrices for each thread. + std::vector thread_layers_; + // Producer-consumer lock used to wait between computing |self_matrix| and + // |other_matrix| for the other threads to finish the *previous* matvec. + std::unique_ptr split_pc_; + int num_threads_ = 0; +}; + +template +SparseLinearLayer CreateRandomLayer(int rows, int cols, + float sparsity, + int block_height = 1, + int block_width = 1) { + typedef typename TypeOfProduct::type BiasType; + CacheAlignedVector bias(rows); + bias.FillRandom(); + + auto masked_matrix = MaskedSparseMatrix(rows, cols, sparsity, + block_height, block_width); + auto sparse_matrix = CsrBlockSparseMatrix(masked_matrix); + + return SparseLinearLayer(std::move(sparse_matrix), + std::move(bias)); +} + +template +SparseLinearLayer CreateConstantLayer( + int rows, int cols, float sparsity, float constant = 1.f) { + typedef typename TypeOfProduct::type BiasType; + CacheAlignedVector bias(rows); + bias.FillOnes(); + + MaskedSparseMatrix masked_matrix(rows, cols, sparsity, + /*block_height=*/1, /*block_width=*/1, + constant, /*random=*/false); + CsrBlockSparseMatrix sparse_matrix(masked_matrix); + + return SparseLinearLayer(std::move(sparse_matrix), + std::move(bias)); +} + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_ diff --git a/sparse_matmul/layers/sparse_linear_layer_test.cc b/sparse_matmul/layers/sparse_linear_layer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bb256ec05965c3ed39b657ec43ba9a58ba415857 --- /dev/null +++ b/sparse_matmul/layers/sparse_linear_layer_test.cc @@ -0,0 +1,187 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/layers/sparse_linear_layer.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "sparse_matmul/numerics/test_utils.h" + +namespace csrblocksparse { +namespace { + +constexpr int kBlockSize = 4; +constexpr int kSize = 256; +constexpr int kNumThreads = 4; +constexpr int kCols = 1; + +void SlicedThreadBody(SpinBarrier* spin_barrier, int tid, + const FatCacheAlignedVector& rhs, + SparseLinearLayer* sparse_linear_layer, + FatCacheAlignedVector* out, bool use_relu) { + sparse_linear_layer->MatVec(rhs, use_relu, tid, /*replicas=*/1, + /*output_stride=*/0, out); + spin_barrier->barrier(); +} + +// Tests that a Layer that has been SliceForThreads computes the same result as +// the original layer. This is a basic test that all the slicing didn't mess up +// any of the computations. +TEST(CsrBlockSparseMatrix, SliceForThreads) { + MaskedSparseMatrix matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize); + FatCacheAlignedVector rhs(kSize, kCols); + CacheAlignedVector bias(kSize); + FatCacheAlignedVector out1(kSize, kCols); + + bias.FillRandom(); + rhs.FillRandom(); + out1.FillZero(); + FatCacheAlignedVector out_reference = out1; + CsrBlockSparseMatrix sparse_matrix(matrix); + SparseLinearLayer sparse_linear_layer(std::move(sparse_matrix), + std::move(bias)); + sparse_linear_layer.PrepareForThreads(1); + sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1, + /*output_stride=*/0, &out_reference); + std::vector fake_split_points = {0, 48 / kBlockSize, 128 / kBlockSize, + 208 / kBlockSize, kSize / kBlockSize}; + sparse_linear_layer.PrepareForThreads(kNumThreads); + sparse_linear_layer.SliceForThreads(fake_split_points); + csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, SlicedThreadBody, rhs, + &sparse_linear_layer, &out1, + /*relu=*/true); + + CheckResult(out_reference, out1, kCols); +} + +void LayersThreadBody(SpinBarrier* spin_barrier, int tid, + const FatCacheAlignedVector& rhs, + SparseLinearLayer* sparse_linear_layer1, + SparseLinearLayer* sparse_linear_layer2, + FatCacheAlignedVector* out1, + FatCacheAlignedVector* out2, bool use_relu) { + sparse_linear_layer1->MatVec(rhs, use_relu, tid, /*replicas=*/1, + /*output_stride=*/0, out1); + // NOTE no barrier here! + sparse_linear_layer2->MatVec(*out1, use_relu, tid, /*replicas=*/1, + /*output_stride=*/0, out2); + spin_barrier->barrier(); +} + +// Tests that a pair of layers computes the same result whether or not the +// second layer has been SliceForThreads. This is a more critical test that +// the replacement of barriers with producer-consumer locks works. +// Must be run with tsan to really test it properly. +TEST(CsrBlockSparseMatrix, SliceForThreadsLayers) { + MaskedSparseMatrix matrix1(kSize, kSize, 0.95, kBlockSize, kBlockSize); + FatCacheAlignedVector rhs(kSize, kCols); + CacheAlignedVector bias1(kSize); + FatCacheAlignedVector out1(kSize, kCols); + MaskedSparseMatrix matrix2(kSize, kSize, 0.95, kBlockSize, kBlockSize); + CacheAlignedVector bias2(kSize); + FatCacheAlignedVector out2(kSize, kCols); + + bias1.FillRandom(); + rhs.FillRandom(); + bias2.FillRandom(); + out1.FillZero(); + out2.FillZero(); + FatCacheAlignedVector out_reference = out2; + CsrBlockSparseMatrix sparse_matrix1(matrix1); + SparseLinearLayer layer1(std::move(sparse_matrix1), + std::move(bias1)); + CsrBlockSparseMatrix sparse_matrix2(matrix2); + SparseLinearLayer layer2(std::move(sparse_matrix2), + std::move(bias2)); + layer1.PrepareForThreads(1); + layer2.PrepareForThreads(1); + layer1.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1, + /*output_stride=*/0, &out1); + layer2.MatVec(out1, /*relu=*/true, /*tid=*/0, /*replicas=*/1, + /*output_stride=*/0, &out_reference); + layer1.PrepareForThreads(kNumThreads); + layer2.PrepareForThreads(kNumThreads); + layer2.SliceForThreads(layer1.split_points()); + csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, LayersThreadBody, rhs, + &layer1, &layer2, &out1, &out2, + /*relu=*/true); + + CheckResult(out_reference, out2, kCols); +} + +// Tests that a Layer that has been DoubleBlockHeight()-ed computes the same +// result as original layer. (Float compute type). +TEST(CsrBlockSparseMatrix, Float8x4) { + using ComputeType = float; + using RhsType = float; + using BiasType = float; + MaskedSparseMatrix matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize); + matrix.CastWeights(); + FatCacheAlignedVector rhs(kSize, kCols); + CacheAlignedVector bias(kSize); + FatCacheAlignedVector out1(kSize, kCols); + + bias.FillRandom(); + rhs.FillRandom(); + out1.FillZero(); + FatCacheAlignedVector out_reference = out1; + CsrBlockSparseMatrix sparse_matrix(matrix); + SparseLinearLayer sparse_linear_layer( + std::move(sparse_matrix), std::move(bias)); + sparse_linear_layer.PrepareForThreads(1); + sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1, + /*output_stride=*/0, &out_reference); + sparse_linear_layer.DoubleBlockHeight(); + sparse_linear_layer.PrepareForThreads(1); + sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1, + /*output_stride=*/0, &out1); + CheckResult(out_reference, out1, kCols); +} + +// Tests that a Layer that has been DoubleBlockHeight()-ed computes the same +// result as original layer. (Fixed16 compute type). +TEST(CsrBlockSparseMatrix, Fixed8x4) { + using ComputeType = csrblocksparse::fixed16<4>; + using RhsType = csrblocksparse::fixed16<4>; + using BiasType = typename TypeOfProduct::type; + MaskedSparseMatrix matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize); + matrix.CastWeights(); + FatCacheAlignedVector rhs(kSize, kCols); + CacheAlignedVector bias(kSize); + FatCacheAlignedVector out1(kSize, kCols); + + bias.FillRandom(); + rhs.FillRandom(); + out1.FillZero(); + FatCacheAlignedVector out_reference = out1; + CsrBlockSparseMatrix sparse_matrix(matrix); + SparseLinearLayer sparse_linear_layer( + std::move(sparse_matrix), std::move(bias)); + sparse_linear_layer.PrepareForThreads(1); + sparse_linear_layer.MatVec(rhs, /*relu=*/false, /*tid=*/0, /*replicas=*/1, + /*output_stride=*/0, &out_reference); + sparse_linear_layer.DoubleBlockHeight(); + sparse_linear_layer.PrepareForThreads(1); + sparse_linear_layer.MatVec(rhs, /*relu=*/false, /*tid=*/0, /*replicas=*/1, + /*output_stride=*/0, &out1); + CheckResult(out_reference, out1, kCols); +} + +TEST(SparseLinearLayerTest, PrintCompiles) { + SparseLinearLayer sparse_linear_layer; + sparse_linear_layer.Print(); +} + +} // namespace +} // namespace csrblocksparse diff --git a/sparse_matmul/layers/status_macros.h b/sparse_matmul/layers/status_macros.h new file mode 100644 index 0000000000000000000000000000000000000000..d2ebeaa7db5b49abc5065c7fd4a2d57564aa2550 --- /dev/null +++ b/sparse_matmul/layers/status_macros.h @@ -0,0 +1,34 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_ +#define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_ + +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +#define SPARSE_MATMUL_RETURN_IF_ERROR(expr) \ + do { \ + const absl::Status _status = (expr); \ + if (!_status.ok()) return _status; \ + } while (0) +template +absl::Status DoAssignOrReturn(T& lhs, absl::StatusOr result) { + if (result.ok()) { + lhs = result.value(); + } + return result.status(); +} + +#endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_QRhat_weights.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_QRhat_weights.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..4b986882aef8e7d3474812a75e450d27054f0fd4 --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_QRhat_weights.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50f861af29b1f767830d74ef83874944b18d80157b6b0256fdc4c14fa79ec936 +size 20852 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_What_weights.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_What_weights.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..0fc7369bbda29d0d4c14e23f551a4c8476de0f0f --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_What_weights.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a2d534bde2caf6e59990a46b4b1907088b8144c53d62d97de7e2b4bdc956da68 +size 5133 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_bias.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_bias.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..657ec5a95a304a3adc2bacf85b974ae7fee0dff1 --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_bias.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:11399f9d0e8f8dfbef6eb37e0c096f858658bc650f728a08f3135ccca44f0a5a +size 1062 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_mask.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_mask.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..4674dd5ab09c2daf2b5334249d4b30bd497e47d6 --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_mask.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d3d971e067a6df985d68beac26bcf4e9a6cc13ff328599e84d50a0fc9a7c103b +size 2382 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_weights.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_weights.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..db9fdb3ec4579f3a50dfe6ccf458b96d7077b3d2 --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_weights.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d1376ef7a360699dae24a49f40a254990d4a70b844dadcdbe9dcbf1a306999a8 +size 55829 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_bias.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_bias.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..5e29ff26bdc576e9e09ab3ef52dc5bd57f7f7c6d --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_bias.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ffcc8ccf086fccfacc928877aa29ef03ce51cce0f0b7d2aacf81782b7b527089 +size 2003 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_mask.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_mask.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..46768d586a447e5b0d364bae233d9205bbbdbd57 --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_mask.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a16f98ba6f09031ea9fefb79fdc9ba90e44f0046ab70dab014ac971ca7f7186 +size 4684 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_weights.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_weights.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..8b1d454c91e2c376657ac2a572b234e293d617e4 --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_weights.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1b91304f5b6f7b53651ec7f9c827d4a2447366d1f990032adff46b18377741f +size 113777 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_bias.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_bias.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..e84744f9a7988ebb01615eb61ef648d3060ae62e --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_bias.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ebb84ab4e16408f898b41a28c0d2c611f6735c8d9ad96a6805947c57cb547c7 +size 1055 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_mask.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_mask.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..5daa30ccd0068a27cdc1f00938e366e62b169989 --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_mask.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:071159e5397eff604ff3f1fca3ba90980a1ff9ae12838022179709d2c50e4627 +size 2322 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_weights.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_weights.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..5a15b1591b978af58f596fc4c00f6e28e67dc9a5 --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_weights.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1fdd0cbc0e79ea0a0dc1fc2ce8b10c5f25387fb4fd2ca019b66ac7ad7f44d219 +size 51615 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_bias.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_bias.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..5c85faf2cb5bed32b4279384356805625504e8c0 --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_bias.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:abd83a1795fd5e7044200029eae3ce6406b84095b7128288ac0dda1de5746b59 +size 2001 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_mask.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_mask.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..3562f8f341479cda655bb342a3166901ccb439ec --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_mask.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:455e1c142dd29bc4a4bb5a15c1f88ef3e0fbb580425620ef6f923b6e04faab01 +size 4459 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_weights.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_weights.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..2fcdead28a8ffdfaab831e77c4d2b8bd70e980dd --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_weights.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:171d1e86e04fbefeca7dcce59817ad82d30556a110b4552cd5757a9348405d1c +size 111636 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_bias.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_bias.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..ead150fe78c6ef1940aff8da99dc93c583c70af4 --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_bias.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fba804daa5c3c4d5c87ca1ff4060d118c33f8e2201077e6faa233822c5f0c511 +size 10706 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_mask.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_mask.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..9360b1d755076cc714ad19bf3dd277662ba71198 --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_mask.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62c03b31f5f58eb67773dcc5b0bae5b4790a26dca1934d79802342b4175e7a74 +size 50978 diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_weights.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_weights.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..43d829f3aa406da96497c52b7ec577305dd20118 --- /dev/null +++ b/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_weights.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:679c5bd2d5ca6abaae96225e8bab2ce9f9d57170027471465c85fc220c0c44a8 +size 1361746 diff --git a/sparse_matmul/layers/testdata/lyra_conv1d_bias.raw.gz b/sparse_matmul/layers/testdata/lyra_conv1d_bias.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..221d4230c25e16f0101875ca2f01ca5305fe3c9a --- /dev/null +++ b/sparse_matmul/layers/testdata/lyra_conv1d_bias.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:14cca1e77c2a87ac161c0273153b46630e8a00718c6aac0479f13cda6f07ad81 +size 1980 diff --git a/sparse_matmul/layers/testdata/lyra_conv1d_mask.raw.gz b/sparse_matmul/layers/testdata/lyra_conv1d_mask.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..1e6b4cb5826eb7e0f261d8d25f973edad469f241 --- /dev/null +++ b/sparse_matmul/layers/testdata/lyra_conv1d_mask.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbbcc8c024f0e67a58eca0e3007c396b384312b1f437f7634ba61e4ad6780908 +size 953 diff --git a/sparse_matmul/layers/testdata/lyra_conv1d_weights.raw.gz b/sparse_matmul/layers/testdata/lyra_conv1d_weights.raw.gz new file mode 100644 index 0000000000000000000000000000000000000000..013a94bb3f55a8904409754c9d17e4ab2421db62 --- /dev/null +++ b/sparse_matmul/layers/testdata/lyra_conv1d_weights.raw.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7272b077f8c569922d756a2b1a29cd94db7651c86fcf66b175b1905f6c11351 +size 858640 diff --git a/sparse_matmul/layers/utils.cc b/sparse_matmul/layers/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..0a8d5796afe65c15cc4b59789dfcc876d98cc40a --- /dev/null +++ b/sparse_matmul/layers/utils.cc @@ -0,0 +1,129 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Source for various utility functions related to reading and writing files +// and vectors. Would be much simpler if Android and Windows supported File. + +#include "sparse_matmul/layers/utils.h" + +#ifdef _WIN32 +#include + +#include +#include // NOLINT +#else +#include +#endif // _WIN32 + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/substitute.h" + +namespace csrblocksparse { + +namespace { + +// Helper to test if a filename is "." or "..". +template +bool IsDotOrDotDot(const CharType* filename) { + if (filename[0] == '.') { + if (filename[1] == '\0') { + return true; + } + if (filename[1] == '.' && filename[2] == '\0') { + return true; + } + } + + return false; +} + +#ifdef _WIN32 // We only define these conversion routines on Win32. +static std::mutex g_converter_mutex; +static std::wstring_convert> g_converter; + +std::string Narrow(const std::wstring& wide) { + std::lock_guard auto_lock(g_converter_mutex); + return g_converter.to_bytes(wide); +} + +std::wstring Widen(const std::string& narrow) { + std::lock_guard auto_lock(g_converter_mutex); + return g_converter.from_bytes(narrow); +} + +inline constexpr char kLongPathPrefix[] = R"(\\?\)"; + +std::wstring ConvertToWindowsPathFormat(const std::string& path, + int max_path_length = MAX_PATH) { + if (path.length() + 1 > max_path_length && + !absl::StartsWith(path, kLongPathPrefix)) { + return Widen(absl::StrCat(kLongPathPrefix, path)); + } + return Widen(path); +} +#endif // _WIN32 + +} // namespace + +// Return all files in a given directory. +absl::Status FilesInDirectory(const std::string& path, + const std::string& must_contain, + std::vector* result) { +#ifdef _WIN32 + WIN32_FIND_DATAW child_data; + HANDLE find_handle = FindFirstFileW( + ConvertToWindowsPathFormat(absl::StrCat(path, "\\*")).c_str(), + &child_data); + if (find_handle == INVALID_HANDLE_VALUE) { + return absl::UnknownError( + absl::Substitute("Couldn't open: $0 (error $1)", path, GetLastError())); + } + do { + if (IsDotOrDotDot(child_data.cFileName)) continue; + const std::string name = Narrow(child_data.cFileName); + if (name.find(must_contain) == std::string::npos) continue; + result->push_back(name); + } while (FindNextFileW(find_handle, &child_data) != 0); + const auto err = GetLastError(); + FindClose(find_handle); + if (err != ERROR_NO_MORE_FILES) + return absl::UnknownError( + absl::Substitute("Error in FindNextFileW: $0", err)); +#else + DIR* dirp = opendir(path.c_str()); + if (dirp == nullptr) { + return absl::UnknownError(absl::Substitute("Couldn't open: $0", path)); + } + + dirent* dp; + errno = 0; + while ((dp = readdir(dirp)) != nullptr) { + if (IsDotOrDotDot(dp->d_name)) continue; + const std::string name(dp->d_name); + if (name.find(must_contain) == std::string::npos) continue; + result->push_back(name); + } + closedir(dirp); + if (errno != 0) + return absl::UnknownError(absl::Substitute("Error in readdir: $0", errno)); +#endif // _WIN32 + + return absl::OkStatus(); +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/layers/utils.h b/sparse_matmul/layers/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..e10b1b957d8faa5c253f86d99d7a3738e480df9f --- /dev/null +++ b/sparse_matmul/layers/utils.h @@ -0,0 +1,338 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Various utility functions related to reading and writing files, vectors, etc. +// Would be much simpler if Android supported File. + +#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_UTILS_H_ +#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_UTILS_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/substitute.h" +#include "include/ghc/filesystem.hpp" +#include "sparse_matmul/layers/errno_mapping.h" +#include "sparse_matmul/layers/masked_sparse_matrix.h" +#include "sparse_matmul/layers/read_array_ifstream.h" +#include "sparse_matmul/layers/sparse_linear_layer.h" +#include "sparse_matmul/layers/status_macros.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" +#include "sparse_matmul/numerics/type_utils.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" +#include "sparse_matmul/zlib_wrapper/zlibwrapper.h" + +namespace csrblocksparse { + +template +void unzip(int64_t st_size, std::vector* array) { + ZLib z; + z.SetGzipHeaderMode(); + if (z.HasGzipHeader(reinterpret_cast(array->data()), st_size)) { + const std::size_t kMaxBufferSize = 1 << 27; // 128MB + + Bytef* dest; + uLongf dest_len = kMaxBufferSize; + CHECK_EQ(z.UncompressGzipAndAllocate(&dest, &dest_len, + (Bytef*)array->data(), st_size), + Z_OK); + CHECK_EQ(dest_len % sizeof(T), 0); + array->assign(reinterpret_cast(dest), + reinterpret_cast(dest + dest_len)); + free(dest); + } else { + CHECK_EQ(st_size % sizeof(T), 0); + } +} + +// Reads a file that contains an array of a single POD type. Eventually we +// will replace serializiation with protos, but for now this is the easiest way +// to interface with the rest of the pipeline. +// +// StatusOr might be preferred but does not compile on ARM. +// |DiskType| and |ElemType| template types have no effect in this function +// version and are only used to handle fixed_type disk storage. +template +typename std::enable_if::value, + absl::Status>::type +ReadArrayFromFile(const std::string& file_name, std::vector* array, + const std::string& path = "/data/local/tmp/") { + int64_t length = 0; + const absl::Status status = + detail::ReadArrayIfstream(file_name, path, array, &length); + if (!status.ok()) { + return status; + } + unzip(length, array); + + return absl::OkStatus(); +} + +// If the metatype |DiskType| is of fixed16_type, we load int16_ts from disk and +// construct |ElemType| from them. |ElemType| is necessary because we need to +// know the mantissa/exponent bit split before casting to float. We need a +// separate function template for fixed rather than an if block because the +// compiler will complain bfloat not having an int16_t constructor. +template +typename std::enable_if::value && + csrblocksparse::IsFixed16Type::value, + absl::Status>::type +ReadArrayFromFile(const std::string& file_name, std::vector* array, + const std::string& path = "/data/local/tmp/") { + std::vector disk_values; + SPARSE_MATMUL_RETURN_IF_ERROR( + ReadArrayFromFile(file_name, &disk_values, path)); + array->resize(disk_values.size()); + std::transform( + disk_values.begin(), disk_values.end(), array->begin(), + [](int16_t disk_value) { return static_cast(ElemType(disk_value)); }); + return absl::OkStatus(); +} + +// Writes a vector to a binary file. Eventually serialization will be handled +// with protos. +template +absl::Status WriteArrayToFile(const std::vector& array, + const std::string& file_name, + std::string path = "/data/local/tmp/") { + path = (ghc::filesystem::path(path) / file_name).string(); + FILE* fp = fopen(path.c_str(), "wb"); + if (fp == nullptr) + return ErrnoToCanonicalStatus(errno, + absl::Substitute("Error opening $0", path)); + size_t write_count = fwrite(array.data(), sizeof(T), array.size(), fp); + if (write_count != array.size()) { + return ErrnoToCanonicalStatus( + errno, + absl::Substitute( + "Error writing array, only wrote $0 of $1 elements for file $2", + write_count, array.size(), path)); + } + SPARSE_MATMUL_RETURN_IF_ERROR(ErrnoToCanonicalStatus( + fclose(fp), absl::Substitute("Error closing $0", path))); + return absl::OkStatus(); +} + +// Reads an entire layer that consists of weights, bias and mask as a +// SparseLinearLayer. Eventually this serialization will be handled with +// protos, but the rest of the system currently does naive serialization. +// +// StatusOr might be preferred but does not compile on ARM. +// +// Here |DiskWeightType| is the metatype used to store the weights, usually +// fixed16_type, float, or bfloat. +// For |DiskWeightType| = fixed16_type specialization, this loads a file with a +// "fixed16_weights.raw" suffix which stores int16_ts as its element datatype. +// The disk elements should match fixed16. This cuts +// down disk storage of weights by +// >= half. For all other types it reads the weights as floats. +template +absl::Status LoadGenericLayer( + const std::string& prefix, bool zipped, const std::string& path, + float default_bias, + SparseLinearLayer* sparse_linear_layer) { + std::string fixed_prefix = + csrblocksparse::IsFixed16Type::value ? "fixed16_" : ""; + std::string extension = zipped ? ".gz" : ""; + std::string weight_name = + absl::StrCat(prefix, fixed_prefix, "weights.raw", extension); + std::string mask_name = absl::StrCat(prefix, "mask.raw", extension); + std::string bias_name = absl::StrCat(prefix, "bias.raw", extension); + + std::vector weight_vector; + std::vector mask_vector; + std::vector bias_vector; + + const auto status = ReadArrayFromFile( + weight_name, &weight_vector, path); + SPARSE_MATMUL_RETURN_IF_ERROR(status); + SPARSE_MATMUL_RETURN_IF_ERROR( + ReadArrayFromFile(mask_name, &mask_vector, path)); + SPARSE_MATMUL_RETURN_IF_ERROR( + ReadArrayFromFile(bias_name, &bias_vector, path)); + + CHECK(weight_vector.size() == mask_vector.size()) + << "Weight and mask must be" + << " the same size, weights: " << weight_vector.size() + << " mask: " << mask_vector.size(); + CHECK(weight_vector.size() % bias_vector.size() == 0) + << "Weights size must " + "be a multiple of the bias size. Weights: " + << weight_vector.size() + << " " + "bias: " + << bias_vector.size() + << " remainder: " << weight_vector.size() % bias_vector.size(); + + int rows = bias_vector.size(); + int cols = weight_vector.size() / rows; + + MaskedSparseMatrix weights_masked(rows, cols, mask_vector.data(), + weight_vector.data()); + + weights_masked.template CastWeights(); + using csrmatrix = CsrBlockSparseMatrix; + + csrmatrix weights(weights_masked); + // If the weights were not a multiple of the block size in rows, we need to + // expand the bias vector to match using the provided default_bias value. + bias_vector.resize(weights.rows(), default_bias); + using BiasType = typename TypeOfProduct::type; + CacheAlignedVector bias(bias_vector); + + *sparse_linear_layer = std::move(SparseLinearLayer( + std::move(weights), std::move(bias))); + + return absl::OkStatus(); +} +template +absl::Status LoadSparseLayer( + const std::string& prefix, bool zipped, + SparseLinearLayer* sparse_linear_layer, + const std::string& path = "/data/local/tmp/") { + return LoadGenericLayer( + prefix, zipped, path, 0.0f, sparse_linear_layer); +} +template +absl::Status LoadLogitLayer( + const std::string& prefix, bool zipped, const std::string& path, + SparseLinearLayer* sparse_linear_layer) { + return LoadGenericLayer( + prefix, zipped, path, std::numeric_limits::lowest(), + sparse_linear_layer); +} + +// Reads an entire layer that consists of weights, bias and mask as a +// MaskedLinearLayer. Eventually this serialization will be handled with +// protos, but the rest of the system currently does naive serialization. +// +// StatusOr might be preferred but does not compile on ARM. +template +absl::Status LoadMaskedLayer(const std::string& prefix, bool zipped, + MaskedLinearLayer* masked_sparse_matrix, + const std::string& path = "/data/local/tmp/") { + std::string extension = zipped ? ".gz" : ""; + std::string weight_name = absl::StrCat(prefix, "weights.raw", extension); + std::string mask_name = absl::StrCat(prefix, "mask.raw", extension); + std::string bias_name = absl::StrCat(prefix, "bias.raw", extension); + + std::vector weight_vector; + std::vector mask_vector; + std::vector bias_vector; + + SPARSE_MATMUL_RETURN_IF_ERROR( + ReadArrayFromFile(weight_name, &weight_vector, path)); + SPARSE_MATMUL_RETURN_IF_ERROR( + ReadArrayFromFile(mask_name, &mask_vector, path)); + SPARSE_MATMUL_RETURN_IF_ERROR( + ReadArrayFromFile(bias_name, &bias_vector, path)); + + CHECK(weight_vector.size() == mask_vector.size()) + << "Weight and mask must be" + << " the same size, weights: " << weight_vector.size() + << " mask: " << mask_vector.size(); + CHECK(weight_vector.size() % bias_vector.size() == 0) + << "Weights size must " + "be a multiple of the bias size. Weights: " + << weight_vector.size() + << " " + "bias: " + << bias_vector.size() + << " remainder: " << weight_vector.size() % bias_vector.size(); + + int rows = bias_vector.size(); + int cols = weight_vector.size() / rows; + + MaskedSparseMatrix weights_masked(rows, cols, mask_vector.data(), + weight_vector.data()); + CacheAlignedVector bias(bias_vector); + + *masked_sparse_matrix = + MaskedLinearLayer(std::move(weights_masked), std::move(bias)); + return absl::OkStatus(); +} + +// Load a vector of POD into a CacheAlignedVector. +// +// StatusOr might be preferred but does not compile on ARM. +template +absl::Status LoadVector(const std::string& file_name, + CacheAlignedVector* cache_aligned_vector, + const std::string& path = "/data/local/tmp/") { + std::vector values; + + SPARSE_MATMUL_RETURN_IF_ERROR(ReadArrayFromFile(file_name, &values, path)); + + *cache_aligned_vector = std::move(CacheAlignedVector(values)); + + return absl::OkStatus(); +} + +// Loads a 2D vector from a file. One of rows or cols can optionally be +// -1 to indicate that dimension should be inferred. +template +absl::Status LoadFatVector(const std::string& file_name, int rows, int cols, + FatCacheAlignedVector* fat_cache_aligned_vector, + const std::string& path = "/data/local/tmp/") { + // neither can be zero + CHECK(rows != 0 && cols != 0); + // only one can be -1 + CHECK(rows != -1 || cols != -1); + // otherwise must be positive + CHECK(rows >= -1 && cols >= -1); + + CacheAlignedVector values; + + SPARSE_MATMUL_RETURN_IF_ERROR(LoadVector(file_name, &values, path)); + + if (rows > 0) + CHECK_EQ(values.size() % rows, 0); + else + rows = values.size() / cols; + + if (cols > 0) + CHECK_EQ(values.size() % cols, 0); + else + cols = values.size() / rows; + + *fat_cache_aligned_vector = std::move(FatCacheAlignedVector(values, rows)); + + return absl::OkStatus(); +} + +// Return all files in a given directory +// If only File worked on Android and Windows... +absl::Status FilesInDirectory(const std::string& path, + const std::string& must_contain, + std::vector* result); + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_UTILS_H_ diff --git a/sparse_matmul/layers/utils_test.cc b/sparse_matmul/layers/utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c70ee07472bce894921a0ede7880c46098987816 --- /dev/null +++ b/sparse_matmul/layers/utils_test.cc @@ -0,0 +1,185 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/layers/utils.h" + +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "include/ghc/filesystem.hpp" +#include "sparse_matmul/layers/csr_blocksparse_matrix.h" +#include "sparse_matmul/layers/errno_mapping.h" +#include "sparse_matmul/layers/sparse_linear_layer.h" +#include "sparse_matmul/numerics/fast_transcendentals.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" +#include "sparse_matmul/numerics/test_utils.h" +#include "sparse_matmul/numerics/type_utils.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" + +namespace csrblocksparse { +namespace { + +static constexpr char kTempOutputDir[] = + "third_party/lyra_codec/sparse_matmul/layers/testdata/"; +static constexpr int kTestExponentBits = 5; + +template +class CsrBlockSparseMatrixUtilsTest : public testing::Test { + protected: + CsrBlockSparseMatrixUtilsTest() + : output_dir_((ghc::filesystem::path(testing::TempDir()) / kTempOutputDir) + .string()) { + if (std::is_floating_point::value) { + tolerance_ = 1e-5; + } else if (csrblocksparse::IsCustomFloatType::value) { + // Casting float --> bfloat truncates the least significant 16 bits from + // the mantissa, thus the larger the exponent bits the larger the rounding + // error. + // The exponent for max_val is 2^4, meaning the max rounding error + // for the weight input is ~ 0.124. The tolerance is 2x this because + // although the intermediate multiplications are accumulated in float, + // the output is cast to bfloat. + // Placeholder for internal diagram. + float max_val = + std::pow(2, kTestExponentBits) - + std::pow(2, -fixed16::kMantissaBits); + tolerance_ = 2 * (max_val - static_cast(ComputeType(max_val))); + } else { + tolerance_ = std::pow(2, -MantissaBitsOf::value); + } + } + + void SetUp() override { + std::error_code error_code; + ghc::filesystem::create_directories(output_dir_, error_code); + ASSERT_FALSE(error_code); + } + + void TearDown() override { + std::error_code error_code; + ghc::filesystem::remove_all(output_dir_, error_code); + ASSERT_FALSE(error_code); + } + + const std::string output_dir_; + float tolerance_; +}; + +void GenerateRandomWeightBiasMaskVectors( + int weight_vector_size, int bias_vector_size, + std::vector* weight_vector, std::vector* bias_vector, + std::vector* mask_vector, std::vector* masked_weight_vector) { + weight_vector->resize(weight_vector_size); + bias_vector->resize(bias_vector_size); + mask_vector->resize(weight_vector_size); + masked_weight_vector->resize(weight_vector_size); + // Fill Weight and Bias with random values between +/-[2^|kTestExponentBits| - + // 1] - 0.5 to prevent clipping in the fixed16 case when the weight and bias + // are added with all 1s in the exponent and mantissa. + const float max_abs_random_value = + std::pow(2, kTestExponentBits - 1) - 0.5; + std::uniform_real_distribution distribution(-max_abs_random_value, + max_abs_random_value); + std::default_random_engine generator(1337); + std::generate(weight_vector->begin(), weight_vector->end(), + [&]() { return distribution(generator); }); + std::generate(bias_vector->begin(), bias_vector->end(), + [&]() { return distribution(generator); }); + std::bernoulli_distribution mask_distribution(0.5); + std::generate(mask_vector->begin(), mask_vector->end(), + [&]() { return mask_distribution(generator) ? 1 : 0; }); + // Construct the combined weight and mask vector. + std::transform(mask_vector->begin(), mask_vector->end(), + weight_vector->begin(), masked_weight_vector->begin(), + [&](float mask_value, float weight_value) { + return mask_value * weight_value; + }); +} + +using ComputeTypes = + testing::Types, + csrblocksparse::bfloat16>; +TYPED_TEST_SUITE(CsrBlockSparseMatrixUtilsTest, ComputeTypes); + +TYPED_TEST(CsrBlockSparseMatrixUtilsTest, LoadLayer) { + const int kWeightVectorSize = 16; + const int kBiasVectorSize = 4; + std::vector ref_weight_vector; + std::vector ref_bias_vector; + std::vector ref_mask_vector; + std::vector ref_masked_weight_vector; + + GenerateRandomWeightBiasMaskVectors( + kWeightVectorSize, kBiasVectorSize, &ref_weight_vector, &ref_bias_vector, + &ref_mask_vector, &ref_masked_weight_vector); + + // This fixed16_weights.raw vector should only be read by LoadGenericLayer + // when |TypeParam| is a fixed16_type. + std::vector fixed_weight_vector(ref_weight_vector.size()); + std::transform(ref_weight_vector.begin(), ref_weight_vector.end(), + fixed_weight_vector.begin(), [](float weight) { + return fixed16(weight).raw_val(); + }); + ASSERT_TRUE(WriteArrayToFile(fixed_weight_vector, "fixed16_weights.raw", + this->output_dir_) + .ok()); + ASSERT_TRUE( + WriteArrayToFile(ref_weight_vector, "weights.raw", this->output_dir_) + .ok()); + ASSERT_TRUE( + WriteArrayToFile(ref_bias_vector, "bias.raw", this->output_dir_).ok()); + ASSERT_TRUE( + WriteArrayToFile(ref_mask_vector, "mask.raw", this->output_dir_).ok()); + + // Read in the weights, mask, and bias to a layer. + SparseLinearLayer actual_layer; + using DiskWeightType = + typename std::conditional::value, + csrblocksparse::fixed16_type, TypeParam>::type; + auto status = LoadGenericLayer( + /*prefix=*/"", /*zipped=*/false, this->output_dir_, + /*default_bias=*/0.f, &actual_layer); + ASSERT_TRUE(status.ok()); + // Multiply the read in layer with an identity matrix so we just get + // the weights added with bias. + std::vector identity(kBiasVectorSize * kBiasVectorSize, + TypeParam(0.f)); + for (int i = 0; i < identity.size(); i += kBiasVectorSize + 1) { + identity.at(i) = TypeParam(1.f); + } + FatCacheAlignedVector masked_weights_plus_bias(kBiasVectorSize, + kBiasVectorSize); + actual_layer.SpMM_bias( + VectorView(identity.data(), /*rows=*/kBiasVectorSize, + /*cols=*/kBiasVectorSize), + &masked_weights_plus_bias); + // |masked_weights_plus_bias| - bias = masked weights. + for (int col = 0; col < masked_weights_plus_bias.cols(); col++) { + MutableVectorView col_data = masked_weights_plus_bias.slice(col); + for (int row = 0; row < masked_weights_plus_bias.rows(); row++) { + int flat_index = row * masked_weights_plus_bias.cols() + col; + EXPECT_NEAR(static_cast(col_data[row]) - ref_bias_vector.at(row), + ref_masked_weight_vector.at(flat_index), this->tolerance_); + } + } +} +} // namespace +} // namespace csrblocksparse diff --git a/sparse_matmul/numerics/BUILD b/sparse_matmul/numerics/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..0a81aafb8a025a4af52aab884ce068c6a75504ec --- /dev/null +++ b/sparse_matmul/numerics/BUILD @@ -0,0 +1,160 @@ +# Base numeric types and transcendental functions. + +licenses(["notice"]) + +cc_library( + name = "fast_transcendentals", + srcs = [ + "fast_transcendentals.cc", + ], + hdrs = [ + "fast_transcendentals.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [":types"], +) + +cc_library( + name = "test_utils", + testonly = 1, + hdrs = [ + "test_utils.h", + ], + visibility = ["//sparse_matmul:__subpackages__"], + deps = [ + ":types", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "types", + hdrs = [ + "fixed_types.h", + "float16_types.h", + "type_utils.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + "@com_google_glog//:glog", + ], +) + +cc_library( + name = "fast_transcendentals_cc", + srcs = ["fast_transcendentals.cc"], + hdrs = ["fast_transcendentals.h"], + deps = [":types"], +) + +cc_test( + name = "fasttranscendentals_test", + size = "small", + srcs = [ + "fasttranscendentals_test.cc", + ], + deps = [ + ":fast_transcendentals", + ":test_utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "fasttranscendentals_test_fast", + size = "small", + srcs = [ + "fasttranscendentals_test.cc", + ], + copts = ["-DFAST_TRANSCENDENTALS"], + deps = [ + ":fast_transcendentals", + ":test_utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "fasttranscendentals_test_fast_accurate", + size = "small", + srcs = [ + "fasttranscendentals_test.cc", + ], + copts = [ + "-DFAST_TRANSCENDENTALS", + "-DACCURATE_TRANSCENDENTAL_APPROX", + ], + deps = [ + ":fast_transcendentals", + ":test_utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "fasttranscendentals_test_fast_accurate_sigmoidastanh", + size = "small", + srcs = [ + "fasttranscendentals_test.cc", + ], + copts = [ + "-DFAST_TRANSCENDENTALS", + "-DACCURATE_TRANSCENDENTAL_APPROX", + "-DSIGMOID_AS_TANH", + ], + deps = [ + ":fast_transcendentals", + ":test_utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "fasttranscendentals_test_fast_sigmoidastanh", + size = "small", + srcs = [ + "fasttranscendentals_test.cc", + ], + copts = [ + "-DFAST_TRANSCENDENTALS", + "-DSIGMOID_AS_TANH", + ], + deps = [ + ":fast_transcendentals", + ":test_utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "fasttranscendentals_test_faster_sigmoid", + size = "small", + srcs = [ + "fasttranscendentals_test.cc", + ], + copts = [ + "-DFASTER_TRANSCENDENTALS", + ], + deps = [ + ":fast_transcendentals", + ":test_utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "fixed_types_test", + size = "small", + srcs = [ + "fixed_types_test.cc", + ], + deps = [ + ":test_utils", + ":types", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/sparse_matmul/numerics/fast_transcendentals.cc b/sparse_matmul/numerics/fast_transcendentals.cc new file mode 100644 index 0000000000000000000000000000000000000000..75adf01aa612f130b0e56862eb510308ea63e0d0 --- /dev/null +++ b/sparse_matmul/numerics/fast_transcendentals.cc @@ -0,0 +1,81 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/numerics/fast_transcendentals.h" + +namespace csrblocksparse { + +// Maximum desired precision of the output. +static constexpr int kMaxMantissaBits = 14; + +// Returns (and builds if not done yet) a static data table that implements +// tanh on fixed32 input, returning another fixed32 with the given number of +// mantissa bits (which is assumed to be less than the input mantissa bits). +// NOTE that this function is intended to be used only with fixed16 outputs that +// are sign-extended to 32 bits for convenience, and will return a nullptr +// if asked for more than |kMaxMantissaBits| of precision in the output table. +const int32_t* TanhTable(int num_mantissa_bits_out) { + if (num_mantissa_bits_out > kMaxMantissaBits) return nullptr; + // Static data dynamically created and never destructed. + static const int32_t* tanh_luts[kMaxMantissaBits]; + if (tanh_luts[num_mantissa_bits_out - 1] == nullptr) { + // Total bits is number each side of the binary point. + int tanh_lut_bits = num_mantissa_bits_out + kNumTanhExpBits; + // Offset is the number of negative numbers represented. + int tanh_offset = 1 << tanh_lut_bits; + // Size is double the offset plus one more for zero. + int tanh_size = tanh_offset * 2 + 1; + // Conversion between int and float. + float float_factor = static_cast(1 << num_mantissa_bits_out); + int* tanh_lut = new int[tanh_size]; + // Initialize the table. + for (int i = 0; i < tanh_size; ++i) { + float x = (i - tanh_offset) / float_factor; + tanh_lut[i] = static_cast(std::round(tanhf(x) * float_factor)); + } + tanh_luts[num_mantissa_bits_out - 1] = tanh_lut; + } + return tanh_luts[num_mantissa_bits_out - 1]; +} + +// As TanhTable, but for Sigmoid. +const int32_t* SigmoidTable(int num_mantissa_bits_out) { + if (num_mantissa_bits_out > kMaxMantissaBits) return nullptr; + // Static data dynamically created and never destructed. + static const int32_t* sigmoid_luts[kMaxMantissaBits]; + if (sigmoid_luts[num_mantissa_bits_out - 1] == nullptr) { + // Total bits is number each side of the binary point minus one for the fact + // that the gradient never exceeds 1/4. (Could probably use -2.) + int sigmoid_lut_bits = + num_mantissa_bits_out + kNumSigmoidExpBits - kNumExtraSigmoidShiftBits; + // Offset is the number of negative numbers represented. + int sigmoid_offset = 1 << sigmoid_lut_bits; + // Size is double the offset plus one more for zero. + int sigmoid_size = sigmoid_offset * 2 + 1; + // Conversion between int and float. + float float_factor = static_cast(1 << num_mantissa_bits_out); + int* sigmoid_lut = new int[sigmoid_size]; + // Initialize the table. + for (int i = 0; i < sigmoid_size; ++i) { + constexpr int kSigmoidFactor = 1 << kNumExtraSigmoidShiftBits; + float x = ((i - sigmoid_offset) * kSigmoidFactor) / float_factor; + float sigmoid = 1.0f / (1.0f + expf(-x)); + sigmoid_lut[i] = static_cast(std::round(sigmoid * float_factor)); + } + sigmoid_luts[num_mantissa_bits_out - 1] = sigmoid_lut; + } + return sigmoid_luts[num_mantissa_bits_out - 1]; +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/numerics/fast_transcendentals.h b/sparse_matmul/numerics/fast_transcendentals.h new file mode 100644 index 0000000000000000000000000000000000000000..2c73eeec3ddfbb214da3a47281f832abdde64929 --- /dev/null +++ b/sparse_matmul/numerics/fast_transcendentals.h @@ -0,0 +1,1177 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FAST_TRANSCENDENTALS_H_ +#define LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FAST_TRANSCENDENTALS_H_ + +#include +#if defined __ARM_NEON || defined __aarch64__ +#include +#else +#include +#endif +#if defined __AVX__ || defined __AVX2__ +#include +#endif +#include + +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/type_utils.h" + +namespace csrblocksparse { + +// The input to exp is clipped to bounds that prevent overflow/underflow in a +// 32 bit float representation. e^80 ~ 6e34, which is close to maxfloat. +constexpr float kMaxExpInput = 80.f; +constexpr int kMaxExpInputInt = static_cast(kMaxExpInput); +constexpr float kMinExpInput = -80.f; +// tanh(9) ~ 0.99999997, which cannot be resolved from 1 in a float32. +constexpr float kMaxTanhInput = 9.f; +constexpr float kMinTanhInput = -9.f; +// sigmoid(18) ~ 0.999999985, which cannot be resolved from 1 in a float32. +constexpr float kMaxSigmoidInput = 18.f; +constexpr float kMinSigmoidInput = -18.f; +// kAConstant ~= 2^23 / ln 2 +constexpr uint32_t kAConstant = 0x4b38aa3b; +// kBConstant ~= (127 << 23) - 366000 +constexpr uint32_t kBConstant = 0x4e7de9a9; +// Coefficients of the rational approximation to tanh. +// Coefficients of the numerator polynomial (odd). +constexpr float kTanhAlpha1 = 4.89352455891786e-03; +constexpr float kTanhAlpha3 = 6.37261928875436e-04; +constexpr float kTanhAlpha5 = 1.48572235717979e-05; +constexpr float kTanhAlpha7 = 5.12229709037114e-08; +constexpr float kTanhAlpha9 = -8.60467152213735e-11; +constexpr float kTanhAlpha11 = 2.00018790482477e-13; +constexpr float kTanhAlpha13 = -2.76076847742355e-16; +// The monomial coefficients of the denominator polynomial (even). +constexpr float kTanhBeta0 = 4.89352518554385e-03; +constexpr float kTanhBeta2 = 2.26843463243900e-03; +constexpr float kTanhBeta4 = 1.18534705686654e-04; +constexpr float kTanhBeta6 = 1.19825839466702e-06; + +// Coefficients of the rational approximation to sigmoid. +// Coefficients of the numerator polynomial (odd). +constexpr float kSigmoidAlpha1 = 2.48287947061529e-01; +constexpr float kSigmoidAlpha3 = 8.51377133304701e-03; +constexpr float kSigmoidAlpha5 = 6.08574864600143e-05; +constexpr float kSigmoidAlpha7 = 1.15627324459942e-07; +constexpr float kSigmoidAlpha9 = 4.37031012579801e-11; + +// The monomial coefficients of the denominator polynomial (even). +constexpr float kSigmoidBeta0 = 9.93151921023180e-01; +constexpr float kSigmoidBeta2 = 1.16817656904453e-01; +constexpr float kSigmoidBeta4 = 1.70198817374094e-03; +constexpr float kSigmoidBeta6 = 6.29106785017040e-06; +constexpr float kSigmoidBeta8 = 5.76102136993427e-09; +constexpr float kSigmoidBeta10 = 6.10247389755681e-13; + +// x is the first term of the Taylor series approximation of tanh near 0 and +// because the leading error term of tanh(x) - x is O(x^3), it is good for a +// wide interval, use it in this region where the other approximation is +// inaccurate. tanh(x) = x - x^3 / 3 + 2x^5 / 15 - 17x^7 / 315 + ... +// Similarly for sigmoid where the first term is .25x +constexpr float kTanhLinearRegion = .15f; +constexpr float kSigmoidLinearRegion = .75f; + +// Maximum shift factor for 1/log 2 to keep it inside int32. +constexpr int kMaxLog2Shift = 30; +static const int kLogFactor = static_cast((1 << kMaxLog2Shift) / log(2.f)); +static const float kOneOverLog2 = 1.0f / log(2.f); +// Number of real mantissa bits in IEEE float32. +constexpr int kFloatMantissaBits = 23; +// Offset to correct the exponent value in the resulting float. +constexpr int kFloatExponentOffset = 127 << kFloatMantissaBits; +// Mask for mantissa. +constexpr int kFloatMantissaMask = (1 << kFloatMantissaBits) - 1; +// Mask for exponent; +constexpr int kFloatExponentMask = (-1) ^ kFloatMantissaMask; + +// ========== COMMON DOCUMENTATION FOR THE FLOATING EXPONENT TRICK ============ +// Summary: Use the exponent-mantissa representation of a floating point number +// to give exponentiation of 2 for free. If we desire f(z) = e^z = 2^(x+n), (for +// some fixed-point z expressed as an integer with imaginary binary point within +// it) then we have to compute x+n = z / ln 2 and then splitting x+n into +// n = int(x+n) and x = fract(x+n) in [0, 1), we can use n and 2^x as the +// exponent and mantissa of a floating point number, and that float is equal to +// e^z. For original reference see: +// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.9.4508&rep=rep1&type=pdf +// Important detail: +// IEEE floats are stored normalized, ie 1.bbbbbbb... x 2^exponent. The leading +// 1 bit is not actually stored, (as it is always 1), providing an extra bit of +// precision. +// Since 2^0=1 and 2^1=2, we can treat the problem as 2^x = 1 + u and we thus +// need a mapping x in [0, 1) -> u in [0, 1) and the 1 + is provided by the +// representation. +// In the original paper cited above, the mapping is u = x - c, where c is set +// to minimize the average error. The function to compute exp(x) this way is +// incredibly simple and computationally cheap, but not very accurate. +// Fortunately, the problem has been reduced to u = 2^x - 1 over [0, 1) for +// which it is far easier to construct accurate approximations with small +// polynomials than a full range exp(x), and this is what the cubic and quartic +// versions below do. An important feature of these functions is that they +// constrain the solution to be exact at 0 and 1 so there is continuity at each +// integer boundary where we wrap from 1 to 0 and increment the power of 2. + +// Coefficients for quartic representation of 2^x - 1 for x on [0,1). +// The quartic representation is 2^x - 1 ~ x - x(1-x)(ax^2 + bx + c), hence the +// coefficients of a quadratic are all that is required. +// Coefficients came from numerical experiments. +constexpr float kExpQuarticFactor2 = 0.0135302434f; +constexpr float kExpQuarticFactor1 = 0.0656107542f; +constexpr float kExpQuarticFactor0 = 0.306963906f; +// Coefficients for cubic representation of 2^x - 1 for x on [0,1] +// The cubic representation is 2^x - 1 ~ x - x(1-x)(mx + c), hence the +// coefficients of a linear function are all that is required. +// Coefficients came from numerical experiments. +constexpr float kExpCubicFactor1 = 0.0780252018f; +constexpr float kExpCubicFactor0 = 0.304684167f; +// Coefficients are optimized to minimize the absolute error on +// tanh = (e^2x - 1) / (e^2x + 1) instead of on pure e^x. + +// Enum that determines how a transcendental is computed. +enum TranscendentalMode { + // Cubic using 16 bit integer arithmetic. + TM_ORDER3_16BIT, + // Quartic using 16 bit integer arithmetic. + TM_ORDER4_16BIT, + // Quartic using 32 bit float arithmetic. + TM_ORDER4_FLOAT, +}; + +inline int FloatAsInt16(float x) { + return static_cast(x * (1 << 15) + 0.5f); +} + +inline int FloatAsInt32(float x) { + return static_cast(x * (1 << 30) + 0.5f); +} + +#if defined __ARM_NEON || defined __aarch64__ + +constexpr int kMaxSigmoidInputInt = static_cast(kMaxSigmoidInput); + +// Computes and returns 2^(x>>23) ie 2^u where x = u << 23 bits. +// Uses the quartic floating point exponent trick, see COMMON DOCUMENTATION FOR +// THE FLOATING EXPONENT TRICK above for details. +// Returns the true value, ie not scaled. +inline float32x4_t float32_pow2(float32x4_t x) { + // The input is already shifted left by 23 bits, so when we convert to int, + // the bottom 23 bits are the fractional part, and the top bits are the + // integer part. We want to compute a function of the fractional part, so + // we will mask it off and manipulate it. + int32x4_t exp_int_x = vcvtq_s32_f32(x); + // Mask to allow conversion of just the fractional part of x to fixed16<0>. + int32x4_t mantissa_mask16 = vdupq_n_s32(0x7fff00); + // Mask to allow conversion of just the fractional part of x to fixed32<1>. + int32x4_t mantissa_mask32 = vdupq_n_s32(0x7fffff); + // Narrowing shift to convert to fixed16<0>. + int16x4_t x_16 = vshrn_n_s32(vandq_s32(mantissa_mask16, exp_int_x), 8); + // Shift to convert to fixed32<1>. + int32x4_t x_32 = vshlq_n_s32(vandq_s32(mantissa_mask32, exp_int_x), 7); + // Compute the polynomial x(x - 1)(ax^2 + bx + c) of the fractional part. + // Ordering these lines carefully makes it faster, as some of the multiply + // operations can pipeline instead of waiting for the previous result. + int32x4_t x_squared = vmull_s16(x_16, x_16); + int16x4_t b = vdup_n_s16(FloatAsInt16(kExpQuarticFactor1)); + int32x4_t c = vdupq_n_s32(FloatAsInt32(kExpQuarticFactor0)); + int32x4_t bx_plus_c = vmlal_s16(c, b, x_16); + int16x4_t a = vdup_n_s16(FloatAsInt16(kExpQuarticFactor2)); + // Finish the quadratic: result = ax^2 + bx + c. + int32x4_t result = vmlal_s16(bx_plus_c, a, vshrn_n_s32(x_squared, 15)); + int32x4_t x_squared_minus_x = vsubq_s32(x_squared, x_32); + + // Multiply by x^2 - x. + result = vqrdmulhq_s32(result, x_squared_minus_x); + // Shift back to mantissa position. vqrdmulhq_s32 took 2x 30-mantissa bit + // inputs, made 60-mantissa bit result, doubled it to 61 bits, then discarded + // the bottom 32 making 29, so shift right 6 to get 23. + result = vshrq_n_s32(result, 6); + // Add the constant to normalize the exponent for IEEE format. + int32x4_t exp_offset = vdupq_n_s32(kFloatExponentOffset); + exp_int_x = vaddq_s32(exp_int_x, exp_offset); + exp_int_x = vaddq_s32(exp_int_x, result); + // Cast back to float, as we just computed the exponent and mantissa and + // assembled them in IEEE format. + return vreinterpretq_f32_s32(exp_int_x); +} + +// Scaled float to float exp approximation, using a quartic refinement of +// the exponent trick. See COMMON DOCUMENTATION FOR THE FLOATING EXPONENT TRICK +// above for details. Input is a fixed32<31 - mantissa_bits> that has been +// converted to a float without any further shifting. MUST HAVE ALREADY BEEN +// CLIPPED to a suitable range for exp! +// Returns a vector of standard unscaled floats. +inline float32x4_t fixed32_exp_float_preclipped(const int mantissa_bits, + float32x4_t x) { + // Divide by log 2 to convert problem to 2^x, and scale to match the + // mantissa bits required by IEEE floats. + // This is the shift of the FP mantissa relative to the input mantissa. + const int kXShift = kFloatMantissaBits - mantissa_bits; + const float kLogFactor = static_cast(1 << kXShift); + float32x4_t factor = vdupq_n_f32(kLogFactor * kOneOverLog2); + float32x4_t y = vmulq_f32(x, factor); + // Now compute 2^x. + return float32_pow2(y); +} + +// uses trick that 2^x can be computed by shifting integer into the +// exponent, see the following reference for a derivation using double: +// goo.gl/aUVTK3 +// Input x is clamped to [-64, 64], even infinity and NaN. +// Accurate to within 3% relative across the entire range. +// Fully pipelined throughput is about 10 cycles per fast_exp call. +inline float32x4_t fast_exp(float32x4_t x) { +#if defined FAST_TRANSCENDENTALS && __ARM_ARCH >= 800 + // Uses vcvtnq_s32_f32, not available on ARM v7 NEON. + + // Load A and B, which are defined as integers into float registers. + float32x4_t A = vreinterpretq_f32_u32(vdupq_n_u32(kAConstant)); + float32x4_t res = vreinterpretq_f32_u32(vdupq_n_u32(kBConstant)); + + // Make sure x within the allowed range. + x = vminq_f32(x, vdupq_n_f32(kMaxExpInput)); + x = vmaxq_f32(x, vdupq_n_f32(kMinExpInput)); + + // res = A * x + B. + // This shifts x into the exponent field and adds the bias. + res = vmlaq_f32(res, A, x); + + // Convert back to an integer, this is what uses the floating point + // unit to compute 2^x. + int32x4_t x_int = vcvtnq_s32_f32(res); + + return vreinterpretq_f32_s32(x_int); +#else + float32x4_t return_val = vdupq_n_f32(0.f); + + float exponent = expf(vgetq_lane_f32(x, 0)); + return_val = vld1q_lane_f32(&exponent, return_val, 0); + + exponent = expf(vgetq_lane_f32(x, 1)); + return_val = vld1q_lane_f32(&exponent, return_val, 1); + exponent = expf(vgetq_lane_f32(x, 2)); + return_val = vld1q_lane_f32(&exponent, return_val, 2); + exponent = expf(vgetq_lane_f32(x, 3)); + return_val = vld1q_lane_f32(&exponent, return_val, 3); + + return return_val; +#endif // FAST_TRANSCENDENTALS +} + +// This version does a conversion of the input to floating point, then calls +// the floating point fast_exp function. There is another version +// fast_exp_fixed, that never does a conversion and is less accurate, but much +// faster. +template +inline float32x4_t fast_exp(int32x4_t x) { + return fast_exp(vcvtq_n_f32_s32(x, 31 - ExponentBits)); +} + +// Performs an exp estimate without doing any floating point operations. The +// result is a floating point number. See scalar version for an explanation. +template +inline float32x4_t fast_exp_fixed(int32x4_t x) { + static_assert(ExponentBits > 8, "Must have more than 8 ExponentBits"); + constexpr int kA = 1.4426950408889634 * (1 << (ExponentBits - 8)); + constexpr int kB = (127 << 23) - 366000; + + constexpr int maxInput = 80 << (31 - ExponentBits); + constexpr int minInput = -maxInput; + + int32x4_t A = vdupq_n_s32(kA); + int32x4_t res = vdupq_n_s32(kB); + + // Make sure x within the allowed range. + x = vminq_s32(x, vdupq_n_s32(maxInput)); + x = vmaxq_s32(x, vdupq_n_s32(minInput)); + + // res = A * x + B. + // This shifts x into the exponent field and adds the bias. + res = vmlaq_s32(res, A, x); + + return vreinterpretq_f32_s32(res); +} + +// fast_exp_norange_check uses vcvtnq_s32_f32, not available on ARM v7 NEON. +#if __ARM_ARCH >= 800 +namespace detail { +// tanh can do range check once. +// Input x is clamped to [-64, 64], even infinity and NaN. +inline float32x4_t fast_exp_norange_check(float32x4_t x) { + float32x4_t A = vreinterpretq_f32_u32(vdupq_n_u32(kAConstant)); + float32x4_t res = vreinterpretq_f32_u32(vdupq_n_u32(kBConstant)); + + res = vmlaq_f32(res, A, x); + + int32x4_t x_int = vcvtnq_s32_f32(res); + + return vreinterpretq_f32_s32(x_int); +} + +} // namespace detail +#endif // __ARM_ARCH >= 800 + +// Clips float input to [-kLimit,kLimit]. +inline float32x4_t ClipToFloatBounds(const float kLimit, const float32x4_t x) { + // Clip to the input bounds for this approximation. + float32x4_t clip_limit = vdupq_n_f32(kLimit); + float32x4_t clipped_x = vminq_f32(x, clip_limit); + clip_limit = vnegq_f32(clip_limit); + return vmaxq_f32(clipped_x, clip_limit); +} + +inline float32x4_t float_tanh_float(const float32x4_t& x) { + float32x4_t clipped_x = ClipToFloatBounds(kMaxTanhInput, x); + // Divide by log 2 to convert problem to 2^x, double (as we need exp(2x)) and + // scale to the mantissa bits required by float32_pow2 all in one multiply. + // Add one to double the input. + const float kLogFactor = static_cast(1 << (kFloatMantissaBits + 1)); + float32x4_t factor = vdupq_n_f32(kLogFactor * kOneOverLog2); + clipped_x = vmulq_f32(clipped_x, factor); + // Now compute 2^x. + float32x4_t exp_result = float32_pow2(clipped_x); + // Now compute tanh using (e^2x - 1) / (e^2x + 1). + float32x4_t one = vdupq_n_f32(1.0f); + float32x4_t numerator = vsubq_f32(exp_result, one); + float32x4_t denominator = vaddq_f32(exp_result, one); + float32x4_t recp = vrecpeq_f32(denominator); + // Newton-Raphson iteration, accuracy is important for audio quality + recp = vmulq_f32(recp, vrecpsq_f32(recp, denominator)); + recp = vmulq_f32(recp, numerator); + // Compute 3rd-order Taylor tanh ~ x - x^3/3 for high accuracy and thus low + // relative error close to 0. + float32x4_t third = vdupq_n_f32(1.0f / 3.0f); + float32x4_t taylor = vmulq_f32(x, x); + taylor = vmulq_f32(taylor, x); + taylor = vmulq_f32(taylor, third); + taylor = vsubq_f32(x, taylor); + // Test |x| <= 1/9, roughly where the errors cross over, without needing yet + // another constant. + float32x4_t ninth = vmulq_f32(third, third); + uint32x4_t cmp_results = vcaleq_f32(x, ninth); + return vbslq_f32(cmp_results, taylor, recp); +} + +// Calculates (exp(x) - exp(-x)) / (exp(x) + exp(-x)). +// Input x is clamped to [-9, 9], even infinity and NaN. +// See test program for bounds. Throughput of FAST is 334 Mega/sec, +// throughput of accurate is 232 Mega/sec. +inline float32x4_t fast_tanh(float32x4_t x) { +#if defined FASTER_TRANSCENDENTALS + return float_tanh_float(x); +#elif defined ACCURATE_TRANSCENDENTAL_APPROX && defined FAST_TRANSCENDENTALS + x = vminq_f32(x, vdupq_n_f32(kMaxTanhInput)); + x = vmaxq_f32(x, vdupq_n_f32(kMinTanhInput)); + + // The monomial coefficients of the numerator polynomial (odd). + const float32x4_t alpha_1 = vdupq_n_f32(kTanhAlpha1); + const float32x4_t alpha_3 = vdupq_n_f32(kTanhAlpha3); + const float32x4_t alpha_5 = vdupq_n_f32(kTanhAlpha5); + const float32x4_t alpha_7 = vdupq_n_f32(kTanhAlpha7); + const float32x4_t alpha_9 = vdupq_n_f32(kTanhAlpha9); + const float32x4_t alpha_11 = vdupq_n_f32(kTanhAlpha11); + const float32x4_t alpha_13 = vdupq_n_f32(kTanhAlpha13); + + // The monomial coefficients of the denominator polynomial (even). + const float32x4_t beta_0 = vdupq_n_f32(kTanhBeta0); + const float32x4_t beta_2 = vdupq_n_f32(kTanhBeta2); + const float32x4_t beta_4 = vdupq_n_f32(kTanhBeta4); + const float32x4_t beta_6 = vdupq_n_f32(kTanhBeta6); + + // Since the polynomials are odd/even, we need x^2. + const float32x4_t x2 = vmulq_f32(x, x); + + // Evaluate the numerator polynomial |p|. + float32x4_t p = vmlaq_f32(alpha_11, x2, alpha_13); + p = vmlaq_f32(alpha_9, x2, p); + p = vmlaq_f32(alpha_7, x2, p); + p = vmlaq_f32(alpha_5, x2, p); + p = vmlaq_f32(alpha_3, x2, p); + p = vmlaq_f32(alpha_1, x2, p); + p = vmulq_f32(x, p); + + // Evaluate the denominator polynomial p. + float32x4_t q = vmlaq_f32(beta_4, x2, beta_6); + q = vmlaq_f32(beta_2, x2, q); + q = vmlaq_f32(beta_0, x2, q); + + // Divide the numerator by the denominator. + float32x4_t recp = vrecpeq_f32(q); + recp = vmulq_f32(recp, vrecpsq_f32(recp, q)); + return vmulq_f32(p, recp); +#elif defined FAST_TRANSCENDENTALS && __ARM_ARCH >= 800 + // Uses vcvtnq_s32_f32, not available on ARM v7 NEON. + + x = vminq_f32(x, vdupq_n_f32(kMaxTanhInput)); + x = vmaxq_f32(x, vdupq_n_f32(kMinTanhInput)); + float32x4_t exp_est = detail::fast_exp_norange_check(x); + float32x4_t neg_exp_est = detail::fast_exp_norange_check(-x); + + // If we're in the linear region. + // caleq = compare absolute <= + uint32x4_t cmp_results = vcaleq_f32(x, vdupq_n_f32(kTanhLinearRegion)); + + float32x4_t diff = vsubq_f32(exp_est, neg_exp_est); + float32x4_t sum = vaddq_f32(exp_est, neg_exp_est); + float32x4_t recp = vrecpeq_f32(sum); + recp = vmulq_f32(recp, vrecpsq_f32(recp, sum)); + float32x4_t tanh_estimate = vmulq_f32(diff, recp); + + // Based on comparison, possibly copy x through instead of calculated value. + // TODO(b/191497441): Is the compiler generating VBIT or VBSL ? VBIT is one + // cycle and VBSL is two... documentation suggests it can do either. + return vbslq_f32(cmp_results, x, tanh_estimate); +#else + float32x4_t return_val = vdupq_n_f32(0.f); + + float tanh_value = tanhf(vgetq_lane_f32(x, 0)); + return_val = vld1q_lane_f32(&tanh_value, return_val, 0); + tanh_value = tanhf(vgetq_lane_f32(x, 1)); + return_val = vld1q_lane_f32(&tanh_value, return_val, 1); + tanh_value = tanhf(vgetq_lane_f32(x, 2)); + return_val = vld1q_lane_f32(&tanh_value, return_val, 2); + tanh_value = tanhf(vgetq_lane_f32(x, 3)); + return_val = vld1q_lane_f32(&tanh_value, return_val, 3); + + return return_val; +#endif // FAST_TRANSCENDENTALS +} + +// Input x is clamped to [-18, 18], even infinity and NaN. +// See tests for error bounds. Using SIGMOID_AS_TANH with +// ACCURATE_TRANSCENDENTAL_APPROX is both faster and more accurate. Using +// SIGMOID_AS_TANH with just FAST is slower, but more accurate. +// SIGMOID_AS_TANH, ACCURATE is 205 Mega/sec +// SIGMOID_AS_TANH, FAST is 290 Mega/sec +// FAST is 340 Mega/sec +inline float32x4_t fast_sigmoid(float32x4_t x) { +#ifdef SIGMOID_AS_TANH + float32x4_t half = vdupq_n_f32(0.5f); + return vmlaq_f32(half, half, fast_tanh(vmulq_f32(half, x))); +#else // SIGMOID_AS_TANH +#if defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX + x = vminq_f32(x, vdupq_n_f32(kMaxSigmoidInput)); + x = vmaxq_f32(x, vdupq_n_f32(kMinSigmoidInput)); + + // The monomial coefficients of the numerator polynomial (odd). + const float32x4_t alpha_1 = vdupq_n_f32(kSigmoidAlpha1); + const float32x4_t alpha_3 = vdupq_n_f32(kSigmoidAlpha3); + const float32x4_t alpha_5 = vdupq_n_f32(kSigmoidAlpha5); + const float32x4_t alpha_7 = vdupq_n_f32(kSigmoidAlpha7); + const float32x4_t alpha_9 = vdupq_n_f32(kSigmoidAlpha9); + + // The monomial coefficients of the denominator polynomial (even). + const float32x4_t beta_0 = vdupq_n_f32(kSigmoidBeta0); + const float32x4_t beta_2 = vdupq_n_f32(kSigmoidBeta2); + const float32x4_t beta_4 = vdupq_n_f32(kSigmoidBeta4); + const float32x4_t beta_6 = vdupq_n_f32(kSigmoidBeta6); + const float32x4_t beta_8 = vdupq_n_f32(kSigmoidBeta8); + const float32x4_t beta_10 = vdupq_n_f32(kSigmoidBeta10); + + // Since the polynomials are odd/even, we need x^2. + const float32x4_t x2 = vmulq_f32(x, x); + + // Evaluate the numerator polynomial p. + float32x4_t p = vmlaq_f32(alpha_7, x2, alpha_9); + p = vmlaq_f32(alpha_5, x2, p); + p = vmlaq_f32(alpha_3, x2, p); + p = vmlaq_f32(alpha_1, x2, p); + p = vmulq_f32(x, p); + + // Evaluate the denominator polynomial p. + float32x4_t q = vmlaq_f32(beta_8, x2, beta_10); + q = vmlaq_f32(beta_6, x2, q); + q = vmlaq_f32(beta_4, x2, q); + q = vmlaq_f32(beta_2, x2, q); + q = vmlaq_f32(beta_0, x2, q); + + // Divide the numerator by the denominator. + float32x4_t recp = vrecpeq_f32(q); + recp = vmulq_f32(recp, vrecpsq_f32(recp, q)); + return vmlaq_f32(vdupq_n_f32(0.5f), p, recp); +#elif defined FAST_TRANSCENDENTALS + float32x4_t denom = vaddq_f32(fast_exp(vnegq_f32(x)), vdupq_n_f32(1.f)); + + float32x4_t recp = vrecpeq_f32(denom); + // Newton-Raphson iteration, accuracy is important for audio quality. + recp = vmulq_f32(recp, vrecpsq_f32(recp, denom)); + float32x4_t half = vdupq_n_f32(0.5f); + float32x4_t quarter = vdupq_n_f32(0.245f); + float32x4_t linear_approx = vmlaq_f32(half, quarter, x); + uint32x4_t cmp_results = vcaleq_f32(x, vdupq_n_f32(kSigmoidLinearRegion)); + + return vbslq_f32(cmp_results, linear_approx, recp); +#else + float32x4_t return_val = vdupq_n_f32(0.f); + + float result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 0))); + return_val = vld1q_lane_f32(&result, return_val, 0); + result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 1))); + return_val = vld1q_lane_f32(&result, return_val, 1); + result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 2))); + return_val = vld1q_lane_f32(&result, return_val, 2); + result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 3))); + return_val = vld1q_lane_f32(&result, return_val, 3); + + return return_val; +#endif // FAST_TRANSCENDENTALS +#endif // SIGMOID_AS_TANH +} + +// Scalar implementations, mainly useful for testing. +inline float fast_exp(float x) { + return vgetq_lane_f32(fast_exp(vdupq_n_f32(x)), 0); +} + +template +inline float fast_exp(fixed32 x) { + return vgetq_lane_f32(fast_exp(vdupq_n_s32(x.raw_val())), 0); +} + +// Returns the exponent of a fixed point number in floating point without ever +// doing any conversions. Less accurate than the version that does conversions, +// but still accurate to within 4% relative for x < 16. +template +inline float fast_exp_fixed(fixed32 x) { + return vgetq_lane_f32(fast_exp_fixed(vdupq_n_s32(x.raw_val())), + 0); +} + +inline float fast_sigmoid(float x) { + return vgetq_lane_f32(fast_sigmoid(vdupq_n_f32(x)), 0); +} + +inline float fast_tanh(float x) { + return vgetq_lane_f32(fast_tanh(vdupq_n_f32(x)), 0); +} + +// Clips integer input to [-|kLimit|, |kLimit|]. +// Input: register containins 4x fixed32 with mantissa_bits. +// Output: register containing 4x fixed32 limited to +// [-|kLimit| << |mantissa_bits|, |kLimit| << |mantissa_bits|]. +template +inline int32x4_t ClipToBounds(const int mantissa_bits, const int32x4_t x) { + // Clip to the input bounds for this approximation. + int32x4_t clip_limit = vdupq_n_s32(-(kLimit << mantissa_bits)); + int32x4_t clipped_x = vmaxq_s32(x, clip_limit); + clip_limit = vnegq_s32(clip_limit); + return vminq_s32(clipped_x, clip_limit); +} + +// Fixed32 sigmoid approximation via a quadratic refinement of the exponent +// trick. +// Input: Register containing 4x fixed32 with |mantissa_bits|. +// Output: Register containing 4x float results. +inline float32x4_t fixed32_sigmoid_float(const int mantissa_bits, + const int32x4_t x) { + int32x4_t input = vnegq_s32(x); + float32x4_t y = + vcvtq_f32_s32(ClipToBounds(mantissa_bits, input)); + y = fixed32_exp_float_preclipped(mantissa_bits, y); + float32x4_t one = vdupq_n_f32(1.0f); + // Approximate reciprocal is not accurate enough - use full division. + float32x4_t denom = vaddq_f32(y, one); + float32x4_t recp = vrecpeq_f32(denom); + // Newton-Raphson iteration, accuracy is important for audio quality + recp = vmulq_f32(recp, vrecpsq_f32(recp, denom)); + return recp; +} + +template +inline float32x4_t fast_sigmoid(int32x4_t x) { +#if defined FASTER_TRANSCENDENTALS + // Computation will fail to produce the right result if the input mantissa + // bits exceeds the number in a float. + static_assert(kFloatMantissaBits >= fixed32::kMantissaBits, + "Mantissa bits must be at most 23!"); + return fixed32_sigmoid_float(fixed32::kMantissaBits, x); +#else + return fast_sigmoid(vcvtq_n_f32_s32(x, fixed32::kMantissaBits)); +#endif // FASTER_TRANSCENDENTALS +} + +template +inline float fast_sigmoid(fixed32 x) { + return vgetq_lane_f32(fast_sigmoid(vdupq_n_s32(x.raw_val())), + 0); +} + +#else // defined __ARM_NEON || defined __aarch64__ + +inline float fast_exp(float x) { +#ifdef FAST_TRANSCENDENTALS + if (isnan(x)) return 0.0f; + x = std::max(std::min(x, kMaxExpInput), kMinExpInput); + float AConstant, BConstant; + memcpy(&AConstant, &kAConstant, sizeof(int)); + memcpy(&BConstant, &kBConstant, sizeof(int)); + float y = x * AConstant + BConstant; + int x_int = static_cast(y); + float ret; + memcpy(&ret, &x_int, sizeof(float)); + return ret; +#else + return expf(x); +#endif // FAST_TRANSCENDENTALS +} + +template +inline float fast_exp(fixed32 x) { + return fast_exp(static_cast(x)); +} + +template +inline float fast_exp_fixed(fixed32 x) { + static_assert(ExponentBits > 8, "Must have more than 8 ExponentBits"); + int matched_decimal = + std::max(std::min(x.raw_val(), (80 << (31 - ExponentBits))), + -(80 << (31 - ExponentBits))); + // Convert 1 / log(2) to 16-bit fixed point with 1 exponent bit + // (1 / log(2)) * (1 << 14), but then right shift by the appropriate amount to + // line the decimal point up with the 32-bit float representation. + // (MantissaBits of x) + (MantissaBits of constant) = 23 + // 23 - (MantissaBits of x) = MantissaBits of constant + // 23 - (31 - ExponentBits of x) = ... + // (ExponentBits of x - 8) = MantissaBits of constant + const int16_t A = (1.f / logf(2.f)) * (1 << (ExponentBits - 8)); + // Same rationale as for floating point versions, bias exponent, subtract + // 366000 to reduce error by centering approximation, instead of being + // one-sided. + const int B = (127 << 23) - 366000; + matched_decimal = A * matched_decimal + B; + float ret_val; + memcpy(&ret_val, &matched_decimal, sizeof(float)); + return ret_val; +} + +inline float fast_tanh(float x) { +#if defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX + // Doesn't do anything fancy, just a 13/6-degree rational interpolant which + // is accurate up to a couple of ulp in the range [-9, 9], outside of which + // fl(tanh(x)) = +/-1. + x = std::max(std::min(x, kMaxTanhInput), kMinTanhInput); + + // Since the polynomials are odd/even, we need x^2. + float x2 = x * x; + + // Evaluate numerator. + float p = kTanhAlpha11 + x2 * kTanhAlpha13; + p = kTanhAlpha9 + x2 * p; + p = kTanhAlpha7 + x2 * p; + p = kTanhAlpha5 + x2 * p; + p = kTanhAlpha3 + x2 * p; + p = kTanhAlpha1 + x2 * p; + p = x * p; + + // Evaluate denominator. + float q = kTanhBeta4 + x2 * kTanhBeta6; + q = kTanhBeta2 + x2 * q; + q = kTanhBeta0 + x2 * q; + + return p / q; +#elif defined FAST_TRANSCENDENTALS + if (std::abs(x) < kTanhLinearRegion) { + return x; + } else { + x = std::max(std::min(x, kMaxTanhInput), kMinTanhInput); + float positive = fast_exp(x); + float negative = fast_exp(-x); + return (positive - negative) / (positive + negative); + } +#else + return tanhf(x); +#endif // FAST_TRANSCENDENTALS +} + +inline float fast_sigmoid(float x) { +#ifdef SIGMOID_AS_TANH + return .5f * fast_tanh(.5f * x) + .5f; +#else +#if defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX + // Doesn't do anything fancy, just a 9/10-degree rational interpolant which + // interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulp in the range + // [-18, 18], outside of which the fl(sigmoid(x)) = {0|1}. The shifted + // sigmoid is interpolated because it was easier to make the fit converge. + // See GenericPacketMath.h* in the open source Eigen library. + x = std::max(std::min(x, kMaxSigmoidInput), kMinSigmoidInput); + + // Since the polynomials are odd/even, we need x^2. + float x2 = x * x; + + // Evaluate numerator. + float p = kSigmoidAlpha7 + x2 * kSigmoidAlpha9; + p = kSigmoidAlpha5 + x2 * p; + p = kSigmoidAlpha3 + x2 * p; + p = kSigmoidAlpha1 + x2 * p; + p = x * p; + + // Evaluate denominator. + float q = kSigmoidBeta8 + x2 * kSigmoidBeta10; + q = kSigmoidBeta6 + x2 * q; + q = kSigmoidBeta4 + x2 * q; + q = kSigmoidBeta2 + x2 * q; + q = kSigmoidBeta0 + x2 * q; + + return p / q + 0.5f; +#elif defined FAST_TRANSCENDENTALS + if (std::abs(x) < kSigmoidLinearRegion) { + return .245 * x + .5; + } else { + return 1.f / (1.f + fast_exp(-x)); + } +#else + return 1.f / (1.f + expf(-x)); +#endif // FAST_TRANSCENDENTALS +#endif // SIGMOID_AS_TANH +} + +template +inline float fast_sigmoid(fixed32 x) { + return fast_sigmoid(static_cast(x)); +} + +#endif // defined __aarch64__ + +// Number of exponent bits to use for tanh. +static constexpr int kNumTanhExpBits = 3; +// Number of exponent bits to use for sigmoid. +static constexpr int kNumSigmoidExpBits = 4; +// Number of extra bits to shift sigmoid, due to its low gradient. +static constexpr int kNumExtraSigmoidShiftBits = 1; + +// Returns (and builds if not done yet) a static data table (that is never +// deleted, as per the style guide) that implements tanh on fixed32 input, +// returning another fixed32 with the given number of mantissa bits (which is +// assumed to be less than the input mantissa bits). +// NOTE that this function is intended to be used only with fixed16 outputs that +// are sign-extended to 32 bits for convenience, and will return a nullptr +// if asked for more than |kMaxMantissaBits| of precision in the output table. +const int* TanhTable(int num_mantissa_bits_out); +// As TanhTable, but for Sigmoid. +const int* SigmoidTable(int num_mantissa_bits_out); + +// Scalar/generic function to compute and return the fast approximation to exp +// via a polynomial refinement of the floating point exponent trick. +// TM_ORDER4_16BIT:Max relative error < 5e-6, absolute error < 1e-5 for x < 1. +// TM_ORDER3_16BIT:Max relative error < 1.1e-4, absolute error < 3e-4 for x +// < 1. +template +float fixed32_exp(fixed32 x) { + constexpr int kMantissaBits = MantissaBitsOf>::value; + // Clip x to min/max exp input to avoid infinities. + int64_t clipped_x = + std::max(std::min(x.raw_val(), kMaxExpInputInt << kMantissaBits), + -(kMaxExpInputInt << kMantissaBits)); + // First convert problem from e^x to 2^x by multiplying by 1/log(2). + // To maximize precision, log_factor is shifted left the maximum amount to + // keep within int32, and we shift x left a further amount such that the + // binary point of the product sits in the correct place in the top 32 bits of + // the result to be used directly as a float. We can't do that directly, as x + // would overflow, so we have to shift by 1 bit less and shift the result by + // 1 bit less to match. + constexpr int kXShift = + kFloatMantissaBits + 31 - kMaxLog2Shift - kMantissaBits; + static_assert(kXShift >= 0, + "Mantissa bits > kFloatMantissaBits + 31 - kMaxLog2Shift"); + clipped_x <<= kXShift; + int float_as_int = (kLogFactor * clipped_x >> 31) + kFloatExponentOffset; + // Separate the resulting fixed-point into integer and fractional parts. + int int_part = float_as_int & kFloatExponentMask; + int float_part = float_as_int & kFloatMantissaMask; + float fraction = static_cast(float_part) / (1 << kFloatMantissaBits); + // Compute the mantissa = 2^fraction using: + // fraction - fraction*(1-fraction)*(polynomial of fraction) + // This guarantees exactness at 0 and 1, providing continuity of the error at + // integer boundaries. + float mantissa; + if (kOrder == TM_ORDER4_16BIT || kOrder == TM_ORDER4_FLOAT) { + mantissa = (kExpQuarticFactor2 * fraction + kExpQuarticFactor1) * fraction + + kExpQuarticFactor0; + } else if (kOrder == TM_ORDER3_16BIT) { + mantissa = kExpCubicFactor1 * fraction + kExpCubicFactor0; + } + mantissa = fraction - fraction * (1.0f - fraction) * mantissa; + // Since the function above guarantees to stay within [0, 1), we could do all + // the above in fixed point if necessary, in which case, we can just stuff + // the bottom kFloatMantissaBits in with the exponent and we are done. + // In the floating point world, it is simpler to just multiply them together. + float result; + memcpy(&result, &int_part, sizeof(float)); + return result * (1.0f + mantissa); +} + +// Computes and returns tanh(x) fixed32->float using a polynomial refinement of +// the floating point exponent trick. +// kOrder=4: Absolute error < 1.8e-6. Relative error < 1.2e-4 for |x| > 0.01. +// kOrder=3: Absolute error < 6e-5. Relative error < 3e-3 for |x| > 0.01 +template +float fixed32_tanh(fixed32 x) { + float float_x = static_cast(x); + if (std::abs(float_x) < 1.0f / 9.0f) { + return float_x * (1 - float_x * float_x / 3.0f); + } + x = static_cast>(x.raw_val() * 2); + float exp_2x = fixed32_exp(x); + return (exp_2x - 1.0f) / (exp_2x + 1.0f); +} + +// Computes and returns sigmoid(x) fixed32->float using a polynomial refinement +// of the floating point exponent trick. +// TM_ORDER4_16BIT: Absolute error < 9e-7, relative < 4e-6. +// TM_ORDER3_16BIT: Absolute error < 3e-5, relative < 1.1e-4. +template +float fixed32_sigmoid(fixed32 x) { + x = static_cast>(-x.raw_val()); + float exp_x = fixed32_exp(x); + return 1.0f / (exp_x + 1.0f); +} + +#if defined __AVX2__ + +// Inline function to access an int32 data table by shifting |x| right by +// |kNumShiftBits|, and adding |kTableOffset| to the result. |x| contains 8 +// indices and 8 results are returned. The data table is of size +// |kTableOffset| * 2 + 1. +template +inline __m256i index_data_table(const int32_t* data_table, const __m256i& x) { + // Shift right with rounding to match input and output precision. + __m256i shifted = _mm256_set1_epi32(1 << (kNumShiftBits - 1)); + shifted = _mm256_add_epi32(x, shifted); + shifted = _mm256_srai_epi32(shifted, kNumShiftBits); + // Add the offset. + __m256i addend = _mm256_set1_epi32(kTableOffset); + shifted = _mm256_add_epi32(shifted, addend); + // And clamp to the indices of the LUT. + addend = _mm256_add_epi32(addend, addend); + shifted = _mm256_min_epi32(shifted, addend); + shifted = _mm256_max_epi32(shifted, _mm256_setzero_si256()); + // Lookup the results in the table. + return _mm256_i32gather_epi32(data_table, shifted, 4); +} + +// Fixed32 to fixed16-in-an-int32 tanh LUT function. +// Input: register containins 8x fixed32 with |NumInputMantissaBits|. +// Output: a register containing 8x fixed16 with |NumOutputMantissaBits|, but +// note that they are sign-extended to 32 bits and are therefore basically the +// same as fixed32 with |NumOutputMantissaBits|. +template +inline __m256i fixed32_tanh_fixed16(const int* tanh_table, const __m256i& x) { + // Lose the unnecessary input precision. + constexpr int kNumShiftBits = NumInputMantissaBits - NumOutputMantissaBits; + constexpr int kTableOffset = 1 << (NumOutputMantissaBits + kNumTanhExpBits); + return index_data_table(tanh_table, x); +} + +// Fixed32 to fixed16-in-an-int32 sigmoid LUT function. +// Input: register containins 8x fixed32 with |NumInputMantissaBits|. +// Output: a register containing 8x fixed16 with |NumOutputMantissaBits|, but +// note that they are sign-extended to 32 bits and are therefore basically the +// same as fixed32 with |NumOutputMantissaBits|. +template +inline __m256i fixed32_sigmoid_fixed16(const int* sigmoid_table, + const __m256i& x) { + // Lose the unnecessary input precision. + constexpr int kNumShiftBits = + kNumExtraSigmoidShiftBits + NumInputMantissaBits - NumOutputMantissaBits; + constexpr int kTableOffset = 1 + << (NumOutputMantissaBits + kNumSigmoidExpBits - + kNumExtraSigmoidShiftBits); + return index_data_table(sigmoid_table, x); +} + +// Convert 2x registers of 8x float32 into 1 register of 16x16 bit fixed int, +// assuming that the floats are already scaled up. +inline __m256i PackFloatsToFixed16(const __m256& x0, const __m256& x1) { + __m256i int0 = _mm256_cvtps_epi32(x0); + __m256i int1 = _mm256_cvtps_epi32(x1); + int0 = _mm256_packs_epi32(int0, int1); + // Swap the middle 64 bit elements so the results are in the right order. + return _mm256_permute4x64_epi64(int0, 0xd8); +} + +// Clips integer input to [-|kLimit|, |kLimit|]. +// Input: register containins 8x fixed32 with |mantissa_bits|. +// Output: register containing 8x fixed32 limited to +// [-|kLimit| << |mantissa_bits|, |kLimit| << |mantissa_bits|]. +template +inline __m256i ClipToBounds(const int mantissa_bits, const __m256i& x) { + // Clip to the input bounds for this approximation. + __m256i clip_limit = _mm256_set1_epi32(-(kLimit << mantissa_bits)); + __m256i clipped_x = _mm256_max_epi32(x, clip_limit); + // This quickly negates the limit without having to load another constant. + clip_limit = _mm256_sign_epi32(clip_limit, clip_limit); + return _mm256_min_epi32(clipped_x, clip_limit); +} + +// Clips float input to [-|kLimit|, |kLimit|]. +// Input: register containins 8x float. +// Output: register containing 8x float limited to [-|kLimit|, |kLimit|]. +inline __m256 ClipToFloatBounds(const float kLimit, const __m256& x) { + __m256 clip_limit = _mm256_set1_ps(kLimit); + __m256 clipped_x = _mm256_min_ps(x, clip_limit); + clip_limit = _mm256_set1_ps(-kLimit); + return _mm256_max_ps(clipped_x, clip_limit); +} + +// Float to float power of 2 approximation, using a quartic refinement of +// the exponent trick. For TM_ORDER4_16BIT and TM_ORDER3_16BIT, implementation +// is entirely in integer, using 16x16=16 multiplication, using AVX2, which +// enables 16 elements to be computed in parallel, hence the double register +// input/output args. +// The price paid for this speed is an increase in error over the (scalar) int32 +// example implementations above by a variable factor of 4-10. +// For the TM_ORDER4_FLOAT case, the computation is all done in float, solving +// this lower precision problem. +// NOTE: The input must have already been clipped to prevent overflow, which +// sets the practical limit to +/-126 << kFloatMantissaBits. +// NOTE: The input is a scaled float, as if converted raw from int, and the +// scale factor is fixed at kFloatMantissaBits! +// Input: 2x register containining 8x float * 1 << kFloatMantissaBits. +// Output: 2x register containing 8x float. +// TM_ORDER4_FLOAT: Max relative error < 8e-6, absolute error < 9e-6 for x < 1. +// TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 6e-5 for x < 1. +// TM_ORDER3_16BIT: Max relative error < 6e-4, absolute error < 2e-3 for x < 1. +template +inline void float32_pow2(__m256& x0, __m256& x1) { + // Convert straight to int. + __m256i exp_int_x0 = _mm256_cvtps_epi32(x0); + __m256i exp_int_x1 = _mm256_cvtps_epi32(x1); + __m256i result_x0, result_x1; + + static_assert(kOrder == TM_ORDER4_FLOAT || kOrder == TM_ORDER4_16BIT || + kOrder == TM_ORDER3_16BIT, + "Invalid order."); + + if (kOrder == TM_ORDER4_FLOAT) { + __m256i mantissa_mask = _mm256_set1_epi32(0x7fffff); + __m256 float_factor = + _mm256_set1_ps(1.0f / static_cast(1 << kFloatMantissaBits)); + __m256i fract0 = _mm256_and_si256(mantissa_mask, exp_int_x0); + __m256i fract1 = _mm256_and_si256(mantissa_mask, exp_int_x1); + __m256 float0 = _mm256_mul_ps(_mm256_cvtepi32_ps(fract0), float_factor); + __m256 float1 = _mm256_mul_ps(_mm256_cvtepi32_ps(fract1), float_factor); + // Compute the polynomial of the fractional part. + // Ordering these lines carefully makes it faster, as some of the multiply + // operations can pipeline instead of waiting for the previous result. + __m256 x_squared0 = _mm256_mul_ps(float0, float0); + __m256 x_squared1 = _mm256_mul_ps(float1, float1); + __m256 b = _mm256_set1_ps(kExpQuarticFactor1); + __m256 b_x0 = _mm256_mul_ps(b, float0); + __m256 b_x1 = _mm256_mul_ps(b, float1); + __m256 a = _mm256_set1_ps(kExpQuarticFactor2); + __m256 a_x_squared0 = _mm256_mul_ps(a, x_squared0); + __m256 a_x_squared1 = _mm256_mul_ps(a, x_squared1); + __m256 x_squared_minus_x0 = _mm256_sub_ps(x_squared0, float0); + __m256 x_squared_minus_x1 = _mm256_sub_ps(x_squared1, float1); + __m256 c = _mm256_set1_ps(kExpQuarticFactor0); + b_x0 = _mm256_add_ps(b_x0, c); + b_x1 = _mm256_add_ps(b_x1, c); + float_factor = _mm256_set1_ps(static_cast(1 << kFloatMantissaBits)); + a_x_squared0 = _mm256_add_ps(a_x_squared0, b_x0); + a_x_squared1 = _mm256_add_ps(a_x_squared1, b_x1); + a_x_squared0 = _mm256_mul_ps(a_x_squared0, x_squared_minus_x0); + a_x_squared1 = _mm256_mul_ps(a_x_squared1, x_squared_minus_x1); + result_x0 = _mm256_cvtps_epi32(_mm256_mul_ps(a_x_squared0, float_factor)); + result_x1 = _mm256_cvtps_epi32(_mm256_mul_ps(a_x_squared1, float_factor)); + } else { + // Combine the fractional part of both inputs into a single register. + // The representation is fixed16<0>, ie 15 mantissa bits. + __m256i mantissa_mask = _mm256_set1_epi32(0x7fff00); + __m256i x_01 = + _mm256_srli_epi32(_mm256_and_si256(mantissa_mask, exp_int_x0), 8); + x_01 = _mm256_or_si256( + x_01, + _mm256_slli_epi32(_mm256_and_si256(mantissa_mask, exp_int_x1), 8)); + // Compute the polynomial of the fractional part. + // Ordering these lines carefully makes it faster, as some of the multiply + // operations can pipeline instead of waiting for the previous result. + __m256i x_squared = _mm256_mulhrs_epi16(x_01, x_01); + __m256i result, x_squared_minus_x; + if (kOrder == TM_ORDER4_16BIT) { + __m256i b = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor1)); + __m256i b_x = _mm256_mulhrs_epi16(b, x_01); + __m256i a = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor2)); + __m256i a_x_squared = _mm256_mulhrs_epi16(a, x_squared); + x_squared_minus_x = _mm256_sub_epi16(x_squared, x_01); + // LOG(INFO) << "x_squared_minus_x=" << + // static_cast(_mm256_extract_epi16(x_squared_minus_x, 0)) / + // 32768.0f; + __m256i c = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor0)); + b_x = _mm256_add_epi16(b_x, c); + // LOG(INFO) << "bx+c=" << static_cast(_mm256_extract_epi16(b_x, + // 0)) / 32768.0f; + result = _mm256_add_epi16(a_x_squared, b_x); + } else { // kOrder = TM_ORDER3_16BIT + __m256i a = _mm256_set1_epi16(FloatAsInt16(kExpCubicFactor1)); + __m256i b = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor0)); + __m256i a_x = _mm256_mulhrs_epi16(a, x_01); + x_squared_minus_x = _mm256_sub_epi16(x_squared, x_01); + result = _mm256_add_epi16(a_x, b); + } + result = _mm256_mulhrs_epi16(result, x_squared_minus_x); + // Extract 16x16-bit results back to the separate sets of 8x32. + result_x0 = _mm256_slli_epi32(result, 16); + result_x0 = _mm256_srai_epi32(result_x0, 8); + result_x1 = _mm256_srai_epi32(result, 16); + result_x1 = _mm256_slli_epi32(result_x1, 8); + } + // Add the constant to normalize the exponent. + __m256i exp_offset = _mm256_set1_epi32(kFloatExponentOffset); + exp_int_x0 = _mm256_add_epi32(exp_int_x0, exp_offset); + exp_int_x0 = _mm256_add_epi32(exp_int_x0, result_x0); + exp_int_x1 = _mm256_add_epi32(exp_int_x1, exp_offset); + exp_int_x1 = _mm256_add_epi32(exp_int_x1, result_x1); + // Cast back to float, as we just computed the exponent and mantissa and + // assembled them in IEEE format. + x0 = _mm256_castsi256_ps(exp_int_x0); + x1 = _mm256_castsi256_ps(exp_int_x1); +} + +// Fixed32 to to float exp approximation, using a quartic/cubic refinement of +// the exponent trick. Implementation is entirely in integer, using 16x16=16 +// multiplication, using AVX2, which enables 16 elements to be computed in +// parallel, hence the double register input/output args. +// The price paid for this speed is an increase in error over the (scalar) int32 +// example implementations above by a variable factor of 4-10. +// The TM_ORDER4_FLOAT version uses floats and improves the precision. +// Input: 2x registers containins 8x fixed32 with kMantissaBits. +// Output: 2x registers containing 8x float32. +// TM_ORDER4_FLOAT: Max relative error < 8e-6, absolute error < 9e-6 for x < 1. +// TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 6e-5 for x < 1. +// TM_ORDER3_16BIT: Max relative error < 6e-4, absolute error < 2e-3 for x < 1. +template +inline void float_exp_float_preclipped(__m256& y0, __m256& y1) { + // Divide by log 2 to convert problem to 2^x, and scale to match the + // mantissa bits required by IEEE floats. Without a _mm256_mulhrs_epi32, it is + // much easier to do this in float, even with the double conversion, as 16 bit + // is not precise enough here. + // This is the shift of the FP mantissa relative to the input mantissa. + constexpr int kXShift = kFloatMantissaBits - kInputMantissaBits; + constexpr float kLogFactor = static_cast(1 << kXShift); + __m256 factor = _mm256_set1_ps(kLogFactor * kOneOverLog2); + y0 = _mm256_mul_ps(y0, factor); + y1 = _mm256_mul_ps(y1, factor); + // Now compute 2^x. + float32_pow2(y0, y1); +} +template +inline void fixed32_exp_float(const __m256i& x0, const __m256i& x1, __m256& y0, + __m256& y1) { + // Clip to acceptable bounds to prevent overflow, and convert to float. + y0 = + _mm256_cvtepi32_ps(ClipToBounds(kInputMantissaBits, x0)); + y1 = + _mm256_cvtepi32_ps(ClipToBounds(kInputMantissaBits, x1)); + float_exp_float_preclipped(y0, y1); +} + +// Float->float tanh approximation via the exponent trick. +// Note that the input is scaled floats, as if converted raw from fixed16/32. +// Input: 2x registers containing 8x float scaled by input_mantissa_bits. +// Output: two registers containing 8x float. +// TM_ORDER4_FLOAT: Max relative error < 2.1e-5, absolute error < 2.3e-6. +// TM_ORDER4_16BIT: Max relative error < 1e-4, absolute error < 1.3e-5. +// TM_ORDER3_16BIT: Max relative error < 2.1e-3, absolute error < 3e-4. +template +inline void float_tanh_float(const __m256& x0, const __m256& x1, __m256& y0, + __m256& y1) { + // Divide by log 2 to convert problem to 2^x, double (as we need exp(2x)) and + // scale to the mantissa bits required by float32_pow2 all in one multiply. + // This is the shift of the FP mantissa relative to the input mantissa. + // Add one to double the input. + const float kLogFactor = + static_cast(1 << (kFloatMantissaBits - kInputMantissaBits + 1)); + __m256 factor = _mm256_set1_ps(kLogFactor * kOneOverLog2); + // Clip to suitable input bounds for tanh. + __m256 clip_limit = _mm256_set1_ps(kMaxTanhInput * (1 << kInputMantissaBits)); + __m256 clip0 = _mm256_min_ps(x0, clip_limit); + __m256 clip1 = _mm256_min_ps(x1, clip_limit); + clip_limit = _mm256_set1_ps(-kMaxTanhInput * (1 << kInputMantissaBits)); + clip0 = _mm256_max_ps(clip0, clip_limit); + clip1 = _mm256_max_ps(clip1, clip_limit); + __m256 exp0 = _mm256_mul_ps(clip0, factor); + __m256 exp1 = _mm256_mul_ps(clip1, factor); + // Now compute 2^x. + float32_pow2(exp0, exp1); + // Now compute tanh using (e^2x - 1) / (e^2x + 1). + __m256 one = _mm256_set1_ps(1.0f); + __m256 numerator = _mm256_sub_ps(exp0, one); + __m256 denominator = _mm256_add_ps(exp0, one); + // Approximate reciprocal is not accurate enough - use full division. + exp0 = _mm256_div_ps(numerator, denominator); + numerator = _mm256_sub_ps(exp1, one); + denominator = _mm256_add_ps(exp1, one); + exp1 = _mm256_div_ps(numerator, denominator); + // Compute 3rd-order Taylor tanh ~ x - x^3/3 for high accuracy and thus low + // relative error close to 0. + // Normalize the inputs back to proper floats. + factor = _mm256_set1_ps(1.0f / (1 << kInputMantissaBits)); + clip0 = _mm256_mul_ps(clip0, factor); + clip1 = _mm256_mul_ps(clip1, factor); + __m256 third = _mm256_set1_ps(-1.0f / 3.0f); + __m256 taylor0 = _mm256_mul_ps(clip0, clip0); + __m256 taylor1 = _mm256_mul_ps(clip1, clip1); + taylor0 = _mm256_mul_ps(taylor0, clip0); + taylor1 = _mm256_mul_ps(taylor1, clip1); + // TODO(b/191497441): The next two pairs of instructions could be combined to + // _mm256_fmadd_ps, but requires -mfma compilation option, eg: + // taylor0 = _mm256_fmadd_ps(taylor0, third, clip0); + taylor0 = _mm256_mul_ps(taylor0, third); + taylor1 = _mm256_mul_ps(taylor1, third); + taylor0 = _mm256_add_ps(clip0, taylor0); + taylor1 = _mm256_add_ps(clip1, taylor1); + // Test |x| <= 1/9, roughly where the errors cross over, without needing yet + // another constant. + third = _mm256_mul_ps(third, third); + __m256 neg_zero = _mm256_set1_ps(-0.0f); + clip0 = _mm256_andnot_ps(neg_zero, clip0); + clip1 = _mm256_andnot_ps(neg_zero, clip1); + __m256 cmp_results0 = _mm256_cmp_ps(clip0, third, _CMP_LE_OQ); + __m256 cmp_results1 = _mm256_cmp_ps(clip1, third, _CMP_LE_OQ); + y0 = _mm256_blendv_ps(exp0, taylor0, cmp_results0); + y1 = _mm256_blendv_ps(exp1, taylor1, cmp_results1); +} + +// Fixed32 sigmoid approximation via the AVX2 implementation of the exponent +// trick. +// Input: 2x registers containins 8x float containing converted fixed32 scaled +// with kInputMantissaBits. +// Output: 2x registers containing 8x float. +// TM_ORDER4_FLOAT: Max relative error < 4e-6, absolute error < 1e-6. +// TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 7e-6. +// TM_ORDER3_16BIT: Max relative error < 5.4e-4, absolute error < 1.4e-4. +template +inline void float_sigmoid_float(__m256& y0, __m256& y1) { + constexpr float kInputFactor = static_cast(1 << kInputMantissaBits); + // Negate the inputs. + __m256 minus_zero = _mm256_set1_ps(-0.0f); + y0 = _mm256_xor_ps(y0, minus_zero); + y1 = _mm256_xor_ps(y1, minus_zero); + y0 = ClipToFloatBounds(kMaxSigmoidInput * kInputFactor, y0); + y1 = ClipToFloatBounds(kMaxSigmoidInput * kInputFactor, y1); + float_exp_float_preclipped(y0, y1); + __m256 one = _mm256_set1_ps(1.0f); + // Approximate reciprocal is not accurate enough - use full division. + y0 = _mm256_div_ps(one, _mm256_add_ps(y0, one)); + y1 = _mm256_div_ps(one, _mm256_add_ps(y1, one)); +} + +#endif // defined __AVX2__ + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FAST_TRANSCENDENTALS_H_ diff --git a/sparse_matmul/numerics/fasttranscendentals_test.cc b/sparse_matmul/numerics/fasttranscendentals_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..004241e55824c66cbecfdd645b70efa2ba237638 --- /dev/null +++ b/sparse_matmul/numerics/fasttranscendentals_test.cc @@ -0,0 +1,665 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined __aarch64__ +#include +#endif +#if defined __AVX__ || defined __AVX2__ +#include +#endif + +#include + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "sparse_matmul/numerics/fast_transcendentals.h" +#include "sparse_matmul/numerics/test_utils.h" + +namespace csrblocksparse { + +const float kExpFixedRelTolerance = .084f; + +#ifdef SIGMOID_AS_TANH +#if defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX +const float kSigmoidRelTolerance = .093f; // 9.3% relative +const float kSigmoidAbsTolerance = .0005f; +const float kSigmoidFixedRelTolerance = .093f; +const float kSigmoidFixedAbsTolerance = .0005f; +#elif defined FAST_TRANSCENDENTALS +const float kSigmoidRelTolerance = .09f; // 9.0% relative +const float kSigmoidAbsTolerance = .003f; +const float kSigmoidFixedRelTolerance = .09f; +const float kSigmoidFixedAbsTolerance = .003f; +#endif +#elif defined FAST_TRANSCENDENTALS and defined ACCURATE_TRANSCENDENTAL_APPROX +const float kSigmoidRelTolerance = .102f; // 10.2% relative +const float kSigmoidAbsTolerance = .0003f; +const float kSigmoidFixedRelTolerance = .102f; +const float kSigmoidFixedAbsTolerance = .0003f; +#elif defined FAST_TRANSCENDENTALS +const float kSigmoidRelTolerance = .09f; // 9.0% relative +const float kSigmoidAbsTolerance = .006f; +const float kSigmoidFixedRelTolerance = .09f; +const float kSigmoidFixedAbsTolerance = .006f; +#else +const float kSigmoidRelTolerance = .0001f; +const float kSigmoidAbsTolerance = 1e-5f; +const float kSigmoidFixedRelTolerance = .001f; +const float kSigmoidFixedAbsTolerance = .001f; +#endif + +#if (defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX || \ + defined FASTER_TRANSCENDENTALS) +const float kExpRelTolerance = .03f; // 3% relative +const float kTanhRelTolerance = .006f; // .6% relative +const float kTanhAbsTolerance = .0003f; +#elif defined FAST_TRANSCENDENTALS +const float kExpRelTolerance = .03f; // 3% relative +const float kTanhRelTolerance = .091f; // .91% relative +const float kTanhAbsTolerance = .00525f; +#else +const float kExpRelTolerance = .0001f; +const float kTanhRelTolerance = .0001f; +const float kTanhAbsTolerance = 1e-5f; +#endif + +constexpr float kQuarticFloatExpRelTolerance = 8e-6f; +constexpr float kQuarticFloatExpTolerance = 9e-6f; +constexpr float kQuarticExpRelTolerance = 3e-5f; +constexpr float kQuarticExpTolerance = 6e-5f; +constexpr float kCubicExpRelTolerance = 6e-4f; +constexpr float kCubicExpTolerance = 2e-3f; +constexpr float kQuarticFloatTanhRelTolerance = 3e-5f; +constexpr float kQuarticFloatTanhTolerance = 3e-6f; +constexpr float kCubicTanhRelTolerance = 3e-3f; +constexpr float kCubicTanhTolerance = 3e-4f; +constexpr float kQuarticSigmoidRelTolerance = 3e-5f; +constexpr float kQuarticSigmoidTolerance = 7e-6f; +constexpr float kCubicSigmoidRelTolerance = 6e-4f; +constexpr float kCubicSigmoidTolerance = 2e-4f; +#ifdef __AVX2__ +constexpr float kQuarticTanhRelTolerance = 1e-4f; +constexpr float kQuarticTanhTolerance = 2e-5f; +constexpr float kQuarticFloatSigmoidRelTolerance = 4e-6f; +constexpr float kQuarticFloatSigmoidTolerance = 1e-6f; +#endif // __AVX2__ + +TEST(Transcendentals, Exp) { + // 132 - 127 = 5, we check between -63.99... and 63.99... + const int maxExponent = 132; + const int minExponent = 0; + float max_error = 0.f; + constexpr int kExponentBits = 7; + for (int s = 0; s < 2; ++s) { + for (int e = minExponent; e < maxExponent; ++e) { + // Don't check every mantissa for speed reasons. + for (int m = 0; m < (1 << 23); m += (1 << 10)) { + uint32_t int_val = s << 31 | e << 23 | m; + float x; + memcpy(&x, &int_val, sizeof(float)); + + float exact_exp = expf(x); + float approx_exp = csrblocksparse::fast_exp(x); + float approx_exp_fixed = csrblocksparse::fast_exp( + csrblocksparse::fixed32(x)); + + float rel_diff = RelDiff(exact_exp, approx_exp); + float rel_diff_fixed = RelDiff(exact_exp, approx_exp_fixed); + max_error = std::max(max_error, rel_diff); + EXPECT_LT(rel_diff, kExpRelTolerance) + << exact_exp << " " << approx_exp << " " << x; + EXPECT_LT(rel_diff_fixed, kExpRelTolerance) + << exact_exp << " " << approx_exp << " " << x; + } + } + } +} + +TEST(Transcendentals, FixedExp) { + const int maxExponent = 132; + const int minExponent = 120; + float max_error = 0.f; + float max_abs_error = 0.f; + for (int s = 0; s < 2; ++s) { + for (int e = minExponent; e < maxExponent; ++e) { + // Don't check every mantissa for speed reasons. + for (int m = 0; m < (1 << 23); m += (1 << 10)) { + uint32_t int_val = s << 31 | e << 23 | m; + float x; + memcpy(&x, &int_val, sizeof(float)); + + float exact_exp = expf(x); + float approx_exp = + csrblocksparse::fast_exp_fixed(csrblocksparse::fixed32<16>(x)); + + float rel_diff = RelDiff(exact_exp, approx_exp); + float abs_diff = std::abs(exact_exp - approx_exp); + max_error = std::max(max_error, rel_diff); + max_abs_error = std::max(max_abs_error, abs_diff); + EXPECT_LT(rel_diff, kExpFixedRelTolerance) + << exact_exp << " " << approx_exp << " " << x; + } + } + } + LOG(INFO) << "Max relative exp error = " << max_error + << ", abs=" << max_abs_error; +} + +template +void TestExp(float abs_tolerance, float rel_tolerance) { + constexpr int kMaxInput = 80 << 16; + constexpr int kMinInput = -(80 << 16); + constexpr int kExponentBits = 15; + float max_error = 0.f; + float max_abs_error = 0.f; + for (int i = kMinInput; i <= kMaxInput; ++i) { + csrblocksparse::fixed32 fixed_int(i); + float x = static_cast(fixed_int); + float exact_exp = expf(x); + float approx_exp = fixed32_exp(fixed_int); + float diff = exact_exp - approx_exp; + float abs_diff = std::abs(diff); + float rel_diff = RelDiff(exact_exp, approx_exp); + max_error = std::max(max_error, rel_diff); + if (x <= 1.0f) { + ASSERT_LT(abs_diff, abs_tolerance) + << "x=" << x << ", target=" << exact_exp << ", aprx=" << approx_exp; + max_abs_error = std::max(max_abs_error, abs_diff); + } + ASSERT_LT(rel_diff, rel_tolerance) + << "x=" << x << ", target=" << exact_exp << ", aprx=" << approx_exp; + } + LOG(INFO) << "Max relative error = " << max_error + << ", abs=" << max_abs_error; +} + +TEST(Transcendentals, QuarticExp) { + TestExp(kQuarticFloatExpTolerance, + kQuarticFloatExpRelTolerance); +} + +TEST(Transcendentals, CubicExp) { + TestExp(kCubicExpTolerance, kCubicExpRelTolerance); +} + +template +void TestTanh(float abs_tolerance, float rel_tolerance) { + constexpr int kMaxInput = (40 << 16); + constexpr int kMinInput = -(40 << 16); + constexpr int kExponentBits = 15; + float max_error = 0.f; + float max_abs_error = 0.f; + for (int i = kMinInput; i <= kMaxInput; ++i) { + csrblocksparse::fixed32 fixed_int(i); + float x = static_cast(fixed_int); + float exact_tanh = tanh(x); + float approx_tanh = fixed32_tanh(fixed_int); + float diff = exact_tanh - approx_tanh; + float abs_diff = std::abs(diff); + float rel_diff = RelDiff(exact_tanh, approx_tanh); + ASSERT_LT(abs_diff, abs_tolerance) + << "x=" << x << ", target=" << exact_tanh << ", aprx=" << approx_tanh; + max_abs_error = std::max(max_abs_error, abs_diff); + max_error = std::max(max_error, rel_diff); + ASSERT_LT(rel_diff, rel_tolerance) + << "x=" << x << ", target=" << exact_tanh << ", aprx=" << approx_tanh; + } + LOG(INFO) << "Max relative error = " << max_error + << ", abs=" << max_abs_error; +} + +TEST(Transcendentals, QuarticTanh) { + TestTanh(kQuarticFloatTanhTolerance, + kQuarticFloatTanhRelTolerance); +} + +TEST(Transcendentals, CubicTanh) { + TestTanh(kCubicTanhTolerance, kCubicTanhRelTolerance); +} + +template +void TestSigmoid(float abs_tolerance, float rel_tolerance) { + constexpr int kMaxInput = 80 << 16; + constexpr int kMinInput = -(80 << 16); + constexpr int kExponentBits = 15; + float max_error = 0.f; + float max_abs_error = 0.f; + for (int i = kMinInput; i <= kMaxInput; ++i) { + csrblocksparse::fixed32 fixed_int(i); + float x = static_cast(fixed_int); + float exact_sigmoid = 1.0f / (1.0f + exp(-x)); + float approx_sigmoid = fixed32_sigmoid(fixed_int); + float diff = exact_sigmoid - approx_sigmoid; + float abs_diff = std::abs(diff); + float rel_diff = RelDiff(exact_sigmoid, approx_sigmoid); + max_error = std::max(max_error, rel_diff); + ASSERT_LT(abs_diff, abs_tolerance) + << "x=" << x << ", target=" << exact_sigmoid + << ", aprx=" << approx_sigmoid; + max_abs_error = std::max(max_abs_error, abs_diff); + ASSERT_LT(rel_diff, rel_tolerance) + << "x=" << x << ", target=" << exact_sigmoid + << ", aprx=" << approx_sigmoid; + } + LOG(INFO) << "Max relative sigmoid error = " << max_error + << ", abs=" << max_abs_error; +} + +TEST(Transcendentals, QuarticSigmoidExp) { + TestSigmoid(kQuarticSigmoidTolerance, + kQuarticSigmoidRelTolerance); +} + +TEST(Transcendentals, CubicSigmoidExp) { + TestSigmoid(kCubicSigmoidTolerance, + kCubicSigmoidRelTolerance); +} + +TEST(Transcendentals, Sigmoid) { + // 132 - 127 = 5, we check between -63.99... and 63.99... + const int maxExponent = 132; + const int minExponent = 0; + // The mantissa bits must not exceed 23, so min exponent bits here is: + // 31 - 23 = 8. + constexpr int kExponentBits = 9; + float max_error = 0.f; + float max_abs_error = 0.f; +#if defined __aarch64__ + float max_vector_error = 0.f; + float max_vector_abs_error = 0.f; +#endif + for (int s = 0; s < 2; ++s) { + for (int e = minExponent; e < maxExponent; ++e) { + // Don't check every mantissa for speed reasons. + for (int m = 0; m < (1 << 23); m += (1 << 10)) { + uint32_t int_val = s << 31 | e << 23 | m; + float x; + memcpy(&x, &int_val, sizeof(float)); + + float exact_sigmoid = 1. / (1. + expf(-x)); + float approx_sigmoid = csrblocksparse::fast_sigmoid(x); + float approx_sigmoid_fixed = + csrblocksparse::fast_sigmoid( + csrblocksparse::fixed32(x)); + + float rel_diff = RelDiff(exact_sigmoid, approx_sigmoid); + float abs_diff = std::abs(exact_sigmoid - approx_sigmoid); + float rel_diff_fixed = RelDiff(exact_sigmoid, approx_sigmoid_fixed); + max_error = std::max(max_error, rel_diff); + max_abs_error = std::max(max_abs_error, abs_diff); + EXPECT_LT(rel_diff, kSigmoidRelTolerance) + << exact_sigmoid << " " << approx_sigmoid << " " << x; + EXPECT_NEAR(approx_sigmoid, exact_sigmoid, kSigmoidAbsTolerance) << x; + + EXPECT_LT(rel_diff_fixed, kSigmoidFixedRelTolerance) + << exact_sigmoid << " " << approx_sigmoid_fixed << " " << x; + EXPECT_NEAR(approx_sigmoid_fixed, exact_sigmoid, + kSigmoidFixedAbsTolerance) + << x; +#if defined __aarch64__ + constexpr int kSIMD_WIDTH = 4; + float approx_results[kSIMD_WIDTH]; + int32x4_t input = + vdupq_n_s32(csrblocksparse::fixed32(x).raw_val()); + float32x4_t result = csrblocksparse::fast_sigmoid(input); + vst1q_f32(approx_results, result); + + for (int i = 0; i < kSIMD_WIDTH; ++i) { + float rel_diff = RelDiff(exact_sigmoid, approx_results[i]); + float abs_diff = std::abs(exact_sigmoid - approx_results[i]); + max_vector_error = std::max(max_vector_error, rel_diff); + max_vector_abs_error = std::max(max_vector_abs_error, abs_diff); + EXPECT_LT(rel_diff, kSigmoidRelTolerance) + << exact_sigmoid << " " << approx_sigmoid << " " << x; + EXPECT_NEAR(approx_sigmoid, exact_sigmoid, kSigmoidAbsTolerance) << x; + } +#endif + } + } + } + LOG(INFO) << "Max relative error in float sigmoid=" << max_error; + LOG(INFO) << "Max abs error in float sigmoid=" << max_abs_error; +#if defined __aarch64__ + LOG(INFO) << "Max relative vector error fixed sigmoid=" << max_vector_error; + LOG(INFO) << "Max abs vector error fixed sigmoid=" << max_vector_abs_error; +#endif +} + +TEST(Transcendentals, Tanh) { + // 132 - 127 = 5, we check between -63.99... and 63.99... + const int maxExponent = 132; + const int minExponent = 0; + float max_error = 0.f; + float max_abs_error = 0.f; + for (int s = 0; s < 2; ++s) { + for (int e = minExponent; e < maxExponent; ++e) { + // Don't check every mantissa for speed reasons. + for (int m = 0; m < (1 << 23); m += (1 << 10)) { + uint32_t int_val = s << 31 | e << 23 | m; + float x; + memcpy(&x, &int_val, sizeof(float)); + + float exact_tanh = tanhf(x); + float approx_tanh = csrblocksparse::fast_tanh(x); + + float rel_diff = RelDiff(exact_tanh, approx_tanh); + float abs_diff = std::abs(exact_tanh - approx_tanh); + max_error = std::max(rel_diff, max_error); + max_abs_error = std::max(abs_diff, max_abs_error); + + EXPECT_LT(rel_diff, kTanhRelTolerance) + << exact_tanh << " " << approx_tanh << " " << x; + EXPECT_NEAR(approx_tanh, exact_tanh, kTanhAbsTolerance) << x; + } + } + } + LOG(INFO) << "Max relative error in float tanh=" << max_error; + LOG(INFO) << "Max abs error in float tanh=" << max_abs_error; + + // tanh behavior is not identical across all lanes, so need to test + // with some values in the linear region and some not. +#if defined __aarch64__ + float vals[4] = {-1.f, -.1f, .1f, 1.f}; + float exact_results[4]; + float approx_results[4]; + max_error = 0.f; + max_abs_error = 0.f; + + float32x4_t input = vld1q_f32(vals); + float32x4_t result = csrblocksparse::fast_tanh(input); + vst1q_f32(approx_results, result); + + for (int i = 0; i < 4; ++i) { + exact_results[i] = tanh(vals[i]); + float rel_diff = RelDiff(exact_results[i], approx_results[i]); + float abs_diff = std::abs(exact_results[i] - approx_results[i]); + max_error = std::max(rel_diff, max_error); + max_abs_error = std::max(abs_diff, max_abs_error); + + EXPECT_LT(rel_diff, kTanhRelTolerance) + << exact_results[i] << " " << approx_results[i] << " " << vals[i]; + EXPECT_NEAR(approx_results[i], exact_results[i], kTanhAbsTolerance) + << vals[i]; + } + LOG(INFO) << "Max relative vector error in float tanh=" << max_error; + LOG(INFO) << "Max abs vector error in float tanh=" << max_abs_error; +#endif +} + +#if defined __AVX2__ + +constexpr int kSIMDSize = 8; +constexpr int kNumExpBitsIn = 10; +constexpr int kNumExpBitsOut = 5; + +TEST(Transcendentals, TanhLut) { + // Test every value in (-1, 1) for round-trip exactness. + constexpr int kNumMantissaBitsIn = fixed32::kMantissaBits; + constexpr int kNumMantissaBitsOut = fixed16::kMantissaBits; + const int32_t* tanh_table = TanhTable(kNumMantissaBitsOut); + float in_factor = static_cast(1 << kNumMantissaBitsIn); + float out_factor = static_cast(1 << kNumMantissaBitsOut); + for (int i = 1 - (1 << kNumMantissaBitsOut); + i + kSIMDSize < (1 << kNumMantissaBitsOut); i += kSIMDSize) { + int32_t inputs[kSIMDSize]; + int32_t outputs[kSIMDSize]; + int32_t target_outputs[kSIMDSize]; + for (int j = 0; j < kSIMDSize; ++j) { + float target_tanh = (i + j) / out_factor; + float x = atanhf(static_cast(target_tanh)); + inputs[j] = static_cast(x * in_factor); + target_outputs[j] = i + j; + } + __m256i x_in = _mm256_loadu_si256(reinterpret_cast<__m256i*>(inputs)); + __m256i output = + fixed32_tanh_fixed16( + tanh_table, x_in); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(outputs), output); + for (int j = 0; j < kSIMDSize; ++j) { + EXPECT_EQ(target_outputs[j], outputs[j]); + } + } +} + +TEST(Transcendentals, SigmoidLut) { + // Test every value in (-1, 1) for round-trip exactness. + constexpr int kNumMantissaBitsIn = fixed32::kMantissaBits; + constexpr int kNumMantissaBitsOut = fixed16::kMantissaBits; + const int32_t* sigmoid_table = SigmoidTable(kNumMantissaBitsOut); + float in_factor = static_cast(1 << kNumMantissaBitsIn); + float out_factor = static_cast(1 << kNumMantissaBitsOut); + for (int i = 1; i + kSIMDSize < (1 << kNumMantissaBitsOut); i += kSIMDSize) { + int32_t inputs[kSIMDSize]; + int32_t outputs[kSIMDSize]; + int32_t target_outputs[kSIMDSize]; + for (int j = 0; j < kSIMDSize; ++j) { + float target_sigmoid = (i + j) / out_factor; + float x = 2.0f * atanhf(2.0f * static_cast(target_sigmoid) - 1.0f); + inputs[j] = static_cast(x * in_factor); + target_outputs[j] = i + j; + } + __m256i x_in = _mm256_loadu_si256(reinterpret_cast<__m256i*>(inputs)); + __m256i output = + fixed32_sigmoid_fixed16( + sigmoid_table, x_in); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(outputs), output); + for (int j = 0; j < kSIMDSize; ++j) { + EXPECT_EQ(target_outputs[j], outputs[j]); + } + } +} + +template +static void TestExpAVX2(float abs_tolerance, float rel_tolerance) { + constexpr int kMantissaBits = 20; + // Test every value in [-80, 80] and report the max error. + constexpr int kMinInput = -(80 << kMantissaBits); + constexpr int kMaxInput = 80 << kMantissaBits; + constexpr int kNumInputs = kMaxInput - kMinInput; + std::vector inputs(kNumInputs); + std::vector outputs(kNumInputs); + std::vector target_outputs(kNumInputs); + for (int i = 0; i < inputs.size(); ++i) { + csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); + float x = static_cast(fixed_int); + inputs[i] = fixed_int.raw_val(); + target_outputs[i] = expf(x); + } + absl::Time t_start = absl::Now(); + for (int i = 0; i + kSIMDSize * 2 <= kNumInputs; i += kSIMDSize * 2) { + __m256i x0 = + _mm256_loadu_si256(reinterpret_cast(inputs.data() + i)); + __m256i x1 = _mm256_loadu_si256( + reinterpret_cast(inputs.data() + i + kSIMDSize)); + __m256 y0, y1; + fixed32_exp_float(x0, x1, y0, y1); + _mm256_storeu_ps(outputs.data() + i, y0); + _mm256_storeu_ps(outputs.data() + i + kSIMDSize, y1); + } + LOG(INFO) << "Time=" << absl::ToDoubleMilliseconds(absl::Now() - t_start); + float max_error = 0.f; + float max_abs_error = 0.f; + for (int i = 0; i < kNumInputs; ++i) { + float diff = target_outputs[i] - outputs[i]; + float abs_diff = std::abs(diff); + csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); + float x = static_cast(fixed_int); + float rel_diff = RelDiff(target_outputs[i], outputs[i]); + max_error = std::max(max_error, rel_diff); + if (x <= 1.0f) { + ASSERT_LT(abs_diff, abs_tolerance) + << "x=" << x << ", target=" << target_outputs[i] + << ", result= " << outputs[i] << ", i=" << i; + max_abs_error = std::max(max_abs_error, abs_diff); + } + ASSERT_LT(rel_diff, rel_tolerance) + << "x=" << x << ", target=" << target_outputs[i] + << ", result= " << outputs[i] << ", i=" << i; + } + LOG(INFO) << "Max relative error = " << max_error + << ", abs=" << max_abs_error; +} + +TEST(Transcendentals, QuarticFloatExpAVX2) { + TestExpAVX2(kQuarticFloatExpTolerance, + kQuarticFloatExpRelTolerance); +} + +TEST(Transcendentals, QuarticExpAVX2) { + TestExpAVX2(kQuarticExpTolerance, kQuarticExpRelTolerance); +} + +TEST(Transcendentals, CubicExpAVX2) { + TestExpAVX2(kCubicExpTolerance, kCubicExpRelTolerance); +} + +template +void TestTanhAVX2Float(float abs_tolerance, float rel_tolerance) { + constexpr int kMantissaBits = 16; + // Test every value in [-10, 10] and report the max error. + constexpr int kMinInput = -(10 << kMantissaBits); + constexpr int kMaxInput = 10 << kMantissaBits; + constexpr int kNumInputs = kMaxInput - kMinInput; + float max_error = 0.f; + float max_abs_error = 0.f; + std::vector inputs(kNumInputs); + std::vector outputs(kNumInputs); + std::vector target_outputs(kNumInputs); + for (int i = 0; i < inputs.size(); ++i) { + csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); + float x = static_cast(fixed_int); + float exact = tanh(x); + inputs[i] = static_cast(fixed_int.raw_val()); + target_outputs[i] = exact; + } + absl::Time t_start = absl::Now(); + for (int i = 0; i + kSIMDSize * 2 <= inputs.size(); i += kSIMDSize * 2) { + __m256 x0 = _mm256_loadu_ps(inputs.data() + i); + __m256 x1 = _mm256_loadu_ps(inputs.data() + kSIMDSize + i); + __m256 y0, y1; + float_tanh_float(x0, x1, y0, y1); + _mm256_storeu_ps(outputs.data() + i, y0); + _mm256_storeu_ps(outputs.data() + i + kSIMDSize, y1); + } + LOG(INFO) << "Time=" << absl::ToDoubleMilliseconds(absl::Now() - t_start); + float worst_abs_x = 0.0f, worst_rel_x = 0.0f; + for (int i = 0; i < inputs.size(); ++i) { + float diff = target_outputs[i] - outputs[i]; + float abs_diff = std::abs(diff); + csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); + float x = static_cast(fixed_int); + ASSERT_LT(abs_diff, abs_tolerance) + << "x=" << x << ", target=" << target_outputs[i] + << ", aprx=" << outputs[i]; + if (abs_diff > max_abs_error) worst_abs_x = x; + max_abs_error = std::max(max_abs_error, abs_diff); + float rel_diff = 0.0f; + rel_diff = RelDiff(target_outputs[i], outputs[i]); + if (rel_diff > max_error) worst_rel_x = x; + max_error = std::max(max_error, rel_diff); + ASSERT_LT(rel_diff, rel_tolerance) + << "x=" << x << ", target=" << target_outputs[i] + << ", aprx=" << outputs[i]; + } + LOG(INFO) << "Max relative error = " << max_error + << ", abs=" << max_abs_error; + LOG(INFO) << "Worst rel x = " << worst_rel_x << ", abs=" << worst_abs_x; +} + +TEST(Transcendentals, QuarticTanhFloatAVX2Float) { + TestTanhAVX2Float(kQuarticFloatTanhTolerance, + kQuarticFloatTanhRelTolerance); +} + +TEST(Transcendentals, QuarticTanhAVX2Float) { + TestTanhAVX2Float(kQuarticTanhTolerance, + kQuarticTanhRelTolerance); +} + +TEST(Transcendentals, CubicTanhAVX2Float) { + TestTanhAVX2Float(kCubicTanhTolerance, + kCubicTanhRelTolerance); +} + +template +void TestSigmoidAVX2Float(float abs_tolerance, float rel_tolerance) { + constexpr int kMantissaBits = 20; + // Test every value in [-20, 20] and report the max error. + constexpr int kMaxInput = 20 << kMantissaBits; + constexpr int kMinInput = -(20 << kMantissaBits); + float max_error = 0.f; + float max_abs_error = 0.f; + std::vector inputs(kMaxInput - kMinInput); + std::vector outputs(kMaxInput - kMinInput); + std::vector target_outputs(kMaxInput - kMinInput); + for (int i = 0; i < inputs.size(); ++i) { + csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); + float x = static_cast(fixed_int); + float exact = 1.0f / (1.0f + expf(-x)); + inputs[i] = fixed_int.raw_val(); + target_outputs[i] = exact; + } + absl::Time t_start = absl::Now(); + for (int i = 0; i + kSIMDSize * 2 <= inputs.size(); i += kSIMDSize * 2) { + __m256i x0 = + _mm256_loadu_si256(reinterpret_cast(inputs.data() + i)); + __m256i x1 = _mm256_loadu_si256( + reinterpret_cast(inputs.data() + i + kSIMDSize)); + __m256 y0 = _mm256_cvtepi32_ps(x0); + __m256 y1 = _mm256_cvtepi32_ps(x1); + float_sigmoid_float(y0, y1); + _mm256_storeu_ps(outputs.data() + i, y0); + _mm256_storeu_ps(outputs.data() + i + kSIMDSize, y1); + } + LOG(INFO) << "Time=" << absl::ToDoubleMilliseconds(absl::Now() - t_start); + for (int i = 0; i < inputs.size(); ++i) { + float diff = target_outputs[i] - outputs[i]; + float abs_diff = std::abs(diff); + csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); + float x = static_cast(fixed_int); + float rel_diff = RelDiff(target_outputs[i], outputs[i]); + max_error = std::max(max_error, rel_diff); + ASSERT_LT(abs_diff, abs_tolerance) + << "x=" << x << ", target=" << target_outputs[i] + << ", aprx=" << outputs[i]; + max_abs_error = std::max(max_abs_error, abs_diff); + ASSERT_LT(rel_diff, rel_tolerance) + << "x=" << x << ", target=" << target_outputs[i] + << ", aprx=" << outputs[i]; + } + LOG(INFO) << "Max relative error = " << max_error + << ", abs=" << max_abs_error; +} + +TEST(Transcendentals, QuarticSigmoidFloatAVX2Float) { + TestSigmoidAVX2Float(kQuarticFloatSigmoidTolerance, + kQuarticFloatSigmoidRelTolerance); +} + +TEST(Transcendentals, QuarticSigmoidAVX2Float) { + TestSigmoidAVX2Float(kQuarticSigmoidTolerance, + kQuarticSigmoidRelTolerance); +} + +TEST(Transcendentals, CubicSigmoidAVX2Float) { + TestSigmoidAVX2Float(kCubicSigmoidTolerance, + kCubicSigmoidRelTolerance); +} +#endif // __AVX2__ + +} // namespace csrblocksparse diff --git a/sparse_matmul/numerics/fixed_types.h b/sparse_matmul/numerics/fixed_types.h new file mode 100644 index 0000000000000000000000000000000000000000..932f81a0f6e7769e1c69c78b92b2f97520f295ad --- /dev/null +++ b/sparse_matmul/numerics/fixed_types.h @@ -0,0 +1,139 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FIXED_TYPES_H_ +#define LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FIXED_TYPES_H_ + +#include +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +namespace csrblocksparse { + +// Useful for meta-programming and determining if a type is a fixed point type +class fixed_type {}; +class fixed16_type : fixed_type {}; +class fixed32_type : fixed_type {}; + +// Storage class for 16-bit fixed point values, not meant to be used directly +// for computation. Used for storage and converting to/from float32. +// N = 16 - 1 - |ExponentBits|. +// range = [-2^|ExponentBits|, 2^|ExponentBits|), increment = 2^-N. +template +class fixed16 : fixed16_type { + static_assert(ExponentBits >= 0 && ExponentBits < 16, + "ExponentBits must be in" + " the interval [0, 15]"); + + public: + static constexpr int kExponentBits = ExponentBits; + static constexpr int kMantissaBits = 16 - ExponentBits - 1; + + fixed16() = default; + explicit fixed16(float x) : val_(float_to_fixed16(x)) {} + explicit fixed16(int16_t x) : val_(x) {} + + explicit operator float() const { return fixed16_to_float(val_); } + + int raw_val() const { return val_; } + + private: + inline float fixed16_to_float(int16_t x) const { + return static_cast(x) / (1 << kMantissaBits); + } + + // Conversion clips to the representable range. + inline int16_t float_to_fixed16(float x) const { + float fval = std::round(x * static_cast(1 << kMantissaBits)); + const float max_bound = std::numeric_limits::max(); + const float min_bound = std::numeric_limits::min(); + auto val = + static_cast(std::max(std::min(fval, max_bound), min_bound)); + LOG_IF(INFO, fval > max_bound || fval < min_bound) + << "Conversion clipping: " << x << " to " << fixed16_to_float(val); + return val; + } + + int16_t val_; +}; + +// Storage class for 32-bit fixed point values, not meant to be used directly +// for computation. Used for storage and converting to/from float32. +// N = 32 - 1 - |ExponentBits|. +// range = [-2^|ExponentBits|, 2^|ExponentBits|), increment = 2^-N. +template +class fixed32 : fixed32_type { + static_assert(ExponentBits >= 0 && ExponentBits < 32, + "ExponentBits must be in" + " the interval [0, 31]"); + + public: + static constexpr int kExponentBits = ExponentBits; + static constexpr int kMantissaBits = 32 - ExponentBits - 1; + + fixed32() = default; + explicit fixed32(float x) : val_(float_to_fixed32(x)) {} + explicit fixed32(int32_t x) : val_(x) {} + + explicit operator float() const { return fixed32_to_float(val_); } + + int raw_val() const { return val_; } + + private: + inline float fixed32_to_float(int32_t x) const { + return static_cast(x) / (1LL << kMantissaBits); + } + + // Conversion clips to the representable range. + inline int32_t float_to_fixed32(float x) const { + float fval = std::round(x * static_cast(1LL << kMantissaBits)); + const int32_t max_bound = std::numeric_limits::max(); + const int32_t min_bound = std::numeric_limits::min(); + int32_t val = fval >= static_cast(max_bound) + ? max_bound + : (fval < static_cast(min_bound) + ? min_bound + : static_cast(fval)); + + LOG_IF(INFO, fval >= max_bound || fval < min_bound) + << "Conversion clipping: " << x << " to " << fixed32_to_float(val); + return val; + } + + int32_t val_; +}; + +template +struct IsFixed16Type + : std::integral_constant::value> {}; + +template +struct IsFixed32Type + : std::integral_constant::value> {}; + +template +struct IsFixedType : std::integral_constant::value || + IsFixed32Type::value> { +}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FIXED_TYPES_H_ diff --git a/sparse_matmul/numerics/fixed_types_test.cc b/sparse_matmul/numerics/fixed_types_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..82fcd93d8b817d4ee3c1892a7221317e8441de68 --- /dev/null +++ b/sparse_matmul/numerics/fixed_types_test.cc @@ -0,0 +1,43 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/numerics/fixed_types.h" + +#include + +#include "gtest/gtest.h" +#include "sparse_matmul/numerics/test_utils.h" +#include "sparse_matmul/numerics/type_utils.h" + +namespace csrblocksparse { + +// Basic test that makes sure basic multiplication and TypeOfProduct work +// correctly. +TEST(FixedPoint, Multiplication) { + fixed16<4> a(.1f); + fixed16<4> b(1.f); + + TypeOfProduct, fixed16<4>>::type c(a.raw_val() * b.raw_val()); + + EXPECT_NEAR(static_cast(c), .1f, + 1. / (1 << fixed16<2>::kMantissaBits)); +} + +TEST(FixedPoint, SafeCastingIntMax) { + const float int_max_float = std::numeric_limits::max(); + const csrblocksparse::fixed32<31> int_max_fixed(int_max_float); + EXPECT_FLOAT_EQ(int_max_float, static_cast(int_max_fixed)); +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/numerics/float16_types.h b/sparse_matmul/numerics/float16_types.h new file mode 100644 index 0000000000000000000000000000000000000000..5a313271ac90dc849988f47306f94b75d2751253 --- /dev/null +++ b/sparse_matmul/numerics/float16_types.h @@ -0,0 +1,149 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FLOAT16_TYPES_H_ +#define LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FLOAT16_TYPES_H_ + +#include +#include +#include + +namespace csrblocksparse { + +// Storage class for fp16 values, not meant to be used directly for computation. +// Used for converting to/from float32. +class fp16 { + public: + fp16() = default; + explicit fp16(float x) : val_(float_to_fp16(x)) {} + explicit fp16(uint16_t x) : val_(x) {} + static constexpr int kMantissaBits = 11; + + explicit operator float() const { return fp16_to_float(val_); } + + private: + inline float fp16_to_float(uint16_t as_int) const { +#if defined __aarch64__ + float x; + float* x_ptr = &x; + asm volatile( + "dup v0.8h, %w[as_int]\n" + "fcvtl v1.4s, v0.4h\n " + "st1 {v1.s}[0], [%[x_ptr]]\n" + : // outputs + : // inputs + [x_ptr] "r"(x_ptr), + [as_int] "r"(as_int) + : // clobbers + "cc", "memory", "v0", "v1"); + return x; +#else + unsigned int sign_bit = (as_int & 0x8000) << 16; + unsigned int exponent = as_int & 0x7c00; + + unsigned int mantissa; + if (exponent == 0) + mantissa = 0; + else + mantissa = ((as_int & 0x7fff) << 13) + 0x38000000; + mantissa |= sign_bit; + + float x; + memcpy(&x, &mantissa, sizeof(int)); + return x; +#endif // defined __aarch64__ + } + + inline uint16_t float_to_fp16(float x) const { +#if defined __aarch64__ + uint16_t as_int; + uint16_t* as_int_ptr = &as_int; + asm volatile( + "dup v0.4s, %w[x]\n" + "fcvtn v1.4h, v0.4s\n" + "st1 {v1.h}[0], [%[as_int_ptr]]\n" + : // outputs + : // inputs + [as_int_ptr] "r"(as_int_ptr), + [x] "r"(x) + : // clobbers + "cc", "memory", "v0", "v1"); + return as_int; +#else + unsigned int x_int; + memcpy(&x_int, &x, sizeof(int)); + + unsigned int sign_bit = (x_int & 0x80000000) >> 16; + unsigned int exponent = x_int & 0x7f800000; + + unsigned int mantissa; + if (exponent < 0x38800000) { // exponent too small or denormal + mantissa = 0; + } else if (exponent > 0x8e000000) { + mantissa = 0x7bff; // exponent too big, inf + } else { + mantissa = ((x_int & 0x7fffffff) >> 13) - 0x1c000; + } + + mantissa |= sign_bit; + + return static_cast(mantissa & 0xFFFF); +#endif + } + + uint16_t val_; +}; + +// Storage class for bfloat16 values, not meant to be used directly for +// computation. Used for converting to/from float32. +class bfloat16 { + public: + bfloat16() = default; + explicit bfloat16(float x) : val_(float_to_bfloat16(x)) {} + explicit bfloat16(uint16_t x) : val_(x) {} + static constexpr int kMantissaBits = 7; + + explicit operator float() const { return bfloat16_to_float(val_); } + + private: + inline uint16_t float_to_bfloat16(float x) const { + uint32_t as_int; + std::memcpy(&as_int, &x, sizeof(float)); + return as_int >> 16; + } + + inline float bfloat16_to_float(uint32_t as_int) const { + as_int <<= 16; + float x; + std::memcpy(&x, &as_int, sizeof(float)); + return x; + } + + uint16_t val_; +}; + +template +struct IsCustomFloatType + : std::integral_constant::value || + std::is_same::value> {}; +template +struct IsAnyFloatType + : std::integral_constant::value || + IsCustomFloatType::value> {}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FLOAT16_TYPES_H_ diff --git a/sparse_matmul/numerics/test_utils.h b/sparse_matmul/numerics/test_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..7f33a4fe937957387707001f5cdd5dd7ba0c35d7 --- /dev/null +++ b/sparse_matmul/numerics/test_utils.h @@ -0,0 +1,75 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_NUMERICS_TEST_UTILS_H_ +#define LYRA_CODEC_SPARSE_MATMUL_NUMERICS_TEST_UTILS_H_ + +#include +#include +#include + +#include "gtest/gtest.h" +#include "sparse_matmul/numerics/type_utils.h" + +namespace csrblocksparse { + +// Computes the relative difference between two floating point numbers +// std::abs(b - a) / a. If the a is < 10 * epsilon, then use the absolute +// difference instead of the relative one. +template +T RelDiff(T a, T b) { + static_assert(std::is_floating_point::value, + "RelDiff should only be used on floating point types."); + if (std::abs(a) < 600 * std::numeric_limits::epsilon()) { + return std::abs(b - a); + } + return std::abs((b - a) / a); +} + +// Compares two CacheAlignedVectors elementwise, checks if each pair passes a +// RelDiff check. The result of RelDiff is scaled by the log of the size of the +// column to account for increasing summation errors as the number of summands +// increases. +template +void CheckResult(const VectorType& lhs, const VectorType& rhs, int columns) { + ASSERT_EQ(lhs.size(), rhs.size()); + float epsilon = + 1.0f / + (1 << (MantissaBitsOf::value - 1)); + + // if we're summing a large number of values, then we can relax the tolerance + float log_scale = std::max(1.f, logf(columns)); + + // The tolerance is so large because it is a relative tolerance used to test + // numbers that are close to zero at the limit of the resolution of the + // representation. It would probably be better to focus on an absolute + // tolerance, based on the epsilon above. + const float tolerance = 0.026f; + for (int i = 0; i < lhs.size(); ++i) { + float lhs_value = static_cast(lhs.data()[i]); + float rhs_value = static_cast(rhs.data()[i]); + // If the absolute difference is no more than the epsilon for the + // representation, then it is OK. + if (std::abs(lhs_value - rhs_value) <= epsilon) continue; + float rel_diff = RelDiff(lhs_value, rhs_value) / log_scale; + EXPECT_LT(rel_diff, tolerance) << i % columns << " " << i / columns << " " + << lhs_value << " " << rhs_value; + } +} + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_NUMERICS_TEST_UTILS_H_ diff --git a/sparse_matmul/numerics/type_utils.h b/sparse_matmul/numerics/type_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..51291abeefa1637a246f2095a4e4a3470e6ef853 --- /dev/null +++ b/sparse_matmul/numerics/type_utils.h @@ -0,0 +1,89 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_NUMERICS_TYPE_UTILS_H_ +#define LYRA_CODEC_SPARSE_MATMUL_NUMERICS_TYPE_UTILS_H_ + +// A collection of useful utilities for determining types based on other types. + +#include + +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" + +namespace csrblocksparse { + +// Basic idea is that any two float types yield a float, fixed16 types +// yield a fixed32 with the exponent bits summed. Other options are not +// allowed. +template +struct TypeOfProduct {}; + +template +struct TypeOfProduct< + LhsType, RhsType, + typename std::enable_if::value && + IsAnyFloatType::value>::type> { + using type = float; +}; + +template +struct TypeOfProduct< + LhsType, RhsType, + typename std::enable_if::value && + IsFixed16Type::value>::type> { + static_assert(LhsType::kMantissaBits + RhsType::kMantissaBits < 31, + "Sum of mantissa bits must not exceed 31."); + using type = fixed32<31 - LhsType::kMantissaBits - RhsType::kMantissaBits>; +}; + +// Given a weight type T, determine what the RhsType should be for that type. +// bfloat16 / fp16 -> float; fixed16 = fixed16 +template +struct RhsTypeIs { + using type = float; +}; + +template +struct RhsTypeIs::value>::type> { + using type = T; +}; + +template +struct MantissaBitsOf { + // Although int types have zero mantissa bits, use 1 to avoid division by 0. + static constexpr int value = 1; +}; + +template +struct MantissaBitsOf< + T, typename std::enable_if::value || + IsCustomFloatType::value>::type> { + public: + static constexpr int value = T::kMantissaBits; +}; + +template +struct MantissaBitsOf< + T, typename std::enable_if::value>::type> { + public: + // Ignoring the fact that doubles have more mantissa bits. + static constexpr int value = 24; +}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_NUMERICS_TYPE_UTILS_H_ diff --git a/sparse_matmul/os/BUILD b/sparse_matmul/os/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..9c9d7684ff1108ad8dcef940396bd99c618ec25b --- /dev/null +++ b/sparse_matmul/os/BUILD @@ -0,0 +1,26 @@ +# Modules that interact with the operating system, and have no other dependencies. + +licenses(["notice"]) + +cc_library( + name = "coop_threads", + srcs = ["coop_threads.cc"], + hdrs = ["coop_threads.h"], + visibility = ["//sparse_matmul:__subpackages__"], + deps = [ + "@com_google_absl//absl/memory", + "@com_google_glog//:glog", + ], +) + +cc_test( + name = "coop_threads_test", + size = "small", + srcs = [ + "coop_threads_test.cc", + ], + deps = [ + ":coop_threads", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/sparse_matmul/os/coop_threads.cc b/sparse_matmul/os/coop_threads.cc new file mode 100644 index 0000000000000000000000000000000000000000..ece0995d4cf2d0a0f0f73170fb5977e08d7731b1 --- /dev/null +++ b/sparse_matmul/os/coop_threads.cc @@ -0,0 +1,63 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/os/coop_threads.h" + +#include + +namespace csrblocksparse { + +// All threads must execute a std::memory_order_seq_cst operation on +// |barrier_step_| this is what ensures the global memory consistency across +// the barrier. +// +// It is possible for the |barrier_step_| to roll over, but this is safe here. +// +// |yield| instructs the processor that it is in a spin loop and can stop doing +// things like out of order, speculative execution, prefetching, etc. On hyper +// threaded machines it can also choose to swap in the other thread. Note that +// this is a hardware level decision and the OS is never involved. +void SpinBarrier::barrier() { + if (num_threads_ < 2) return; + + int old_step = barrier_step_.load(std::memory_order_relaxed); + + int val_threads = threads_at_barrier_.fetch_add(1, std::memory_order_acq_rel); + + if (val_threads == num_threads_ - 1) { + // This is where the logic can go all wrong if the barrier is called by + // more threads than |num_threads_| -- the assumption that we're the last + // thread is inherently invalid. + + // Assuming num_threads_ are calling this barrier, then we're the last + // thread to reach the barrier, reset and advance step count. + threads_at_barrier_.store(0, std::memory_order_relaxed); + barrier_step_.store(old_step + 1, std::memory_order_release); + } else { + // Wait for step count to advance, then continue. + while (barrier_step_.load(std::memory_order_acquire) == old_step) { + // Intel recommends the equivalent instruction PAUSE, not be called more + // than once in a row, I can't find any recommendations for ARM, so + // following that advice here. +#if defined __aarch64__ || defined __arm__ + asm volatile("yield\n" ::: "memory"); +#else + // No pause for x86! The pause instruction on Skylake takes 141 clock + // cycles, which in an AVX2-down-clocked CPU is getting on for 70ns. +#endif + } + } +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/os/coop_threads.h b/sparse_matmul/os/coop_threads.h new file mode 100644 index 0000000000000000000000000000000000000000..9aefa614ea945d20f1699866e7931994b27d5842 --- /dev/null +++ b/sparse_matmul/os/coop_threads.h @@ -0,0 +1,179 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_OS_COOP_THREADS_H_ +#define LYRA_CODEC_SPARSE_MATMUL_OS_COOP_THREADS_H_ + +#include +#include // NOLINT +#include + +#define _COOP_THREADS_USE_STD_THREAD 1 + +#include "absl/memory/memory.h" +#include "glog/logging.h" + +namespace csrblocksparse { + +// A re-usable barrier. Keeps threads in extremely tight sync without +// relinquishing control. All memory writes _before_ this barrier are visible +// to all threads _after_ this barrier. Similar in spirit to +// pthreads_barrier. If you expect arrival times at this barrier to be varied +// by more than microseconds, this is probably not the right synchronization +// primitive for you. If |num_threads| exceeds the number of physical threads +// that can run simultaneously, then using this is certainly a bad idea +// (although it should still be correct). +// +// Callers MUST NOT call barrier from more threads than |num_threads|. The +// result is undefined behavior. +class SpinBarrier { + public: + explicit SpinBarrier(int num_threads) + : num_threads_(num_threads), threads_at_barrier_(0), barrier_step_(0) {} + + void barrier(); + + private: + const int num_threads_; + std::atomic threads_at_barrier_; + std::atomic barrier_step_; // unsigned to make overflow defined. +}; + +// Producer-consumer API using the same underlying mechanism as SpinBarrier. +// This class is intended to allow >=1 producers to produce data for >=1 +// consumers, without blocking the producers. +// The consumer will block if it is ready before all the producer(s) have +// produced. +// WARNING: By design this lock does not work without some other barrier that +// prevents any producer from producing again, or consumer from consuming again +// until all consumers have consumed. Basically any loop that uses +// ProducerConsumer must have at least two consume() calls in each thread (on +// different instances) in order for the lock to work correctly. +class ProducerConsumer { + public: + ProducerConsumer(int num_producers, int num_consumers) + : num_producers_(num_producers), + num_consumers_(num_consumers), + producers_ready_(0), + consumers_passed_(0) {} + + // Indicates that the data produced by this thread is ready. Does NOT block. + // NOTE that some other lock must exist between the call to this produce and + // looping back to call produce again on the same ProducerConsumer, that + // depends on all consumers having called consume. One such candidate would + // be a call to SpinBarrier above by all producers and consumers. + // Another candidate would be a separate ProducerConsumer object in which + // these producers consume some data produced by the threads that consume + // the data produced here. Eg. + // tid 0 1 2 3 + // action 1 produce produce consume consume (on ProducerConsumer 1) + // action 2 consume consume produce produce (on ProducerConsumer 2) + // action 3 produce produce consume consume (on ProducerConsumer 3) + // action 4 consume consume produce produce (on ProducerConsumer 4) + // loop back to action 1. + // NOTE: It is inadequate to loop back after action2, as thread 0 could loop + // back and consume again on PC2 while thread 1 is still completing its call + // to consume. It is still inadequate to loop back after action 3 for the same + // reason (but tsan doesn't seem to pick this up.) + inline void produce() { + producers_ready_.fetch_add(1, std::memory_order_acq_rel); + } + + // Waits if necessary for all producers to have produced before proceeding. + // The ProducerConsumer cannot be reused until all consumers have consumed. + // See detailed comment and example on produce(). + inline void consume() { + // We can't do anything until all the producers have produced. + while (producers_ready_.load(std::memory_order_acquire) < num_producers_) { +#if defined __aarch64__ || defined __arm__ + asm volatile("yield\n" ::: "memory"); +#else + // No pause for x86! The pause instruction on Skylake takes 141 clock + // cycles, which in an AVX2-down-clocked CPU is getting on for 70ns. +#endif + } + // NOTE: It is tempting to move this fetch_add to before the wait loop to + // reduce contention for the memory location, but that would break the lock, + // as then the last to arrive could zero out the producers_ready before the + // other consumers have noticed that all producers have produced. + // With the fetch_add after the wait loop, we are guaranteed that all + // producers have produced AND all consumers have noticed that they have + // produced before we zero out the counters. + int consumers = consumers_passed_.fetch_add(1, std::memory_order_acq_rel); + if (consumers == num_consumers_ - 1) { + // The last consumer to pass has to reset everything for the next time. + producers_ready_.store(0, std::memory_order_relaxed); + consumers_passed_.store(0, std::memory_order_relaxed); + } + } + int num_producers() const { return num_producers_; } + int num_consumers() const { return num_consumers_; } + + private: + const int num_producers_; + const int num_consumers_; + std::atomic producers_ready_; + std::atomic consumers_passed_; +}; + +// We define Thread here, so we can easily change its type later. + +using Thread = std::thread; +using ThreadId = std::thread::id; + +// Creates (|num_threads|-1) threads and executes a total of |num_threads| +// copies of |func| (executes one on the calling thread). +// +// Useful for long running func bodies that are intended to run in lock step. +// A possible use case for this style parallelism over a thread pool is when +// we want tight control over which memory is resident in the L2 cache of a +// processor. With a pool we have no control over which thread gets assigned +// which portion of the computation resulting in L2 thrashing. With this +// breakdown we can make sure each thread only acceses a specific L2-sized +// portion of memory. +// +// func's signature must be (SpinBarrier*, int thread_id, ...); +template +void LaunchOnThreadsWithBarrier(int num_threads, Function&& func, + Args&&... args) { + SpinBarrier spin_barrier(num_threads); + + std::vector> threads; + threads.reserve(num_threads); + for (int tid = 1; tid < num_threads; ++tid) { + auto f = [&, tid]() { func(&spin_barrier, tid, args...); }; + + threads.emplace_back(absl::make_unique(f)); +#ifndef _COOP_THREADS_USE_STD_THREAD + CHECK_OK(threads.back()->Start()); +#endif + } + + const int kLocalTid = 0; + func(&spin_barrier, kLocalTid, args...); + + for (auto& thread : threads) { +#ifdef _COOP_THREADS_USE_STD_THREAD + thread->join(); +#else + CHECK_OK(thread->Join()); +#endif + } +} + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_OS_COOP_THREADS_H_ diff --git a/sparse_matmul/os/coop_threads_test.cc b/sparse_matmul/os/coop_threads_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0aba27f94541af41bfe5a77bbd4a39ea9633cd8b --- /dev/null +++ b/sparse_matmul/os/coop_threads_test.cc @@ -0,0 +1,134 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/os/coop_threads.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +TEST(Threads, LaunchThreads) { + std::atomic counter(0); + + auto f = [&](csrblocksparse::SpinBarrier* barrier, int tid) { + counter.fetch_add(tid); + }; + + const int kNumThreads = 10; + csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, f); + + ASSERT_EQ(counter.load(), kNumThreads * (kNumThreads - 1) / 2); +} + +TEST(Threads, SpinBarrier) { + const int kNumThreads = 10; + + std::vector tids(kNumThreads, 0); + std::vector> expected; + for (int i = 0; i < 10; ++i) { + expected.emplace_back(kNumThreads); + std::iota(expected.back().begin(), expected.back().end(), 0); + std::transform(expected.back().begin(), expected.back().end(), + expected.back().begin(), + [i](int x) -> int { return (i + 1) * x; }); + } + + auto f = [&](csrblocksparse::SpinBarrier* barrier, int tid) { + for (int i = 0; i < 10; ++i) { + tids[tid] += tid; + barrier->barrier(); + EXPECT_EQ(tids, expected[i]); + barrier->barrier(); + } + }; + + csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, f); +} + +TEST(Threads, ProducerConsumer) { + constexpr int kNumThreads = 4; + constexpr int kNumIterations = 10; + + std::vector shared_data(kNumThreads, 0); + std::vector> expected; + for (int i = 1; i <= kNumIterations; ++i) { + // Execute the parallel work sequentially. + // Last two threads write their id * iteration. + std::pair inputs = + std::make_pair((kNumThreads - 2) * i, (kNumThreads - 1) * i); + // First two threads compute sum and difference of those values. + std::pair diffs = std::make_pair(inputs.first + inputs.second, + inputs.first - inputs.second); + // Last two threads compute sum and product. + std::pair sums = + std::make_pair(diffs.first + diffs.second, diffs.first * diffs.second); + // First two threads compute product and difference of those values. + expected.emplace_back( + std::make_pair(sums.first * sums.second, sums.first - sums.second)); + // Last two threads will check for the correct result. + } + csrblocksparse::ProducerConsumer first_pc(2, 2); + csrblocksparse::ProducerConsumer second_pc(2, 2); + csrblocksparse::ProducerConsumer third_pc(2, 2); + csrblocksparse::ProducerConsumer fourth_pc(2, 2); + + auto f = [&](csrblocksparse::SpinBarrier* barrier, int tid) { + for (int i = 1; i <= kNumIterations; ++i) { + if (tid == kNumThreads - 2) { + // Last two threads write their id * iteration. + shared_data[tid] = tid * i; + first_pc.produce(); + second_pc.consume(); + // They then compute sum and product. + shared_data[tid] = shared_data[0] + shared_data[1]; + third_pc.produce(); + // They finally check the result. + fourth_pc.consume(); + EXPECT_EQ(expected[i - 1].first, shared_data[0]) << "i=" << i; + } else if (tid == kNumThreads - 1) { + shared_data[tid] = tid * i; + first_pc.produce(); + second_pc.consume(); + shared_data[tid] = shared_data[0] * shared_data[1]; + third_pc.produce(); + fourth_pc.consume(); + EXPECT_EQ(expected[i - 1].second, shared_data[1]) << "i=" << i; + } else if (tid == 0) { + // First two threads compute sum and difference. + first_pc.consume(); + shared_data[tid] = + shared_data[kNumThreads - 2] + shared_data[kNumThreads - 1]; + second_pc.produce(); + // They then compute product and difference. + third_pc.consume(); + shared_data[tid] = + shared_data[kNumThreads - 2] * shared_data[kNumThreads - 1]; + fourth_pc.produce(); + } else if (tid == 1) { + first_pc.consume(); + shared_data[tid] = + shared_data[kNumThreads - 2] - shared_data[kNumThreads - 1]; + second_pc.produce(); + third_pc.consume(); + shared_data[tid] = + shared_data[kNumThreads - 2] - shared_data[kNumThreads - 1]; + fourth_pc.produce(); + } + } + }; + + csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, f); +} diff --git a/sparse_matmul/sparse_matmul.h b/sparse_matmul/sparse_matmul.h new file mode 100644 index 0000000000000000000000000000000000000000..dc50727861248bb8ffec0015d800987c518d762b --- /dev/null +++ b/sparse_matmul/sparse_matmul.h @@ -0,0 +1,34 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_SPARSE_MATMUL_H_ +#define LYRA_CODEC_SPARSE_MATMUL_SPARSE_MATMUL_H_ + +// IWYU pragma: begin_exports +#include "sparse_matmul/compute/gru_gates.h" +#include "sparse_matmul/layers/csr_blocksparse_matrix.h" +#include "sparse_matmul/layers/masked_sparse_matrix.h" +#include "sparse_matmul/layers/sparse_linear_layer.h" +#include "sparse_matmul/layers/utils.h" +#include "sparse_matmul/numerics/fast_transcendentals.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" +#include "sparse_matmul/numerics/type_utils.h" +#include "sparse_matmul/os/coop_threads.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" +// IWYU pragma: end_exports + +#endif // LYRA_CODEC_SPARSE_MATMUL_SPARSE_MATMUL_H_ diff --git a/sparse_matmul/vector/BUILD b/sparse_matmul/vector/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..3fc064a476e70f7618b62d15cb3ff1ce24781321 --- /dev/null +++ b/sparse_matmul/vector/BUILD @@ -0,0 +1,63 @@ +# Vector that always aligns its data to the cache line of the host machine. + +licenses(["notice"]) + +cc_library( + name = "cache_aligned_vector", + hdrs = [ + "cache_aligned_vector.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + ":aligned_malloc", + "//sparse_matmul/numerics:fast_transcendentals", + "//sparse_matmul/numerics:types", + "//sparse_matmul/os:coop_threads", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_test( + name = "cachealignedvector_test", + size = "small", + srcs = [ + "cachealignedvector_test.cc", + ], + copts = [ + "-DFAST_TRANSCENDENTALS", + "-DSIGMOID_AS_TANH", + ], + deps = [ + ":cache_aligned_vector", + "//sparse_matmul/numerics:test_utils", + "//sparse_matmul/os:coop_threads", + "@com_google_googletest//:gtest_main", + ], +) + +cc_binary( + name = "cachealignedvector_benchmark", + srcs = [ + "cachealignedvector_benchmark.cc", + ], + copts = [ + "-DFAST_TRANSCENDENTALS", + "-DSIGMOID_AS_TANH", + "-DACCURATE_TRANSCENDENTAL_APPROX", + ], + deps = [ + ":cache_aligned_vector", + "@com_github_google_benchmark//:benchmark", + "@com_github_google_benchmark//:benchmark_main", + ], +) + +cc_library( + name = "aligned_malloc", + srcs = ["aligned_malloc.cc"], + hdrs = [ + "aligned_malloc.h", + ], +) diff --git a/sparse_matmul/vector/aligned_malloc.cc b/sparse_matmul/vector/aligned_malloc.cc new file mode 100644 index 0000000000000000000000000000000000000000..410d268ed838541e13c045b90649884ee4c81cef --- /dev/null +++ b/sparse_matmul/vector/aligned_malloc.cc @@ -0,0 +1,46 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +namespace csrblocksparse { + +void Free(void* ptr) { free(ptr); } + +void* Malloc(size_t size) { return malloc(size); } + +void aligned_free(void* aligned_memory) { Free(aligned_memory); } + +void* aligned_malloc(size_t size, int minimum_alignment) { +#if defined(__ANDROID__) + return memalign(minimum_alignment, size); +#else // !defined(__ANDROID__) + void* ptr = nullptr; + // posix_memalign requires that the requested alignment be at least + // sizeof(void*). In this case, fall back on malloc which should return + // memory aligned to at least the size of a pointer. + const int required_alignment = sizeof(void*); + if (minimum_alignment < required_alignment) return Malloc(size); + int err = posix_memalign(&ptr, minimum_alignment, size); + if (err != 0) { + return nullptr; + } else { + return ptr; + } +#endif +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/vector/aligned_malloc.h b/sparse_matmul/vector/aligned_malloc.h new file mode 100644 index 0000000000000000000000000000000000000000..ff13d9390f202441250b1682f2e25b64f6e6f9cc --- /dev/null +++ b/sparse_matmul/vector/aligned_malloc.h @@ -0,0 +1,32 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_VECTOR_ALIGNED_MALLOC_H_ +#define LYRA_CODEC_SPARSE_MATMUL_VECTOR_ALIGNED_MALLOC_H_ + +#include +namespace csrblocksparse { + +void Free(void* ptr); + +void* Malloc(size_t size); + +void aligned_free(void* aligned_memory); + +void* aligned_malloc(size_t size, int minimum_alignment); +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_VECTOR_ALIGNED_MALLOC_H_ diff --git a/sparse_matmul/vector/cache_aligned_vector.h b/sparse_matmul/vector/cache_aligned_vector.h new file mode 100644 index 0000000000000000000000000000000000000000..871298d25b9293fa8b3c1acf97f109e007f5fd9e --- /dev/null +++ b/sparse_matmul/vector/cache_aligned_vector.h @@ -0,0 +1,1117 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_VECTOR_CACHE_ALIGNED_VECTOR_H_ +#define LYRA_CODEC_SPARSE_MATMUL_VECTOR_CACHE_ALIGNED_VECTOR_H_ + +#if defined __aarch64__ +#include +#endif +#if defined __AVX__ || defined __AVX2__ +#include +#endif + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "sparse_matmul/numerics/fast_transcendentals.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/type_utils.h" +#include "sparse_matmul/os/coop_threads.h" +#include "sparse_matmul/vector/aligned_malloc.h" + +namespace csrblocksparse { + +template +class MutableVectorView; +template +class VectorView; + +// CacheAlignedVector is a simple vector-like class that makes sure its +// underlying buffer is aligned to a |kCacheLineSize| boundary. It is meant +// for numeric computation and cannot be used to store objects that are +// not POD as it will neither call their constructors nor destructors. +// +// It is meant to be used with the CSRBlockSparseMatrix class for +// implenting basic neural network layers composed of SpMV. +// +// This class is thread compatible. +template +class CacheAlignedVector { + static_assert(std::is_pod::value, + "CacheAlignedVector can only be" + " used with POD"); + + public: + using value_type = DataType; + + explicit CacheAlignedVector(std::size_t size) : size_(size), data_(nullptr) { + gen_ = absl::make_unique(0); + data_ = reinterpret_cast( + aligned_malloc(size_ * sizeof(DataType), kCacheLineSize)); + } + + explicit CacheAlignedVector(const std::vector& input) + : size_(input.size()), data_(nullptr) { + gen_ = absl::make_unique(0); + data_ = reinterpret_cast( + aligned_malloc(size_ * sizeof(DataType), kCacheLineSize)); + memcpy(data_, input.data(), size_ * sizeof(DataType)); + } + + template + explicit CacheAlignedVector(const std::vector& input) + : size_(input.size()), data_(nullptr) { + gen_ = absl::make_unique(0); + data_ = reinterpret_cast( + aligned_malloc(size_ * sizeof(DataType), kCacheLineSize)); + for (int i = 0; i < size_; ++i) + data_[i] = static_cast(input.data()[i]); + } + + CacheAlignedVector(const DataType* input, int size) + : size_(size), data_(nullptr) { + gen_ = absl::make_unique(0); + data_ = reinterpret_cast( + aligned_malloc(size_ * sizeof(DataType), kCacheLineSize)); + memcpy(data_, input, size_ * sizeof(DataType)); + } + + template + explicit CacheAlignedVector(const InputType* input, int size) + : size_(size), data_(nullptr) { + gen_ = absl::make_unique(0); + data_ = reinterpret_cast( + aligned_malloc(size_ * sizeof(DataType), kCacheLineSize)); + for (int i = 0; i < size_; ++i) data_[i] = static_cast(input[i]); + } + + CacheAlignedVector() : size_(0), data_(nullptr) {} + + ~CacheAlignedVector() { + aligned_free(data_); + data_ = nullptr; + size_ = 0; + } + + // Copies are _deep_ copies + CacheAlignedVector(CacheAlignedVector const& other) + : size_(0), data_(nullptr), gen_(nullptr) { + if (other.gen_) + gen_ = absl::make_unique(std::minstd_rand(*other.gen_)); + this->resize(other.size()); + memcpy(data_, other.data(), size_ * sizeof(DataType)); + } + // Copies a slice of the input. + CacheAlignedVector(CacheAlignedVector const& other, int start, int end) + : size_(0), data_(nullptr), gen_(nullptr) { + if (other.gen_) + gen_ = absl::make_unique(std::minstd_rand(*other.gen_)); + this->resize(end - start); + memcpy(data_, other.data() + start, size_ * sizeof(DataType)); + } + + void operator=(CacheAlignedVector const& other) { + if (other.gen_) + gen_ = absl::make_unique(std::minstd_rand(*other.gen_)); + else + gen_.reset(nullptr); + this->resize(other.size()); + memcpy(data_, other.data(), size_ * sizeof(DataType)); + } + + CacheAlignedVector(CacheAlignedVector&& other) + : size_(0), data_(nullptr), gen_(std::move(other.gen_)) { + size_ = other.size_; + data_ = other.data_; + other.size_ = 0; + other.data_ = nullptr; + } + + CacheAlignedVector& operator=( + CacheAlignedVector&& other) { + aligned_free(data_); + if (other.gen_) + gen_ = absl::make_unique(std::move(*other.gen_)); + else + gen_.reset(nullptr); + size_ = other.size_; + data_ = other.data_; + other.size_ = 0; + other.data_ = nullptr; + return *this; + } + + VectorView AsView() const { + return VectorView(this->data(), this->size(), 1); + } + + MutableVectorView AsMutableView() { + return MutableVectorView(this->data(), this->size(), 1); + } + + // Copies the |split_points| to use in ReducingSample. + void PrepareForThreads(const std::vector& split_points, + int block_height) { + maxes_.resize(split_points.size() - 1); + thread_starts_ = split_points; + for (int t = 0; t < thread_starts_.size(); ++t) { + thread_starts_[t] *= block_height; + } + } + + void FillRandom(float min = -10.f, float max = 10.f) { + // 10 is smaller than any nonzero bound of the range of any data type. + std::uniform_real_distribution dist(min, max); + for (std::size_t i = 0; i < size_; i++) { + data_[i] = DataType(dist(*gen_)); + } + } + + void FillZero() { + for (std::size_t i = 0; i < size_; i++) { + data_[i] = DataType(0.f); + } + } + + void FillOnes() { + for (std::size_t i = 0; i < size_; i++) { + data_[i] = DataType(1.f); + } + } + + void FillWith(const DataType& value) { + for (std::size_t i = 0; i < size_; i++) { + data_[i] = value; + } + } + + // Interprets |data_| as logits and samples from the distribution, this + // version operates IN PLACE and uses an internal random source. + template + typename std::enable_if::value, int>::type Sample( + float temperature = 1.f) { + return Sample(temperature, gen_.get(), this); + } + + // Interprets |data_| as logits and samples. This version requires the random + // source and temporary memory to be passed in. It is thread safe assuming + // no other threads are using the generator and temporary memory. +#if defined __aarch64__ + template + typename std::enable_if::value, int>::type Sample( + float temperature, std::minstd_rand* gen, + CacheAlignedVector* scratch) const { + DCHECK(scratch->size() >= size_); + // Round down to nearest multiple of 8. + int SIMD_iterations = 8 * (size_ / 8); + float* scratch_ptr = scratch->data(); + std::uniform_real_distribution dist; + float random_number = dist(*gen); + + float32x4_t sum = vdupq_n_f32(0.f); + float32x4_t sum1 = vdupq_n_f32(0.f); + float32x4_t max_value = vdupq_n_f32(std::numeric_limits::lowest()); + float32x4_t max_value1 = vdupq_n_f32(std::numeric_limits::lowest()); + float32x4_t inv_temp = vdupq_n_f32(1.f / temperature); + // Compute sum of exp(x) for the denominator. + // Hand unroll by 2, gives speed improvement. + constexpr int kUnrollFactor = 2; + constexpr int kElementsPerIter = kUnrollFactor * kSIMDWidth; + for (std::size_t i = 0; i < SIMD_iterations; i += kElementsPerIter) { + max_value = vmaxq_f32(vld1q_f32(data_ + i), max_value); + max_value1 = vmaxq_f32(vld1q_f32(data_ + i + 4), max_value1); + } + + // Pairwise reduction. + max_value = vpmaxq_f32(max_value, max_value1); + // Duplicate (dupq) maximum across vector (maxnmvq). + float scalar_max_value = vmaxvq_f32(max_value); + + for (int i = SIMD_iterations; i < size_; ++i) { + scalar_max_value = std::max(data_[i], scalar_max_value); + } + + max_value = vdupq_n_f32(scalar_max_value); + + for (std::size_t i = 0; i < SIMD_iterations; i += kElementsPerIter) { + // Load and multiply by temperature. + float32x4_t x = + vmulq_f32(vsubq_f32(vld1q_f32(data_ + i), max_value), inv_temp); + float32x4_t x1 = + vmulq_f32(vsubq_f32(vld1q_f32(data_ + i + 4), max_value), inv_temp); + + float32x4_t exponent = fast_exp(x); + float32x4_t exponent1 = fast_exp(x1); + + sum = vaddq_f32(sum, exponent); + sum1 = vaddq_f32(sum1, exponent1); + + vst1q_f32(scratch_ptr + i, exponent); + vst1q_f32(scratch_ptr + i + 4, exponent1); + } + + // Horizontally reduce the two sums. + sum = vpaddq_f32(sum, sum1); + sum = vpaddq_f32(sum, sum); + float denom = vgetq_lane_f32(sum, 0) + vgetq_lane_f32(sum, 1); + + for (int i = SIMD_iterations; i < size_; ++i) { + float x = (data_[i] - scalar_max_value) / temperature; + float x_exp = expf(x); + denom += x_exp; + scratch_ptr[i] = x_exp; + } + + // Note: rather than normalize all the probabilities, we can just + // apply the inverse normalization to the random number. + random_number *= denom; + + // Now do the scan in serial, return as soon as possible. + // TODO(b/188821456): This could be made into a parallel SIMD scan + // followed by a binary search, for a small speedup. + float cumsum = 0.f; + for (std::size_t i = 0; i < size_; i++) { + cumsum += scratch_ptr[i]; + if (cumsum >= random_number) return i; + } + return size_ - 1; + } + + template + static inline int32x4_t vmul_temp_fixed(int32x4_t x, int32x2_t inv_temp) { + int32x2_t xh = vget_high_s32(x); + int32x2_t xl = vget_low_s32(x); + int32x2_t ph = vqrshrn_n_s64(vmull_s32(xh, inv_temp), Q::kMantissaBits); + int32x2_t pl = vqrshrn_n_s64(vmull_s32(xl, inv_temp), Q::kMantissaBits); + return vcombine_s32(pl, ph); + } + + template + static inline int float_to_fixed(float x) { + return static_cast(x * (1 << Q::kMantissaBits)); + } + + template + static inline float fixed_to_float(int x) { + const float inv_denom = 1.f / (1 << Q::kMantissaBits); + return static_cast(x) * inv_denom; + } + + template + typename std::enable_if::value, int>::type Sample( + float temperature, std::minstd_rand* gen, + CacheAlignedVector* scratch) const { + DCHECK(scratch->size() >= size_); + // Round down to nearest multiple of 8. + int SIMD_iterations = 8 * (size_ / 8); + int* scratch_ptr = scratch->data(); + float scalar_inv_temp = 1.f / temperature; + + int32x4_t sum = vdupq_n_s32(0); + int32x4_t sum1 = vdupq_n_s32(0); + int32x4_t max_value = vdupq_n_s32(std::numeric_limits::lowest()); + int32x4_t max_value1 = vdupq_n_s32(std::numeric_limits::lowest()); + int32x2_t inv_temp = vdup_n_s32(float_to_fixed(scalar_inv_temp)); + // Compute sum of exp(x) for the denominator. + // Hand unroll by 2, gives speed improvement. + + const int* data_ptr = reinterpret_cast(data_); + constexpr int kUnrollFactor = 2; + constexpr int kElementsPerIter = kUnrollFactor * kSIMDWidth; + for (std::size_t i = 0; i < SIMD_iterations; i += kElementsPerIter) { + max_value = vmaxq_s32(vld1q_s32(data_ptr + i), max_value); + max_value1 = vmaxq_s32(vld1q_s32(data_ptr + i + kSIMDWidth), max_value1); + } + + // Pairwise reduction. + max_value = vpmaxq_s32(max_value, max_value1); + int scalar_max_value = vmaxvq_s32(max_value); + + for (int i = SIMD_iterations; i < size_; ++i) { + scalar_max_value = std::max(data_[i].raw_val(), scalar_max_value); + } + max_value = vdupq_n_s32(scalar_max_value); + // We clip all loaded values to a lower bound of the lowest possible arg to + // exp + the max value that we are going to subtract, to prevent underflow + // in exp and also to avoid wrap-around with values that are already minint. + int32x4_t clip_min = + vdupq_n_s32(scalar_max_value - (80 << MantissaBitsOf::value)); + + for (std::size_t i = 0; i < SIMD_iterations; i += kElementsPerIter) { + // Load and multiply by temperature. + int32x4_t loaded = vmaxq_s32(vld1q_s32(data_ptr + i), clip_min); + int32x4_t x = vmul_temp_fixed(vsubq_s32(loaded, max_value), inv_temp); + loaded = vmaxq_s32(vld1q_s32(data_ptr + i + kSIMDWidth), clip_min); + int32x4_t x1 = vmul_temp_fixed(vsubq_s32(loaded, max_value), inv_temp); + + int32x4_t exponent = vcvtq_n_s32_f32(fast_exp_fixed(x), + Q::kMantissaBits); + int32x4_t exponent1 = vcvtq_n_s32_f32( + fast_exp_fixed(x1), Q::kMantissaBits); + + sum = vaddq_s32(sum, exponent); + sum1 = vaddq_s32(sum1, exponent1); + + vst1q_s32(scratch_ptr + i, exponent); + vst1q_s32(scratch_ptr + i + kSIMDWidth, exponent1); + } + + // Horizontally reduce the two sums. + sum = vpaddq_s32(sum, sum1); + sum = vpaddq_s32(sum, sum); + float denom = + fixed_to_float(vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1)); + for (int i = SIMD_iterations; i < size_; ++i) { + float x_exp = fast_exp_fixed( + DataType((data_[i].raw_val() - scalar_max_value) * scalar_inv_temp)); + + denom += x_exp; + scratch_ptr[i] = float_to_fixed(x_exp); + } + + // Note: rather than normalize all the probabilities, we can just + // apply the inverse normalization to the random number. + std::uniform_real_distribution dist; + int random_number = float_to_fixed(dist(*gen) * denom); + + // Now do the scan in serial, return as soon as possible. + // TODO(b/188821456): This could be made into a parallel SIMD scan + // followed by a binary search, for a small speedup. + int cumsum = 0; + for (std::size_t i = 0; i < size_; i += kSIMDWidth) { + int32x4_t next_vals = vld1q_s32(&scratch_ptr[i]); + cumsum += vaddvq_s32(next_vals); + if (cumsum >= random_number) { + int high_sum = vaddv_s32(vget_high_s32(next_vals)); + if (cumsum - high_sum > random_number) { + // One of the lower ones. + return (cumsum - high_sum - scratch_ptr[i + 1] > random_number) + ? i + : i + 1; + } else { + // One of the upper ones. + return (cumsum - scratch_ptr[i + 3] > random_number) ? i + 2 : i + 3; + } + } + } + return size_ - 1; + } +#endif // defined __aarch64__ + + template +#if defined __aarch64__ + typename std::enable_if< + !std::is_same::value && !IsFixed32Type::value, int>::type +#else + int +#endif + Sample(float temperature, std::minstd_rand* gen, + CacheAlignedVector* scratch, int tid = 0, + SpinBarrier* barrier = nullptr) const { + return ScalarSample(temperature, gen, scratch, tid, 0, -1, barrier); + } + + int ScalarSample(float temperature, std::minstd_rand* gen, + CacheAlignedVector* scratch, int tid = 0, + const int mindex = 0, const int maxdex = -1, + SpinBarrier* barrier = nullptr) const { + // TODO(b/188821456) Don't ignore |tid| and |barrier|. Currently all threads + // duplicate the same work and ignore |tid| and |barrier|, but they could + // be used to execute a reducing max over the data before the exp operation. + DCHECK_EQ(barrier, nullptr); + DCHECK_EQ(tid, 0); + DCHECK(scratch->size() >= size_); + DCHECK(size_ % 8 == 0) << "CacheAlignedVector size must be a multiple of " + "8 to allow for maximum SIMD and loop unroll, " + "got " + << size_ % 8; + DCHECK(size_ > mindex >= 0); + DCHECK((maxdex == -1) || (0 <= mindex < maxdex < size_)); + int maxindex = maxdex > 0 ? maxdex : size_; + + float* scratch_ptr = scratch->data(); + std::uniform_real_distribution dist; + float random_number = dist(*gen); + + float sum = 0.f; + float max_value = std::numeric_limits::lowest(); + for (int i = mindex; i < maxindex; ++i) { + max_value = std::max(max_value, static_cast(data_[i])); + } + float inv_temperature = 1.f / temperature; + for (int i = mindex; i < maxindex; ++i) { + float exponent = fast_exp((static_cast(data_[i]) - max_value) * + inv_temperature); + scratch_ptr[i] = exponent; + sum += exponent; + } + + // Note: rather than normalize all the probabilities, we can just + // apply the inverse normalization to the random number. + random_number *= sum; + + float cumsum = 0.f; + for (std::size_t i = mindex; i < maxindex; i++) { + cumsum += scratch_ptr[i]; + if (cumsum >= random_number) return i; + } + return maxindex - 1; + } + +#if defined __AVX2__ + // Some AVX2-only code. + // Returns the max of |data_| in the range [|t_start|, |t_end|). + inline int ThreadMax(int t_start, int t_end) const { + // Note: The AVX2 code requires that the number of threads and the output + // size be a power of 2. For efficiency purposes, these should be checked + // when preparing for threads in an architecture class. + // The output size must be a power of 2 so the binary search for the sample + // point works correctly. + // The number of threads must be a power of 2 so that it nicely divides the + // output size, which has to be a power of 2. + __m256i maxes = + _mm256_load_si256(reinterpret_cast<__m256i const*>(data_ + t_start)); + for (int i = t_start + kSIMDWidth; i < t_end; i += kSIMDWidth) { + __m256i data = + _mm256_load_si256(reinterpret_cast<__m256i const*>(data_ + i)); + maxes = _mm256_max_epi32(maxes, data); + } + // Max within the register. + // Bring the top lane down to the bottom. + __m256i other = _mm256_permute4x64_epi64(maxes, 0xe); + maxes = _mm256_max_epi32(maxes, other); + // Bring the 2nd 64 bits to the bottom. + other = _mm256_shuffle_epi32(maxes, 0xe); + maxes = _mm256_max_epi32(maxes, other); + // Bring the 2nd 32 bits to the bottom. + other = _mm256_shuffle_epi32(maxes, 1); + maxes = _mm256_max_epi32(maxes, other); + return _mm256_extract_epi32(maxes, 0); + } + + // Applies exp (approximately) to the difference between |data_| and + // |max_value|, storing the result in scratch, and returns the sum. + template + inline float ApplyExpAndSum(int max_value, float* scratch_ptr) { + // Rough approximation for exp(x). See fast_exp_fixed. + // Constant clipping limit on exp arg. Since its value is never positive, + // we only need to clip on the negative side. + constexpr int kClipLimit = -(80 << kMantissaBits); + __m256i clip_val = _mm256_set1_epi32(kClipLimit); + // Multiplication factor to convert x from log base e to log base 2, shifted + // by an amount that lines up the binary point with the float32 + // representation, after the multiplication + static const int kLogFactor = (1 << (23 - kMantissaBits)) / logf(2.f); + __m256i log_factor = _mm256_set1_epi32(kLogFactor); + // Fix the exponent bias and add the additive fudge factor for the mantissa + // to finish the approximate conversion. + constexpr int kAddConstant = (127 << 23) - 366000; + __m256i constant = _mm256_set1_epi32(kAddConstant); + // Broadcast the max_value. + __m256i max_val = _mm256_set1_epi32(max_value); + // Add the max to the |clip_val|, so it can be used before the subtraction. + clip_val = _mm256_add_epi32(clip_val, max_val); + // The sum of the exps. + __m256 sum1 = _mm256_setzero_ps(); + for (int i = 0; i < size_; i += kSIMDWidth) { + // |data_| - |max_value|. + __m256i data = + _mm256_load_si256(reinterpret_cast<__m256i const*>(data_ + i)); + // Clip to negative limit before the subtraction of |max_val| to avoid + // wrap-around with min-int values. + data = _mm256_max_epi32(data, clip_val); + __m256i difference = _mm256_sub_epi32(data, max_val); + // Exponent trick exp. + // Multiply by |log_factor|, keeping only the lower 32 bits. + difference = _mm256_mullo_epi32(difference, log_factor); + // Add the constant. + difference = _mm256_add_epi32(difference, constant); + // Reinterpret the results as float32. + __m256 float_exp = _mm256_castsi256_ps(difference); + // Sum the results and save to scratch space. + _mm256_store_ps(scratch_ptr + i, float_exp); + sum1 = _mm256_add_ps(sum1, float_exp); + } + // Horizontally add the 8 values in sum. + // Get the top lane down to the bottom. + __m256 sum2 = _mm256_permute2f128_ps(sum1, sum1, 1); + sum1 = _mm256_add_ps(sum1, sum2); + sum1 = _mm256_hadd_ps(sum1, sum1); + sum1 = _mm256_hadd_ps(sum1, sum1); + return _mm256_cvtss_f32(sum1); + } + + // Binary search for the index where the cumulative sum meets random_target. + inline void FindSamplePoint(const float* scratch_ptr, float* random_target, + int* start, int* end) { + int halfsize = (*end - *start) / 2; + do { + // Sum the first half. + // We sum the section in two independent parts, so we can step down 2 + // levels if we get a hit in this half. + int quartersize = halfsize / (2 * kSIMDWidth); + quartersize *= kSIMDWidth; + halfsize = quartersize * 2; + // The sums of the quarters. + __m256 sum1 = _mm256_setzero_ps(); + __m256 sum2 = _mm256_setzero_ps(); + const float* ptr1 = scratch_ptr + *start; + const float* ptr2 = ptr1 + quartersize; + for (int i = 0; i < quartersize; i += kSIMDWidth) { + __m256 data1 = _mm256_load_ps(ptr1 + i); + __m256 data2 = _mm256_load_ps(ptr2 + i); + sum1 = _mm256_add_ps(sum1, data1); + sum2 = _mm256_add_ps(sum2, data2); + } + // Horizontally add the two sums, keeping the results separate. + // Numbering |sum1|=[0-7] and |sum2|=[8-15]... + sum1 = _mm256_hadd_ps(sum1, sum2); + // |sum1| now has [0+1, 2+3, 8+9, 10+11, 4+5, 6+7, 12+13, 14+15]. + // Bring the top lane down to the bottom. + sum2 = _mm256_permute2f128_ps(sum1, sum1, 1); + sum1 = _mm256_hadd_ps(sum1, sum2); + // Now |sum1| has [0-3, 8-11, 4-7, 12-15], so swap the middle two + // elements. + sum1 = _mm256_shuffle_ps(sum1, sum1, 0xd8); + sum1 = _mm256_hadd_ps(sum1, sum1); + // Now |sum1| has [0-7, 8-15, ....]. + float bottom_quarter = _mm256_cvtss_f32(sum1); + if (bottom_quarter >= *random_target) { + *end = *start + quartersize; + } else { + float bottom_half = _mm256_cvtss_f32(_mm256_hadd_ps(sum1, sum1)); + if (bottom_half >= *random_target) { + *start += quartersize; + *end = *start + quartersize; + *random_target -= bottom_quarter; + } else { + *start += halfsize; + *random_target -= bottom_half; + } + } + halfsize = (*end - *start) / 2; + } while (halfsize >= kSIMDWidth * 2); + } +#endif // __AVX2__ code + + // Fixed32 version. + template + typename std::enable_if::value, int>::type ThreadMax( + int tid) const { + int t_start = thread_starts_[tid]; + int t_end = thread_starts_[tid + 1]; +#if defined __AVX2__ + return ThreadMax(t_start, t_end); +#else + // With operator<, could use std::max_element. + int max_value = data_[t_start].raw_val(); + for (int i = t_start + 1; i < t_end; ++i) { + max_value = std::max(max_value, data_[i].raw_val()); + } + return max_value; +#endif + } + + // As Sample above, except that if |tid| and |barrier| are provided, it will + // save some time by running a local max in each thread before combining them + // and doing the rest of the work duplicated across all threads. + // Fixed32 version. + template + typename std::enable_if::value, int>::type ReducingSample( + std::minstd_rand* gen, CacheAlignedVector* scratch, int tid = 0, + float temperature = 1.0f, SpinBarrier* barrier = nullptr) { + if (barrier != nullptr) barrier->barrier(); + // Sample only accepts tid of 0, as it would ignore it anyway. + // All threads duplicate the same work in this path. + return Sample(temperature, gen, scratch, /*tid=*/0); + } + + template + typename std::enable_if::value, int>::type ReducingSample( + std::minstd_rand* gen, CacheAlignedVector* scratch, int tid = 0, + float temperature = 1.0f, SpinBarrier* barrier = nullptr) { + int max_value; + if (barrier == nullptr) { + // There is only one thread. + max_value = ThreadMax(tid); + } else { + // Reduce max using the threads to do some of the work. + maxes_[tid] = ThreadMax(tid); + barrier->barrier(); + // The rest of the work is duplicated by all threads. + max_value = *std::max_element(maxes_.begin(), maxes_.end()); + } + float* scratch_ptr = scratch->data(); + std::uniform_real_distribution dist; + float sum = 0.0f; +#if defined __AVX2__ + sum = ApplyExpAndSum::value>(max_value, scratch_ptr); +#else + int clip_limit = max_value - (80 << MantissaBitsOf::value); + for (int i = 0; i < size_; ++i) { + int difference = std::max(data_[i].raw_val(), clip_limit) - max_value; + float exponent = expf(static_cast(DataType(difference))); + scratch_ptr[i] = exponent; + sum += exponent; + } +#endif // __AVX2__ + + float random_target = dist(*gen) * sum; + int start = 0; + int end = size_; + +#if defined __AVX2__ + FindSamplePoint(scratch_ptr, &random_target, &start, &end); + // The scalar code finishes the job from here... +#endif // __AVX2__ + float cumsum = 0.f; + for (std::size_t i = start; i < end; i++) { + cumsum += scratch_ptr[i]; + if (cumsum >= random_target) return i; + } + return end - 1; + } + + template + typename std::enable_if::value, void>::type Exp() { +#if defined __aarch64__ + DCHECK(size_ % 16 == 0) << "CacheAlignedVector size must be a multiple of " + "16 to allow for maximum SIMD and loop unroll " + "got " + << size_ % 16; + constexpr int kUnrollFactor = 4; + constexpr int kElementsPerIter = kUnrollFactor * kSIMDWidth; + for (std::size_t i = 0; i < size_; i += kElementsPerIter) { + float32x4_t x = vld1q_f32(data_ + i); + float32x4_t x1 = vld1q_f32(data_ + i + 4); + float32x4_t x2 = vld1q_f32(data_ + i + 8); + float32x4_t x3 = vld1q_f32(data_ + i + 12); + + vst1q_f32(data_ + i, fast_exp(x)); + vst1q_f32(data_ + i + 4, fast_exp(x1)); + vst1q_f32(data_ + i + 8, fast_exp(x2)); + vst1q_f32(data_ + i + 12, fast_exp(x3)); + } +#else + for (int i = 0; i < size_; ++i) { + data_[i] = expf(data_[i]); + } +#endif // defined __aarch64__ + } + + template + typename std::enable_if::value, void>::type Sigmoid() { +#if defined __aarch64__ + DCHECK(size_ % 8 == 0) << "CacheAlignedVector size must be a multiple of " + "8 to allow for maximum SIMD and loop unroll " + "got " + << size_ % 8; + constexpr int kUnrollFactor = 2; + constexpr int kElementsPerIter = kUnrollFactor * kSIMDWidth; + for (std::size_t i = 0; i < size_; i += kElementsPerIter) { + float32x4_t x = vld1q_f32(data_ + i); + float32x4_t x1 = vld1q_f32(data_ + i + 4); + + vst1q_f32(data_ + i, fast_sigmoid(x)); + vst1q_f32(data_ + i + 4, fast_sigmoid(x1)); + } +#else + for (int i = 0; i < size_; ++i) { + data_[i] = 1.f / (1.f + expf(-data_[i])); + } +#endif // defined __aarch64__ + } + + template + typename std::enable_if< + IsFixed32Type::value && IsFixed32Type::value, void>::type + // For benchmarking only. + Sigmoid(const int32_t* sigmoid_table, CacheAlignedVector* result) { +#if defined __AVX2__ + for (int i = 0; i < size_; i += kSIMDWidth) { + __m256i x_in = _mm256_loadu_si256(reinterpret_cast<__m256i*>(data_ + i)); + __m256i output = fixed32_sigmoid_fixed16::value, + MantissaBitsOf::value>( + sigmoid_table, x_in); + _mm256_store_si256(reinterpret_cast<__m256i*>(result->data() + i), + output); + } +#else + for (int i = 0; i < size_; ++i) { + result->data()[i] = 1.f / (1.f + expf(-data_[i])); + } +#endif // defined __AVX2__ + } + + template + typename std::enable_if::value, void>::type Tanh() { +#if defined __aarch64__ + DCHECK(size_ % 8 == 0) << "CacheAlignedVector size must be a multiple of " + "8 to allow for maximum SIMD and loop unroll " + "got " + << size_ % 8; + constexpr int kUnrollFactor = 2; + constexpr int kElementsPerIter = kUnrollFactor * kSIMDWidth; + for (std::size_t i = 0; i < size_; i += kElementsPerIter) { + float32x4_t x = vld1q_f32(data_ + i); + float32x4_t x1 = vld1q_f32(data_ + i + 4); + + vst1q_f32(data_ + i, fast_tanh(x)); + vst1q_f32(data_ + i + 4, fast_tanh(x1)); + } +#else + for (int i = 0; i < size_; ++i) { + data_[i] = tanhf(data_[i]); + } +#endif // defined __aarch64__ + } + + template + typename std::enable_if< + IsFixed32Type::value && IsFixed32Type::value, void>::type + // For benchmarking only + Tanh(const int32_t* tanh_table, CacheAlignedVector* result) { +#if defined __AVX2__ + for (int i = 0; i < size_; i += kSIMDWidth) { + __m256i x_in = _mm256_loadu_si256(reinterpret_cast<__m256i*>(data_ + i)); + __m256i output = + fixed32_tanh_fixed16::value, + MantissaBitsOf::value>(tanh_table, x_in); + _mm256_store_si256(reinterpret_cast<__m256i*>(result->data() + i), + output); + } +#else + for (int i = 0; i < size_; ++i) { + result->data()[i] = tanhf(data_[i]); + } +#endif // defined __AVX2__ + } + + // Returns |data_| cast to the correct integer type if fixed point. + template + typename std::enable_if::value, const int32_t*>::type + cast_data() const { + return reinterpret_cast(data_); + } + template + typename std::enable_if::value, const int16_t*>::type + cast_data() const { + return reinterpret_cast(data_); + } + template + typename std::enable_if::value || IsFixed16Type::value), + const Q*>::type + cast_data() const { + return data_; + } + const DataType* begin() const { return data_; } + const DataType* end() const { return data_ + size_; } + const DataType* data() const { return data_; } + DataType* data() { return data_; } + + const DataType& operator[](int pos) const { return data_[pos]; } + DataType& operator[](int pos) { return data_[pos]; } + + std::size_t size() const { return size_; } + bool empty() const { return size_ == 0; } + std::size_t bytes() const { return size_ * sizeof(DataType); } + + int rows() const { return size_; } + int cols() const { return 1; } + + // Stride to get to move over by one column (which is the number of rows). + int col_stride() const { return size_; } + + void Print() const { + for (int i = 0; i < size(); ++i) + absl::PrintF("[%d]=%g\n", i, static_cast(data_[i])); + } + + float maximum() const { + float max_val = std::numeric_limits::lowest(); + for (int i = 0; i < size_; ++i) { + max_val = std::max(max_val, std::abs(static_cast(data_[i]))); + } + + return max_val; + } + + private: + void resize(std::size_t size) { + aligned_free(data_); + size_ = size; + data_ = reinterpret_cast( + aligned_malloc(size_ * sizeof(DataType), kCacheLineSize)); + } + + std::size_t size_; + DataType* data_; + // Data used by the threaded version for sampling only. + std::vector maxes_; // Max value of logits. + std::vector thread_starts_; // First index for this thread. +#if defined __AVX__ || defined __AVX2__ + static constexpr int kCacheLineSize = 64; + static constexpr int kSIMDWidth = 8; +#else + static constexpr int kCacheLineSize = 128; + static constexpr int kSIMDWidth = 4; +#endif // __AVX__ + std::unique_ptr gen_; +}; + +// Used for doing Sparse Matrix * Dense Matrix multiplication. This class is +// not intended to be a general Matrix class, just for the RHS of a SpMM, hence +// the name fat vector rather than Matrix. The data layout is COLUMN MAJOR. +template +class FatCacheAlignedVector { + public: + using value_type = T; + + FatCacheAlignedVector() : rows_(0), cols_(0) {} + + // Creates a new vector that is (rows, cols), doesn't init memory. + FatCacheAlignedVector(int rows, int cols) + : vector_(rows * cols), rows_(rows), cols_(cols) {} + + // Copies and reshapes vector from (1, size) to (|rows|, size / |rows|). + FatCacheAlignedVector(const CacheAlignedVector& vector, int rows) + : vector_(vector), rows_(rows) { + CHECK_EQ(vector_.size() % rows_, 0); + cols_ = vector_.size() / rows_; + } + + template + explicit FatCacheAlignedVector(const FatCacheAlignedVector& vector) + : vector_(vector.size()), rows_(vector.rows()), cols_(vector.cols()) { + for (int i = 0; i < vector.size(); ++i) { + vector_[i] = static_cast(vector[i]); + } + } + + // Moves and reshapes vector from (1, size) to (|rows|, size / |rows|) + FatCacheAlignedVector(CacheAlignedVector&& vector, int rows) + : vector_(vector), rows_(rows) { + CHECK_EQ(vector_.size() % rows_, 0); + cols_ = vector_.size() / rows_; + } + + VectorView slice(const int col) const { + return VectorView(this->data() + rows() * col, rows(), 1); + } + MutableVectorView slice(const int col) { + return MutableVectorView(this->data() + rows() * col, rows(), 1); + } + + const T* data() const { return vector_.data(); } + T* data() { return vector_.data(); } + // Returns |data_| cast to the correct integer type if fixed point. + template + typename std::enable_if::value, const int32_t*>::type + cast_data() const { + return vector_.cast_data(); + } + template + typename std::enable_if::value, const int16_t*>::type + cast_data() const { + return vector_.cast_data(); + } + template + typename std::enable_if::value || IsFixed16Type::value), + const Q*>::type + cast_data() const { + return vector_.cast_data(); + } + + int rows() const { return rows_; } + int cols() const { return cols_; } + int size() const { return rows_ * cols_; } + bool empty() const { return rows_ == 0 || cols_ == 0; } + std::size_t bytes() const { return vector_.bytes(); } + + void reshape(int rows, int cols) { + CHECK_EQ(rows * cols, rows_ * cols_); + rows_ = rows; + cols_ = cols; + } + + float maximum() const { return vector_.maximum(); } + + // Stride to get to move over by one column (which is the number of rows). + int col_stride() const { return rows_; } + + void FillOnes() { vector_.FillOnes(); } + void FillZero() { vector_.FillZero(); } + void FillRandom(float min = -10.f, float max = 10.f) { + vector_.FillRandom(min, max); + } + + const T& operator[](int pos) const { return vector_[pos]; } + T& operator[](int pos) { return vector_[pos]; } + + private: + CacheAlignedVector vector_; + int rows_; + int cols_; +}; + +// View into a 2D Matrix. Currently only supports partitions by row. This is +// expected to be used with underlying data that is COLUMN MAJOR. +template +class MutableVectorView { + public: + using value_type = T; + + // Construct from a raw pointer, |rows|, |cols| and |col_stride|. + // |col_stride| will default to |rows| if not specified. + explicit MutableVectorView(T* data = nullptr, int rows = 0, int cols = 0, + int col_stride = 0) + : data_(data), + rows_(rows), + cols_(cols), + col_stride_(col_stride > 0 ? col_stride : rows) {} + + // Construct from a CacheAlignedVector, must have one column, can optionally + // specify an offset and row count. + explicit MutableVectorView(CacheAlignedVector* vector) + : MutableVectorView(vector->data(), vector->rows(), 1) {} + + explicit MutableVectorView(CacheAlignedVector* vector, int pos = 0, + int rows = 0) + : MutableVectorView(vector->data() + pos, + rows == 0 ? vector->rows() - pos : rows, 1, + vector->rows()) {} + + // Construct from a FatCacheAlignedVector, can optionally specify an offset, + // and row count. Views that have fewer columns than the original are not + // supported. + explicit MutableVectorView(FatCacheAlignedVector* vector) + : MutableVectorView(vector->data(), vector->rows(), vector->cols()) {} + + MutableVectorView(FatCacheAlignedVector* vector, int pos, int rows) + : MutableVectorView(vector->data() + pos, rows, vector->cols(), + vector->rows()) {} + + T* data() { return data_; } + const T* data() const { return data_; } + + // Returns |data_| cast to the correct integer type if fixed point. + template + typename std::enable_if::value, const int32_t*>::type + cast_data() const { + return reinterpret_cast(data_); + } + template + typename std::enable_if::value, const int16_t*>::type + cast_data() const { + return reinterpret_cast(data_); + } + template + typename std::enable_if::value || IsFixed16Type::value), + const Q*>::type + cast_data() const { + return data_; + } + + // Number of columns in the underlying (Fat)CacheAlignedVector. + int cols() const { return cols_; } + + // Number of rows in this view. + int rows() const { return rows_; } + + // Returns true if there's nothing in the MutableVectorView. + bool empty() const { return rows_ == 0 || cols_ == 0; } + + // Stride to get to the next column (usually the number of rows in the + // underlying data structure). + int col_stride() const { return col_stride_; } + + // Returns the total number of bytes that are "owned" by this view. Uses + // cols and not col_stride. + std::size_t bytes() const { return rows_ * cols_ * sizeof(T); } + + void reshape(int rows, int cols) { + CHECK_EQ(rows * cols, rows_ * cols_); + rows_ = rows; + cols_ = cols; + col_stride_ = rows_; + } + + const T& operator[](int pos) const { return data_[pos]; } + T& operator[](int pos) { return data_[pos]; } + + protected: + T* data_; + int rows_; + int cols_; + int col_stride_; +}; + +// Specialization of MutableVectorView which is read-only. +template +class VectorView : public MutableVectorView { + public: + using value_type = T; + + explicit VectorView(const MutableVectorView& other) + : MutableVectorView(other.data(), other.rows(), other.cols(), + other.col_stride()) {} + + // Construct from a raw pointer, |rows|, |cols| and |col_stride|. + // |col_stride| will default to |rows| if not specified. + explicit VectorView(const T* data = nullptr, int rows = 0, int cols = 0, + int col_stride = 0) + : MutableVectorView(data, rows, cols, col_stride) {} + + // Construct from a CacheAlignedVector, must have one column, can optionally + // specify an offset and row count + explicit VectorView(const CacheAlignedVector& vector) + : MutableVectorView(vector.data(), vector.rows(), 1) {} + + explicit VectorView(const CacheAlignedVector& vector, int pos = 0, + int rows = 0) + : MutableVectorView(vector.data() + pos, + rows == 0 ? vector.rows() - pos : rows, 1, + vector.rows()) {} + + // Construct from a FatCacheAlignedVector, can optionally specify an offset, + // and row count. Views that have fewer columns than the original are not + // supported. + explicit VectorView(const FatCacheAlignedVector& vector) + : MutableVectorView(vector.data(), vector.rows(), + vector.cols()) {} + + VectorView(const FatCacheAlignedVector& vector, int pos, int rows) + : MutableVectorView(vector.data() + pos, rows, vector.cols(), + vector.rows()) {} + + VectorView& operator=(const MutableVectorView& other) { + this->data_ = other.data(); + this->rows_ = other.rows(); + this->cols_ = other.cols(); + this->col_stride_ = other.col_stride(); + return *this; + } +}; + +} // namespace csrblocksparse +#endif // LYRA_CODEC_SPARSE_MATMUL_VECTOR_CACHE_ALIGNED_VECTOR_H_ diff --git a/sparse_matmul/vector/cachealignedvector_benchmark.cc b/sparse_matmul/vector/cachealignedvector_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..9141e2d570884101b286e059d3ca358b643cc376 --- /dev/null +++ b/sparse_matmul/vector/cachealignedvector_benchmark.cc @@ -0,0 +1,60 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "benchmark/benchmark.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" + +// A simple benchmark for CacheAlignedVector. +// +// Running on x86: +// As written, it's not representative of x86 performance since ReducingSample +// is used on x86 and not Sample. +// +// Running on arm64: +// bazel build -c opt --dynamic_mode=off --copt=-gmlt \ +// --copt=-DUSE_FIXED32 --config=android_arm64 \ +// sparse_matmul/vector:cachealignedvector_benchmark +namespace csrblocksparse { + +#ifdef USE_BFLOAT16 +using ComputeType = csrblocksparse::bfloat16; +#elif defined USE_FIXED32 +using ComputeType = csrblocksparse::fixed32<11>; // kGruMatMulOutBits +#else +using ComputeType = float; +#endif // USE_BFLOAT16 + +#if defined(USE_FIXED32) && defined(__aarch64__) +using ScratchType = int; +#else +using ScratchType = float; +#endif // defined(USE_FIXED32) && defined(__aarch64__) + +void BM_Sample(benchmark::State& state) { + constexpr int kVectorSize = 16384; // A large vector. + std::minstd_rand generator; + + CacheAlignedVector values(kVectorSize); + CacheAlignedVector scratch(kVectorSize); + values.FillRandom(); + + for (auto _ : state) { + values.Sample(/*temperature=*/0.98f, &generator, &scratch); + } +} +BENCHMARK(BM_Sample); + +} // namespace csrblocksparse diff --git a/sparse_matmul/vector/cachealignedvector_test.cc b/sparse_matmul/vector/cachealignedvector_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..245c64d7dc3de69e5e37c4445c9ce4c599b28ab0 --- /dev/null +++ b/sparse_matmul/vector/cachealignedvector_test.cc @@ -0,0 +1,405 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/vector/cache_aligned_vector.h" + +#if defined __aarch64__ +#include +#endif + +#include + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "sparse_matmul/numerics/test_utils.h" +#include "sparse_matmul/os/coop_threads.h" + +namespace csrblocksparse { + +const float kExpRelTolerance = .03f; // 3% relative +#ifdef SIGMOID_AS_TANH +const float kSigmoidRelTolerance = .09f; // 9.0% relative +const float kSigmoidAbsTolerance = .003f; +#else +const float kSigmoidRelTolerance = .031f; // 3.1% relative +const float kSigmoidAbsTolerance = .006f; +#endif +const float kTanhRelTolerance = .014f; // 1.4% relative +const float kTanhAbsTolerance = .00525f; + +TEST(Transcendentals, CacheAlignedVectorExp) { + const int kTestSize = 1 << 16; + CacheAlignedVector values(kTestSize); + values.FillRandom(); + CacheAlignedVector values_ref = values; + + values.Exp(); + for (int i = 0; i < kTestSize; ++i) { + float exact_val = std::exp(values_ref[i]); + float rel_diff = RelDiff(exact_val, values[i]); + + EXPECT_LT(rel_diff, kExpRelTolerance) + << exact_val << " " << values[i] << " " << values_ref[i]; + } +} + +TEST(Transcendentals, CacheAlignedVectorSigmoid) { + const int kTestSize = 1 << 16; + CacheAlignedVector values(kTestSize); + values.FillRandom(); + CacheAlignedVector values_ref = values; + + values.Sigmoid(); + for (int i = 0; i < kTestSize; ++i) { + float exact_val = 1. / (1. + std::exp(-values_ref[i])); + float rel_diff = RelDiff(exact_val, values[i]); + + EXPECT_LT(rel_diff, kSigmoidRelTolerance) + << exact_val << " " << values[i] << " " << values_ref[i]; + EXPECT_NEAR(values[i], exact_val, kSigmoidAbsTolerance) << values_ref[i]; + } +} + +TEST(Transcendentals, CacheAlignedVectorTanh) { + const int kTestSize = 1 << 16; + CacheAlignedVector values(kTestSize); + values.FillRandom(); + CacheAlignedVector values_ref = values; + + values.Tanh(); + for (int i = 0; i < kTestSize; ++i) { + float exact_val = std::tanh(values_ref[i]); + float rel_diff = RelDiff(exact_val, values[i]); + + EXPECT_LT(rel_diff, kTanhRelTolerance) + << exact_val << " " << values[i] << " " << values_ref[i]; + EXPECT_NEAR(values[i], exact_val, kTanhAbsTolerance) << values_ref[i]; + } +} + +// Uniformly sample logits and check that the resulting sample choices are +// also (nearly) uniformly distributed. +TEST(Sampling, Random) { + const int kSize = 256; + + CacheAlignedVector logits(kSize); + logits.FillZero(); + + double histogram[kSize] = {}; + + const int kIterations = 10000; + for (int i = 0; i < kIterations; ++i) { + histogram[logits.Sample()]++; + } + + for (int i = 0; i < kSize; ++i) { + // .002 is an empirical bound + EXPECT_GT(histogram[i] / kIterations, 1. / kSize - .002f); + EXPECT_LT(histogram[i] / kIterations, 1. / kSize + .002f); + } +} + +// Put (nearly) all the probability mass on one bin and make sure only that bin +// is chosen. +TEST(Sampling, FixedDistribution) { + const int kSize = 256; + + CacheAlignedVector logits(kSize); + + int histogram[kSize] = {}; + + const int kIterations = 1000; + const int kIndex = 3; + const int kAllProbabilityMass = 10; + const int kNoProbabilityMass = -10; + for (int i = 0; i < kIterations; ++i) { + for (int i = 1; i <= kSize; ++i) { + logits.data()[i - 1] = + i == (kIndex + 1) ? kAllProbabilityMass : kNoProbabilityMass; + } + + histogram[logits.Sample()]++; + } + + EXPECT_EQ(histogram[kIndex], 1000); +} + +// Put (nearly) all the probability mass on one bin outside the target range, +// and make sure that bin is not chosen. +TEST(ScalarSample, ThreadedMasked) { + const int kSize = 256; + const int mindex = 2; + const int maxdex = 3; + const int kNumThreads = 4; + const int kIterations = 1000; + const int kIndex = 3; + const int kMostProbabilityMass = 3; + const int kLittleProbabilityMass = -3; + + CacheAlignedVector logits(kSize); + std::vector> tmp_vectors; + std::vector generators(kNumThreads); + + for (int i = 0; i < kNumThreads; ++i) { + tmp_vectors.emplace_back(kSize); + } + + for (int i = 0; i < kSize; ++i) { + logits.data()[i] = + (i + 1) == (kIndex + 1) ? kMostProbabilityMass : kLittleProbabilityMass; + } + + std::vector> histograms; + for (int i = 0; i < kNumThreads; ++i) { + histograms.emplace_back(kSize); + } + + auto f = [&](csrblocksparse::SpinBarrier* /*barrier*/, int tid) { + for (int i = 0; i < kIterations; ++i) { + histograms[tid][logits.ScalarSample( + 1.f, &generators[tid], &tmp_vectors[tid], 0, mindex, maxdex)]++; + } + }; + + csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, f); + + // Every thread should generate the exact same set of samples. + for (int i = 0; i < kSize; ++i) { + int val = histograms[0][i]; + for (int tid = 1; tid < kNumThreads; ++tid) { + EXPECT_EQ(val, histograms[tid][i]); + } + } + + // The most probable sample should be the only one we're sampling. + for (int tid = 0; tid < kNumThreads; ++tid) { + EXPECT_EQ(std::distance(histograms[tid].begin(), + std::max_element(histograms[tid].begin(), + histograms[tid].end())), + mindex); + } +} + +TEST(Sampling, Threaded) { + const int kSize = 256; + const int kNumThreads = 4; + const int kIterations = 1000; + const int kIndex = 3; + const int kMostProbabilityMass = 3; + const int kLittleProbabilityMass = -3; + + CacheAlignedVector logits(kSize); + std::vector> tmp_vectors; + std::vector generators(kNumThreads); + + for (int i = 0; i < kNumThreads; ++i) { + tmp_vectors.emplace_back(kSize); + } + + for (int i = 0; i < kSize; ++i) { + logits.data()[i] = + (i + 1) == (kIndex + 1) ? kMostProbabilityMass : kLittleProbabilityMass; + } + + std::vector> histograms; + for (int i = 0; i < kNumThreads; ++i) { + histograms.emplace_back(kSize); + } + + auto f = [&](csrblocksparse::SpinBarrier* /*barrier*/, int tid) { + for (int i = 0; i < kIterations; ++i) { + histograms[tid] + [logits.Sample(1.f, &generators[tid], &tmp_vectors[tid])]++; + } + }; + + csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, f); + + // Every thread should generate the exact same set of samples. + for (int i = 0; i < kSize; ++i) { + int val = histograms[0][i]; + for (int tid = 1; tid < kNumThreads; ++tid) { + EXPECT_EQ(val, histograms[tid][i]); + } + } + + // The most probable sample should be the one with the most probability mass. + for (int tid = 0; tid < kNumThreads; ++tid) { + EXPECT_EQ(std::distance(histograms[tid].begin(), + std::max_element(histograms[tid].begin(), + histograms[tid].end())), + kIndex); + } +} + +void CreateVectorHelper( + csrblocksparse::FatCacheAlignedVector* fat_vector, int cols, + int rows, std::unique_ptr>* view) { + *view = absl::make_unique>(*fat_vector, + cols, rows); +} + +void CreateVectorHelper( + csrblocksparse::FatCacheAlignedVector* fat_vector, int cols, + int rows, std::unique_ptr>* view) { + *view = absl::make_unique>( + fat_vector, cols, rows); +} + +csrblocksparse::FatCacheAlignedVector CreateFatAlignedVector(int rows, + int cols) { + csrblocksparse::FatCacheAlignedVector fat_vector(rows, cols); + // Usage intent of FatCacheAlignedVector is that they are COLUMN MAJOR. + float v = 0; + for (int c = 0; c < cols; ++c) { + for (int r = 0; r < rows; ++r) { + fat_vector.data()[c * rows + r] = v++; + } + } + + return fat_vector; +} + +template +void TestFatVectorView() { + const int kRows = 6; + const int kCols = 6; + auto fat_vector = CreateFatAlignedVector(kRows, kCols); + + std::unique_ptr top; + CreateVectorHelper(&fat_vector, 0, kRows / 2, &top); + std::unique_ptr bottom; + CreateVectorHelper(&fat_vector, kRows / 2, kRows / 2, &bottom); + + EXPECT_EQ(top->cols(), kCols); + EXPECT_EQ(bottom->cols(), kCols); + EXPECT_EQ(top->rows(), kRows / 2); + EXPECT_EQ(bottom->rows(), kRows / 2); + EXPECT_EQ(top->col_stride(), kRows); + EXPECT_EQ(bottom->col_stride(), kRows); + + for (int c = 0; c < kCols; ++c) { + for (int r = 0; r < kRows; ++r) { + if (r < kRows / 2) { + EXPECT_EQ(fat_vector[c * kRows + r], + top->data()[c * top->col_stride() + r]); + } else { + EXPECT_EQ(fat_vector[c * kRows + r], + bottom->data()[c * top->col_stride() + r - kRows / 2]); + } + } + } +} + +TEST(FatVector, View) { + TestFatVectorView>(); +} +TEST(FatVector, MutableView) { + TestFatVectorView>(); +} + +TEST(FatVector, SliceMutableView) { + const int kRows = 6; + const int kCols = 3; + auto fat_vector = CreateFatAlignedVector(kRows, kCols); + + int c = 1; + csrblocksparse::MutableVectorView slice = fat_vector.slice(c); + for (int r = 0; r < kRows; ++r) { + EXPECT_EQ(slice[r], c * kRows + r); + } +} + +TEST(FatVector, SliceConstView) { + const int kRows = 6; + const int kCols = 3; + auto fat_vector = CreateFatAlignedVector(kRows, kCols); + + int c = 1; + csrblocksparse::VectorView const_slice; + { + // Take a VectorView from a non-const slice. + const_slice = fat_vector.slice(c); + for (int r = 0; r < kRows; ++r) { + EXPECT_EQ(const_slice[r], c * kRows + r); + } + } + + { + // Take a VectorView from a const slice. + const auto& const_fat_vector = fat_vector; + const_slice = const_fat_vector.slice(c); + for (int r = 0; r < kRows; ++r) { + EXPECT_EQ(const_slice[r], c * kRows + r); + } + } +} + +TEST(View, FromMutableToConst) { + const int kRows = 6; + const int kCols = 3; + auto fat_vector = CreateFatAlignedVector(kRows, kCols); + csrblocksparse::MutableVectorView slice = fat_vector.slice(0); + + csrblocksparse::VectorView const_slice(slice); + for (int r = 0; r < kRows; ++r) { + EXPECT_EQ(const_slice[r], r); + } +} + +TEST(View, CopyTest) { + const int kRows = 6; + const int kCols = 3; + auto fat_vector = CreateFatAlignedVector(kRows, kCols); + csrblocksparse::MutableVectorView slice = fat_vector.slice(0); + csrblocksparse::MutableVectorView slice2(slice); + + for (int r = 0; r < kRows; ++r) { + EXPECT_EQ(slice2[r], r); + } +} + +TEST(Vector, CopyNull) { + // Check that we can copy a vector with a null generator without segfault. + CacheAlignedVector foo((CacheAlignedVector())); + // This is here to prevent foo from being optimized out. + CHECK_EQ(foo.size(), 0); + CacheAlignedVector foo_bar = CacheAlignedVector(); + CHECK_EQ(foo_bar.size(), 0); +} + +TEST(Vector, FromRawPointer) { + std::vector input; + for (int i = 0; i < 5; ++i) { + input.push_back(i * 2); + } + + // Calls first constructor. + CacheAlignedVector foo(input.data(), 5); + CHECK_EQ(foo.size(), 5); + EXPECT_THAT(input, testing::ElementsAreArray(foo.data(), 5)); + + // Calls the second constructor. + CacheAlignedVector foo2(input.data(), 5); + CHECK_EQ(foo2.size(), 5); + EXPECT_THAT(input, testing::ElementsAreArray(foo2.data(), 5)); +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/zlib_wrapper/BUILD b/sparse_matmul/zlib_wrapper/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..b9653dab0a32ff71841b40a9768ba4c57e897bbf --- /dev/null +++ b/sparse_matmul/zlib_wrapper/BUILD @@ -0,0 +1,20 @@ +licenses(["notice"]) + +cc_library( + name = "zlib_wrapper", + srcs = [ + "gzipheader.cc", + "zlibwrapper.cc", + ], + hdrs = [ + "gzipheader.h", + "zlibwrapper.h", + ], + visibility = ["//:__subpackages__"], + deps = [ + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_glog//:glog", + "@zlib", + ], +) diff --git a/sparse_matmul/zlib_wrapper/gzipheader.cc b/sparse_matmul/zlib_wrapper/gzipheader.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8d5c3ca26883106f791652f338caa4ae85b6386 --- /dev/null +++ b/sparse_matmul/zlib_wrapper/gzipheader.cc @@ -0,0 +1,190 @@ +// Copyright 2002 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Author: Neal Cardwell +// + +#include "sparse_matmul/zlib_wrapper/gzipheader.h" + +#include + +#include "absl/base/macros.h" +#include "glog/logging.h" +#include "zlib.h" // for Z_DEFAULT_COMPRESSION + +namespace csrblocksparse { + +const uint8_t GZipHeader::magic[] = {0x1f, 0x8b}; + +// ---------------------------------------------------------------------- +// GZipHeader::ReadMore() +// Attempt to parse the beginning of the given buffer as a gzip +// header. If these bytes do not constitute a complete gzip header, +// return INCOMPLETE_HEADER. If these bytes do not constitute a +// *valid* gzip header, return INVALID_HEADER. If we find a +// complete header, return COMPLETE_HEADER and set the pointer +// pointed to by header_end to the first byte beyond the gzip header. +// ---------------------------------------------------------------------- + +GZipHeader::Status GZipHeader::ReadMore(const char* inbuf, int inbuf_len, + const char** header_end) { + CHECK_GE(inbuf_len, 0); + const uint8_t* pos = reinterpret_cast(inbuf); + const uint8_t* const end = pos + inbuf_len; + + while (pos < end) { + switch (state_) { + case IN_HEADER_ID1: + if (*pos != magic[0]) return INVALID_HEADER; + pos++; + state_++; + break; + case IN_HEADER_ID2: + if (*pos != magic[1]) return INVALID_HEADER; + pos++; + state_++; + break; + case IN_HEADER_CM: + if (*pos != Z_DEFLATED) return INVALID_HEADER; + pos++; + state_++; + break; + case IN_HEADER_FLG: + flags_ = + (*pos) & (FLAG_FHCRC | FLAG_FEXTRA | FLAG_FNAME | FLAG_FCOMMENT); + pos++; + state_++; + break; + + case IN_HEADER_MTIME_BYTE_0: + pos++; + state_++; + break; + case IN_HEADER_MTIME_BYTE_1: + pos++; + state_++; + break; + case IN_HEADER_MTIME_BYTE_2: + pos++; + state_++; + break; + case IN_HEADER_MTIME_BYTE_3: + pos++; + state_++; + break; + + case IN_HEADER_XFL: + pos++; + state_++; + break; + + case IN_HEADER_OS: + pos++; + state_++; + break; + + case IN_XLEN_BYTE_0: + if (!(flags_ & FLAG_FEXTRA)) { + state_ = IN_FNAME; + break; + } + // We have a two-byte little-endian length, followed by a + // field of that length. + extra_length_ = *pos; + pos++; + state_++; + break; + case IN_XLEN_BYTE_1: + extra_length_ += *pos << 8; + pos++; + state_++; + // If we have a zero-length FEXTRA, we want to check to notice that + // we're done reading the FEXTRA before we exit this loop... + ABSL_FALLTHROUGH_INTENDED; + + case IN_FEXTRA: { + // Grab the rest of the bytes in the extra field, or as many + // of them as are actually present so far. + const int num_extra_bytes = std::min(extra_length_, (end - pos)); + pos += num_extra_bytes; + extra_length_ -= num_extra_bytes; + if (extra_length_ == 0) { + state_ = IN_FNAME; // advance when we've seen extra_length_ bytes + flags_ &= ~FLAG_FEXTRA; // we're done with the FEXTRA stuff + } + break; + } + + case IN_FNAME: + if (!(flags_ & FLAG_FNAME)) { + state_ = IN_FCOMMENT; + break; + } + // See if we can find the end of the \0-terminated FNAME field. + pos = reinterpret_cast(memchr(pos, '\0', (end - pos))); + if (pos != nullptr) { + pos++; // advance past the '\0' + flags_ &= ~FLAG_FNAME; // we're done with the FNAME stuff + state_ = IN_FCOMMENT; + } else { + pos = end; // everything we have so far is part of the FNAME + } + break; + + case IN_FCOMMENT: + if (!(flags_ & FLAG_FCOMMENT)) { + state_ = IN_FHCRC_BYTE_0; + break; + } + // See if we can find the end of the \0-terminated FCOMMENT field. + pos = reinterpret_cast(memchr(pos, '\0', (end - pos))); + if (pos != nullptr) { + pos++; // advance past the '\0' + flags_ &= ~FLAG_FCOMMENT; // we're done with the FCOMMENT stuff + state_ = IN_FHCRC_BYTE_0; + } else { + pos = end; // everything we have so far is part of the FNAME + } + break; + + case IN_FHCRC_BYTE_0: + if (!(flags_ & FLAG_FHCRC)) { + state_ = IN_DONE; + break; + } + pos++; + state_++; + break; + + case IN_FHCRC_BYTE_1: + pos++; + flags_ &= ~FLAG_FHCRC; // we're done with the FHCRC stuff + state_++; + break; + + case IN_DONE: + *header_end = reinterpret_cast(pos); + return COMPLETE_HEADER; + } + } + + if ((state_ > IN_HEADER_OS) && (flags_ == 0)) { + *header_end = reinterpret_cast(pos); + return COMPLETE_HEADER; + } else { + return INCOMPLETE_HEADER; + } +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/zlib_wrapper/gzipheader.h b/sparse_matmul/zlib_wrapper/gzipheader.h new file mode 100644 index 0000000000000000000000000000000000000000..21cd71e435a215dea631389accc9d8a206a53019 --- /dev/null +++ b/sparse_matmul/zlib_wrapper/gzipheader.h @@ -0,0 +1,107 @@ +// +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_ZLIB_GZIPHEADER_H +#define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_ZLIB_GZIPHEADER_H + +// The GZipHeader class allows you to parse a gzip header, such as you +// might find at the beginning of a file compressed by gzip (ie, a .gz +// file), or at the beginning of an HTTP response that uses a gzip +// Content-Encoding. See RFC 1952 for the specification for the gzip +// header. +// +// The model is that you call ReadMore() for each chunk of bytes +// you've read from a file or socket. +// + +#include + +namespace csrblocksparse { + +class GZipHeader { + public: + GZipHeader() { Reset(); } + ~GZipHeader() {} + + // Wipe the slate clean and start from scratch. + void Reset() { + state_ = IN_HEADER_ID1; + flags_ = 0; + extra_length_ = 0; + } + + enum Status { + INCOMPLETE_HEADER, // don't have all the bits yet... + COMPLETE_HEADER, // complete, valid header + INVALID_HEADER, // found something invalid in the header + }; + + // Attempt to parse the given buffer as the next installment of + // bytes from a gzip header. If the bytes we've seen so far do not + // yet constitute a complete gzip header, return + // INCOMPLETE_HEADER. If these bytes do not constitute a *valid* + // gzip header, return INVALID_HEADER. When we've seen a complete + // gzip header, return COMPLETE_HEADER and set the pointer pointed + // to by header_end to the first byte beyond the gzip header. + Status ReadMore(const char* inbuf, int inbuf_len, const char** header_end); + + private: + // NOLINTNEXTLINE + static const uint8_t magic[]; // gzip magic header + + enum { // flags (see RFC) + FLAG_FTEXT = 0x01, // bit 0 set: file probably ascii text + FLAG_FHCRC = 0x02, // bit 1 set: header CRC present + FLAG_FEXTRA = 0x04, // bit 2 set: extra field present + FLAG_FNAME = 0x08, // bit 3 set: original file name present + FLAG_FCOMMENT = 0x10, // bit 4 set: file comment present + FLAG_RESERVED = 0xE0, // bits 5..7: reserved + }; + + enum State { + // The first 10 bytes are the fixed-size header: + IN_HEADER_ID1, + IN_HEADER_ID2, + IN_HEADER_CM, + IN_HEADER_FLG, + IN_HEADER_MTIME_BYTE_0, + IN_HEADER_MTIME_BYTE_1, + IN_HEADER_MTIME_BYTE_2, + IN_HEADER_MTIME_BYTE_3, + IN_HEADER_XFL, + IN_HEADER_OS, + + IN_XLEN_BYTE_0, + IN_XLEN_BYTE_1, + IN_FEXTRA, + + IN_FNAME, + + IN_FCOMMENT, + + IN_FHCRC_BYTE_0, + IN_FHCRC_BYTE_1, + + IN_DONE, + }; + + int state_; // our current State in the parsing FSM: an int so we can ++ + uint8_t flags_; // the flags byte of the header ("FLG" in the RFC) + uint16_t extra_length_; // how much of the "extra field" we have yet to read +}; + +} // namespace csrblocksparse + +#endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_ZLIB_GZIPHEADER_H diff --git a/sparse_matmul/zlib_wrapper/zlibwrapper.cc b/sparse_matmul/zlib_wrapper/zlibwrapper.cc new file mode 100644 index 0000000000000000000000000000000000000000..a3a2fa5e584354f8216428df199b027ebfac8d21 --- /dev/null +++ b/sparse_matmul/zlib_wrapper/zlibwrapper.cc @@ -0,0 +1,841 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sparse_matmul/zlib_wrapper/zlibwrapper.h" + +#include +#include + +#include +#include +#include + +#include "glog/logging.h" +#include "sparse_matmul/zlib_wrapper/gzipheader.h" +#include "zconf.h" +#include "zlib.h" + +// The GZIP header (see RFC 1952): +// +---+---+---+---+---+---+---+---+---+---+ +// |ID1|ID2|CM |FLG| MTIME |XFL|OS | +// +---+---+---+---+---+---+---+---+---+---+ +// ID1 \037 +// ID2 \213 +// CM \010 (compression method == DEFLATE) +// FLG \000 (special flags that we do not support) +// MTIME Unix format modification time (0 means not available) +// XFL 2-4? DEFLATE flags +// OS ???? Operating system indicator (255 means unknown) + +// Header value we generate: +// We use a #define so sizeof() works correctly +#define GZIP_HEADER "\037\213\010\000\000\000\000\000\002\377" + +namespace csrblocksparse { + +// We allow all kinds of bad footers when this flag is true. +// Some web servers send bad pages corresponding to these cases +// and IE is tolerant with it. +// - Extra bytes after gzip footer (see bug 69126) +// - No gzip footer (see bug 72896) +// - Incomplete gzip footer (see bug 71871706) +bool ZLib::should_be_flexible_with_gzip_footer_ = false; + +// Initialize the ZLib class +ZLib::ZLib() + : comp_init_(false), uncomp_init_(false), gzip_header_(new GZipHeader) { + Reinit(); + init_settings_ = settings_; +} + +ZLib::~ZLib() { + if (comp_init_) { + deflateEnd(&comp_stream_); + } + if (uncomp_init_) { + inflateEnd(&uncomp_stream_); + } + delete gzip_header_; +} + +void ZLib::Reinit() { + settings_.dictionary_ = nullptr; + settings_.dict_len_ = 0; + settings_.compression_level_ = Z_DEFAULT_COMPRESSION; + settings_.window_bits_ = MAX_WBITS; + settings_.mem_level_ = 8; // DEF_MEM_LEVEL + settings_.no_header_mode_ = false; + settings_.gzip_header_mode_ = false; + settings_.dont_hide_zstream_end_ = false; + + if (comp_init_) { + int err = deflateReset(&comp_stream_); + if (err != Z_OK) { + deflateEnd(&comp_stream_); + comp_init_ = false; + } + } + if (uncomp_init_) { + // Use negative window bits size to indicate bare stream with no header. + int wbits = (settings_.no_header_mode_ ? -MAX_WBITS : MAX_WBITS); + int err = inflateReset2(&uncomp_stream_, wbits); + if (err == Z_OK) { + init_settings_.no_header_mode_ = settings_.no_header_mode_; + } else { + inflateEnd(&uncomp_stream_); + uncomp_init_ = false; + } + } + crc_ = 0; + uncompressed_size_ = 0; + gzip_header_->Reset(); + gzip_footer_bytes_ = -1; + first_chunk_ = true; +} + +void ZLib::Reset() { + first_chunk_ = true; + gzip_header_->Reset(); +} + +void ZLib::CheckValidParams() { + if (settings_.dictionary_ != nullptr && + (settings_.no_header_mode_ || settings_.gzip_header_mode_)) { + LOG(FATAL) + << "Incompatible params: require zlib headers with preset dictionary"; + } +} + +void ZLib::SetNoHeaderMode(bool no_header_mode) { + settings_.no_header_mode_ = no_header_mode; + if (init_settings_.no_header_mode_ != settings_.no_header_mode_) { + // Once the header mode changes, we have to reinitialize all our streams + if (comp_init_) { + deflateEnd(&comp_stream_); + comp_init_ = false; + } + if (uncomp_init_) { + inflateEnd(&uncomp_stream_); + uncomp_init_ = false; + } + } else { + // Mode hasn't changed, but treat this as a reset request nevertheless + Reset(); + } + CheckValidParams(); +} + +void ZLib::SetGzipHeaderMode() { + settings_.gzip_header_mode_ = true; + SetNoHeaderMode(true); // we use gzip headers, not zlib headers + CheckValidParams(); +} + +void ZLib::SetDictionary(const char* initial_dict, unsigned int dict_len) { + settings_.dictionary_ = (Bytef*)initial_dict; // NOLINT + settings_.dict_len_ = dict_len; + CheckValidParams(); +} + +void ZLib::SetDontHideStreamEnd() { settings_.dont_hide_zstream_end_ = true; } + +int ZLib::MinFooterSize() const { + int min_footer_size = 2; // Room for empty chunk. + if (settings_.gzip_header_mode_) { + min_footer_size += 8; // Room for actual footer. + } + return min_footer_size; +} + +// --------- COMPRESS MODE + +// Initialization method to be called if we hit an error while +// compressing. On hitting an error, call this method before returning +// the error. +void ZLib::CompressErrorInit() { + if (comp_init_) { + deflateEnd(&comp_stream_); + comp_init_ = false; + } + Reset(); +} + +// These probably return Z_OK, but may return Z_BUF_ERROR if outbuf is full +int ZLib::WriteGzipHeader() { + if (comp_stream_.avail_out < sizeof(GZIP_HEADER)) return Z_BUF_ERROR; + memcpy(comp_stream_.next_out, GZIP_HEADER, sizeof(GZIP_HEADER) - 1); + comp_stream_.next_out += sizeof(GZIP_HEADER) - 1; + comp_stream_.avail_out -= sizeof(GZIP_HEADER) - 1; + return Z_OK; +} + +int ZLib::WriteGzipFooter(Bytef* dest, uLongf destLen) { + if (destLen < 8) // not enough space for footer + return Z_BUF_ERROR; + *dest++ = (crc_ >> 0) & 255; + *dest++ = (crc_ >> 8) & 255; + *dest++ = (crc_ >> 16) & 255; + *dest++ = (crc_ >> 24) & 255; + *dest++ = (uncompressed_size_ >> 0) & 255; + *dest++ = (uncompressed_size_ >> 8) & 255; + *dest++ = (uncompressed_size_ >> 16) & 255; + *dest++ = (uncompressed_size_ >> 24) & 255; + return Z_OK; +} + +int ZLib::DeflateInit() { + int err = + deflateInit2(&comp_stream_, settings_.compression_level_, Z_DEFLATED, + (settings_.no_header_mode_ ? -settings_.window_bits_ + : settings_.window_bits_), + settings_.mem_level_, Z_DEFAULT_STRATEGY); + if (err == Z_OK) { + // Save parameters for later reusability checks + init_settings_.compression_level_ = settings_.compression_level_; + init_settings_.window_bits_ = settings_.window_bits_; + init_settings_.mem_level_ = settings_.mem_level_; + init_settings_.no_header_mode_ = settings_.no_header_mode_; + } + return err; +} + +int ZLib::CompressInit(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen) { + int err; + + comp_stream_.next_in = (Bytef*)source; // NOLINT + comp_stream_.avail_in = (uInt)*sourceLen; + // Check for sourceLen (unsigned long) to fit into avail_in (unsigned int). + if ((uLong)comp_stream_.avail_in != *sourceLen) return Z_BUF_ERROR; + comp_stream_.next_out = dest; + comp_stream_.avail_out = (uInt)*destLen; + // Check for destLen (unsigned long) to fit into avail_out (unsigned int). + if ((uLong)comp_stream_.avail_out != *destLen) return Z_BUF_ERROR; + + if (!first_chunk_) // only need to set up stream the first time through + return Z_OK; + + // Force full reinit if properties have changed in a way we can't adjust. + if (comp_init_ && + (init_settings_.dictionary_ != settings_.dictionary_ || + init_settings_.dict_len_ != settings_.dict_len_ || + init_settings_.window_bits_ != settings_.window_bits_ || + init_settings_.mem_level_ != settings_.mem_level_ || + init_settings_.no_header_mode_ != settings_.no_header_mode_)) { + deflateEnd(&comp_stream_); + comp_init_ = false; + } + + // Reuse if we've already initted the object. + if (comp_init_) { // we've already initted it + err = deflateReset(&comp_stream_); + if (err != Z_OK) { + deflateEnd(&comp_stream_); + comp_init_ = false; + } + } + + // If compression level has changed, try to reconfigure instead of reinit + if (comp_init_ && + init_settings_.compression_level_ != settings_.compression_level_) { + err = deflateParams(&comp_stream_, settings_.compression_level_, + Z_DEFAULT_STRATEGY); + if (err == Z_OK) { + init_settings_.compression_level_ = settings_.compression_level_; + } else { + deflateEnd(&comp_stream_); + comp_init_ = false; + } + } + + // First use or previous state was not reusable with current settings. + if (!comp_init_) { + comp_stream_.zalloc = (alloc_func)0; + comp_stream_.zfree = (free_func)0; + comp_stream_.opaque = (voidpf)0; + err = DeflateInit(); + if (err != Z_OK) return err; + comp_init_ = true; + } + return Z_OK; +} + +// In a perfect world we'd always have the full buffer to compress +// when the time came, and we could just call Compress(). Alas, we +// want to do chunked compression on our webserver. In this +// application, we compress the header, send it off, then compress the +// results, send them off, then compress the footer. Thus we need to +// use the chunked compression features of zlib. +int ZLib::CompressAtMostOrAll(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen, + int flush_mode) { // Z_FULL_FLUSH or Z_FINISH + int err; + + if ((err = CompressInit(dest, destLen, source, sourceLen)) != Z_OK) + return err; + + // This is used to figure out how many bytes we wrote *this chunk* + int compressed_size = comp_stream_.total_out; + + // Some setup happens only for the first chunk we compress in a run + if (first_chunk_) { + // Append the gzip header before we start compressing + if (settings_.gzip_header_mode_) { + if ((err = WriteGzipHeader()) != Z_OK) return err; + compressed_size -= sizeof(GZIP_HEADER) - 1; // -= is right: adds to size + crc_ = crc32(0, nullptr, 0); // initialize + } + + // Initialize the dictionary just before we start compressing + if (settings_.dictionary_) { + err = deflateSetDictionary(&comp_stream_, settings_.dictionary_, + settings_.dict_len_); + if (err != Z_OK) return err; + init_settings_.dictionary_ = settings_.dictionary_; + init_settings_.dict_len_ = settings_.dict_len_; + } + + uncompressed_size_ = 0; + first_chunk_ = false; // so we don't do this again + } + + // flush_mode is Z_FINISH for all mode, Z_SYNC_FLUSH for incremental + // compression. + err = deflate(&comp_stream_, flush_mode); + + const uLong source_bytes_consumed = *sourceLen - comp_stream_.avail_in; + *sourceLen = comp_stream_.avail_in; + + if ((err == Z_STREAM_END || err == Z_OK) && comp_stream_.avail_in == 0 && + comp_stream_.avail_out != 0) { + // we processed everything ok and the output buffer was large enough. + {} + } else if (err == Z_STREAM_END && comp_stream_.avail_in > 0) { + return Z_BUF_ERROR; // should never happen + } else if (err != Z_OK && err != Z_STREAM_END && err != Z_BUF_ERROR) { + // an error happened + CompressErrorInit(); + return err; + } else if (comp_stream_.avail_out == 0) { // not enough space + err = Z_BUF_ERROR; + } + + assert(err == Z_OK || err == Z_STREAM_END || err == Z_BUF_ERROR); + if (err == Z_STREAM_END) err = Z_OK; + + // update the crc and other metadata + uncompressed_size_ += source_bytes_consumed; + compressed_size = comp_stream_.total_out - compressed_size; // delta + *destLen = compressed_size; + if (settings_.gzip_header_mode_) // don't bother with crc else + crc_ = crc32(crc_, source, source_bytes_consumed); + + return err; +} + +int ZLib::CompressChunkOrAll(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen, + int flush_mode) { // Z_FULL_FLUSH or Z_FINISH + const int ret = + CompressAtMostOrAll(dest, destLen, source, &sourceLen, flush_mode); + if (ret == Z_BUF_ERROR) CompressErrorInit(); + return ret; +} + +int ZLib::CompressChunk(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen) { + return CompressChunkOrAll(dest, destLen, source, sourceLen, Z_SYNC_FLUSH); +} + +int ZLib::CompressAtMost(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen) { + return CompressAtMostOrAll(dest, destLen, source, sourceLen, Z_SYNC_FLUSH); +} + +// This writes the gzip footer info, if necessary. +// No matter what, we call Reset() so we can compress Chunks again. +int ZLib::CompressChunkDone(Bytef* dest, uLongf* destLen) { + // Make sure our buffer is of reasonable size. + if (*destLen < MinFooterSize()) { + *destLen = 0; + return Z_BUF_ERROR; + } + + // The underlying zlib library requires a non-nullptr source pointer, even if + // the source length is zero, otherwise it will generate an (incorrect) zero- + // valued CRC checksum. + char dummy = '\0'; + int err; + + assert(!first_chunk_ && comp_init_); + + const uLongf orig_destLen = *destLen; + // NOLINTNEXTLINE + if ((err = CompressChunkOrAll(dest, destLen, (const Bytef*)&dummy, 0, + Z_FINISH)) != Z_OK) { + Reset(); // we assume they won't retry on error + return err; + } + + // Make sure that when we exit, we can start a new round of chunks later + // (This must be set after the call to CompressChunkOrAll() above.) + Reset(); + + // Write gzip footer if necessary. They're explicitly in little-endian order + if (settings_.gzip_header_mode_) { + if ((err = WriteGzipFooter(dest + *destLen, orig_destLen - *destLen)) != + Z_OK) + return err; + *destLen += 8; // zlib footer took up another 8 bytes + } + return Z_OK; // stream_end is ok +} + +// This routine only initializes the compression stream once. Thereafter, it +// just does a deflateReset on the stream, which should be faster. +int ZLib::Compress(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen) { + int err; + const uLongf orig_destLen = *destLen; + if ((err = CompressChunkOrAll(dest, destLen, source, sourceLen, Z_FINISH)) != + Z_OK) + return err; + Reset(); // reset for next call to Compress + + if (settings_.gzip_header_mode_) { + if ((err = WriteGzipFooter(dest + *destLen, orig_destLen - *destLen)) != + Z_OK) + return err; + *destLen += 8; // zlib footer took up another 8 bytes + } + + return Z_OK; +} + +// --------- UNCOMPRESS MODE + +int ZLib::InflateInit() { + // Use negative window bits size to indicate bare stream with no header. + int wbits = (settings_.no_header_mode_ ? -MAX_WBITS : MAX_WBITS); + int err = inflateInit2(&uncomp_stream_, wbits); + if (err == Z_OK) { + init_settings_.no_header_mode_ = settings_.no_header_mode_; + } + return err; +} + +// Initialization method to be called if we hit an error while +// uncompressing. On hitting an error, call this method before +// returning the error. +void ZLib::UncompressErrorInit() { + if (uncomp_init_) { + inflateEnd(&uncomp_stream_); + uncomp_init_ = false; + } + Reset(); +} + +int ZLib::UncompressInit(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen) { + int err; + + uncomp_stream_.next_in = (Bytef*)source; // NOLINT + uncomp_stream_.avail_in = (uInt)*sourceLen; + // Check for sourceLen (unsigned long) to fit into avail_in (unsigned int). + if ((uLong)uncomp_stream_.avail_in != *sourceLen) return Z_BUF_ERROR; + + uncomp_stream_.next_out = dest; + uncomp_stream_.avail_out = (uInt)*destLen; + // Check for destLen (unsigned long) to fit into avail_out (unsigned int). + if ((uLong)uncomp_stream_.avail_out != *destLen) return Z_BUF_ERROR; + + if (!first_chunk_) // only need to set up stream the first time through + return Z_OK; + + // Force full reinit if properties have changed in a way we can't adjust. + if (uncomp_init_ && (init_settings_.dictionary_ != settings_.dictionary_ || + init_settings_.dict_len_ != settings_.dict_len_)) { + inflateEnd(&uncomp_stream_); + uncomp_init_ = false; + } + + // Reuse if we've already initted the object. + if (uncomp_init_) { + // Use negative window bits size to indicate bare stream with no header. + int wbits = (settings_.no_header_mode_ ? -MAX_WBITS : MAX_WBITS); + err = inflateReset2(&uncomp_stream_, wbits); + if (err == Z_OK) { + init_settings_.no_header_mode_ = settings_.no_header_mode_; + } else { + UncompressErrorInit(); + } + } + + // First use or previous state was not reusable with current settings. + if (!uncomp_init_) { + uncomp_stream_.zalloc = (alloc_func)0; + uncomp_stream_.zfree = (free_func)0; + uncomp_stream_.opaque = (voidpf)0; + err = InflateInit(); + if (err != Z_OK) return err; + uncomp_init_ = true; + } + return Z_OK; +} + +// If you compressed your data a chunk at a time, with CompressChunk, +// you can uncompress it a chunk at a time with UncompressChunk. +// Only difference bewteen chunked and unchunked uncompression +// is the flush mode we use: Z_SYNC_FLUSH (chunked) or Z_FINISH (unchunked). +int ZLib::UncompressAtMostOrAll(Bytef* dest, uLongf* destLen, + const Bytef* source, uLong* sourceLen, + int flush_mode) { // Z_SYNC_FLUSH or Z_FINISH + int err = Z_OK; + + if (first_chunk_) { + gzip_footer_bytes_ = -1; + if (settings_.gzip_header_mode_) { + // If we haven't read our first chunk of actual compressed data, + // and we're expecting gzip headers, then parse some more bytes + // from the gzip headers. + const Bytef* bodyBegin = nullptr; + GZipHeader::Status status = gzip_header_->ReadMore( + reinterpret_cast(source), *sourceLen, + reinterpret_cast(&bodyBegin)); + switch (status) { + case GZipHeader::INCOMPLETE_HEADER: // don't have the complete header + *destLen = 0; + *sourceLen = 0; // GZipHeader used all the input + return Z_OK; + case GZipHeader::INVALID_HEADER: // bogus header + Reset(); + return Z_DATA_ERROR; + case GZipHeader::COMPLETE_HEADER: // we have the full header + *sourceLen -= (bodyBegin - source); // skip past header bytes + source = bodyBegin; + crc_ = crc32(0, nullptr, 0); // initialize CRC + break; + default: + LOG(FATAL) << "Unexpected gzip header parsing result: " << status; + } + } + } else if (gzip_footer_bytes_ >= 0) { + // We're now just reading the gzip footer. We already read all the data. + if (gzip_footer_bytes_ + *sourceLen > sizeof(gzip_footer_) && + // When this flag is true, we allow some extra bytes after the + // gzip footer. + !should_be_flexible_with_gzip_footer_) { + VLOG(1) << "UncompressChunkOrAll: Received " + << (gzip_footer_bytes_ + *sourceLen - sizeof(gzip_footer_)) + << " extra bytes after gzip footer: " + << std::string(reinterpret_cast(source), + std::min(*sourceLen, 20UL)); + Reset(); + return Z_DATA_ERROR; + } + uLong len = sizeof(gzip_footer_) - gzip_footer_bytes_; + if (len > *sourceLen) len = *sourceLen; + if (len > 0) { + memcpy(gzip_footer_ + gzip_footer_bytes_, source, len); + gzip_footer_bytes_ += len; + } + *sourceLen -= len; + *destLen = 0; + return Z_OK; + } + + if ((err = UncompressInit(dest, destLen, source, sourceLen)) != Z_OK) { + LOG(WARNING) << "ZLib: UncompressInit: Error: " << err + << "SourceLen: " << *sourceLen; + return err; + } + + // This is used to figure out how many output bytes we wrote *this chunk*: + const uLong old_total_out = uncomp_stream_.total_out; + + // This is used to figure out how many input bytes we read *this chunk*: + const uLong old_total_in = uncomp_stream_.total_in; + + // Some setup happens only for the first chunk we compress in a run + if (first_chunk_) { + // Initialize the dictionary just before we start compressing + if (settings_.gzip_header_mode_ || settings_.no_header_mode_) { + // In no_header_mode, we can just set the dictionary, since no + // checking is done to advance past header bits to get us in the + // dictionary setting mode. In settings_.gzip_header_mode_ we've already + // removed headers, so this code works too. + if (settings_.dictionary_) { + err = inflateSetDictionary(&uncomp_stream_, settings_.dictionary_, + settings_.dict_len_); + if (err != Z_OK) { + LOG(WARNING) << "inflateSetDictionary: Error: " << err + << " dict_len: " << settings_.dict_len_; + UncompressErrorInit(); + return err; + } + init_settings_.dictionary_ = settings_.dictionary_; + init_settings_.dict_len_ = settings_.dict_len_; + } + } + + first_chunk_ = false; // so we don't do this again + + // For the first chunk *only* (to avoid infinite troubles), we let + // there be no actual data to uncompress. This sometimes triggers + // when the input is only the gzip header, say. + if (*sourceLen == 0) { + *destLen = 0; + return Z_OK; + } + } + + // We'll uncompress as much as we can. If we end OK great, otherwise + // if we get an error that seems to be the gzip footer, we store the + // gzip footer and return OK, otherwise we return the error. + + // flush_mode is Z_SYNC_FLUSH for chunked mode, Z_FINISH for all mode. + err = inflate(&uncomp_stream_, flush_mode); + if (settings_.dictionary_ && err == Z_NEED_DICT) { + err = inflateSetDictionary(&uncomp_stream_, settings_.dictionary_, + settings_.dict_len_); + if (err != Z_OK) { + LOG(WARNING) << "UncompressChunkOrAll: failed in inflateSetDictionary : " + << err; + UncompressErrorInit(); + return err; + } + init_settings_.dictionary_ = settings_.dictionary_; + init_settings_.dict_len_ = settings_.dict_len_; + err = inflate(&uncomp_stream_, flush_mode); + } + + // Figure out how many bytes of the input zlib slurped up: + const uLong bytes_read = uncomp_stream_.total_in - old_total_in; + CHECK_LE(source + bytes_read, source + *sourceLen); + *sourceLen = uncomp_stream_.avail_in; + + // Next we look at the footer, if any. Note that we might currently + // have just part of the footer (eg, if this data is arriving over a + // socket). After looking for a footer, log a warning if there is + // extra cruft. + if ((err == Z_STREAM_END) && + ((gzip_footer_bytes_ == -1) || + (gzip_footer_bytes_ < sizeof(gzip_footer_))) && + (uncomp_stream_.avail_in <= sizeof(gzip_footer_) || + // When this flag is true, we allow some extra bytes after the + // zlib footer. + should_be_flexible_with_gzip_footer_)) { + // Due to a bug in old versions of zlibwrapper, we appended the gzip + // footer even in non-gzip mode. Thus we always allow a gzip footer + // even if we're not in gzip mode, so we can continue to uncompress + // the old data. :-( + + // Store gzip footer bytes so we can check for footer consistency + // in UncompressChunkDone(). (If we have the whole footer, we + // could do the checking here, but we don't to keep consistency + // with CompressChunkDone().) + gzip_footer_bytes_ = std::min(static_cast(uncomp_stream_.avail_in), + sizeof(gzip_footer_)); + memcpy(gzip_footer_, source + bytes_read, gzip_footer_bytes_); + *sourceLen -= gzip_footer_bytes_; + } else if ((err == Z_STREAM_END || err == Z_OK) // everything went ok + && uncomp_stream_.avail_in == 0) { // and we read it all + {} + } else if (err == Z_STREAM_END && uncomp_stream_.avail_in > 0) { + VLOG(1) << "UncompressChunkOrAll: Received some extra data, bytes total: " + << uncomp_stream_.avail_in << " bytes: " + << std::string( + reinterpret_cast(uncomp_stream_.next_in), + std::min(static_cast(uncomp_stream_.avail_in), 20)); + UncompressErrorInit(); + return Z_DATA_ERROR; // what's the extra data for? + } else if (err != Z_OK && err != Z_STREAM_END && err != Z_BUF_ERROR) { + // an error happened + VLOG(1) << "UncompressChunkOrAll: Error: " << err + << " avail_out: " << uncomp_stream_.avail_out; + UncompressErrorInit(); + return err; + } else if (uncomp_stream_.avail_out == 0) { + err = Z_BUF_ERROR; + } + + assert(err == Z_OK || err == Z_BUF_ERROR || err == Z_STREAM_END); + if (err == Z_STREAM_END && !settings_.dont_hide_zstream_end_) err = Z_OK; + + // update the crc and other metadata + uncompressed_size_ = uncomp_stream_.total_out; + *destLen = uncomp_stream_.total_out - old_total_out; // size for this call + if (settings_.gzip_header_mode_) crc_ = crc32(crc_, dest, *destLen); + + return err; +} + +int ZLib::UncompressChunkOrAll(Bytef* dest, uLongf* destLen, + const Bytef* source, uLong sourceLen, + int flush_mode) { // Z_SYNC_FLUSH or Z_FINISH + const int ret = + UncompressAtMostOrAll(dest, destLen, source, &sourceLen, flush_mode); + if (ret == Z_BUF_ERROR) UncompressErrorInit(); + return ret; +} + +int ZLib::UncompressAtMost(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen) { + return UncompressAtMostOrAll(dest, destLen, source, sourceLen, Z_SYNC_FLUSH); +} + +int ZLib::UncompressChunk(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen) { + return UncompressChunkOrAll(dest, destLen, source, sourceLen, Z_SYNC_FLUSH); +} + +// We make sure we've uncompressed everything, that is, the current +// uncompress stream is at a compressed-buffer-EOF boundary. In gzip +// mode, we also check the gzip footer to make sure we pass the gzip +// consistency checks. We RETURN true iff both types of checks pass. +bool ZLib::UncompressChunkDone() { + if (first_chunk_ || !uncomp_init_) { + return false; + } + // Make sure we're at the end-of-compressed-data point. This means + // if we call inflate with Z_FINISH we won't consume any input or + // write any output + Bytef dummyin, dummyout; + uLongf dummylen = 0; + if (UncompressChunkOrAll(&dummyout, &dummylen, &dummyin, 0, Z_FINISH) != + Z_OK) { + return false; + } + + // Make sure that when we exit, we can start a new round of chunks later + Reset(); + + // We don't need to check footer when this flag is true. + if (should_be_flexible_with_gzip_footer_) { + return true; + } + + // Whether we were hoping for a gzip footer or not, we allow a gzip + // footer. (See the note above about bugs in old zlibwrappers.) But + // by the time we've seen all the input, it has to be either a + // complete gzip footer, or no footer at all. + if ((gzip_footer_bytes_ != -1) && (gzip_footer_bytes_ != 0) && + (gzip_footer_bytes_ != sizeof(gzip_footer_))) + return false; + + if (!settings_.gzip_header_mode_) return true; + + return IsGzipFooterValid(); +} + +bool ZLib::IsGzipFooterValid() const { + // If we were expecting a gzip footer, and didn't get a full one, + // that's an error. + if (gzip_footer_bytes_ == -1 || gzip_footer_bytes_ < sizeof(gzip_footer_)) + return false; + + // The footer holds the lower four bytes of the length. + uLong uncompressed_size = 0; + uncompressed_size += static_cast(gzip_footer_[7]) << 24; + uncompressed_size += gzip_footer_[6] << 16; + uncompressed_size += gzip_footer_[5] << 8; + uncompressed_size += gzip_footer_[4] << 0; + if (uncompressed_size != (uncompressed_size_ & 0xffffffff)) { + return false; + } + + uLong checksum = 0; + checksum += static_cast(gzip_footer_[3]) << 24; + checksum += gzip_footer_[2] << 16; + checksum += gzip_footer_[1] << 8; + checksum += gzip_footer_[0] << 0; + if (crc_ != checksum) return false; + + return true; +} + +// Uncompresses the source buffer into the destination buffer. +// The destination buffer must be long enough to hold the entire +// decompressed contents. +// +// We only initialize the uncomp_stream once. Thereafter, we use +// inflateReset2, which should be faster. +// +// Returns Z_OK on success, otherwise, it returns a zlib error code. +int ZLib::Uncompress(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen) { + int err; + if ((err = UncompressChunkOrAll(dest, destLen, source, sourceLen, + Z_FINISH)) != Z_OK) { + Reset(); // let us try to compress again + return err; + } + if (!UncompressChunkDone()) // calls Reset() + return Z_DATA_ERROR; + return Z_OK; // stream_end is ok +} + +// read uncompress length from gzip footer +uLongf ZLib::GzipUncompressedLength(const Bytef* source, uLong len) { + if (len <= 4) return 0; // malformed data. + + return (static_cast(source[len - 1]) << 24) + + (static_cast(source[len - 2]) << 16) + + (static_cast(source[len - 3]) << 8) + + (static_cast(source[len - 4]) << 0); +} + +int ZLib::UncompressGzipAndAllocate(Bytef** dest, uLongf* destLen, + const Bytef* source, uLong sourceLen) { + *dest = nullptr; // until we successfully allocate + if (!settings_.gzip_header_mode_) return Z_VERSION_ERROR; // *shrug* + + uLongf uncompress_length = GzipUncompressedLength(source, sourceLen); + + // Do not trust the uncompress size reported by the compressed buffer. + if (uncompress_length > *destLen) { + if (!HasGzipHeader(reinterpret_cast(source), sourceLen)) { + VLOG(1) << "Attempted to un-gzip data that is not gzipped."; + return Z_DATA_ERROR; + } + VLOG(1) << "Uncompressed size " << uncompress_length + << " exceeds maximum expected size " << *destLen; + return Z_MEM_ERROR; // probably a corrupted gzip buffer + } + + *destLen = uncompress_length; + + *dest = (Bytef*)malloc(*destLen); // NOLINT + if (*dest == nullptr) // probably a corrupted gzip buffer + return Z_MEM_ERROR; + + const int retval = Uncompress(*dest, destLen, source, sourceLen); + if (retval != Z_OK) { // just to make life easier for them + free(*dest); + *dest = nullptr; + } + return retval; +} + +// Convenience method to check if a bytestream has a header. This +// is intended as a quick test: "Is this likely a GZip file?" +bool ZLib::HasGzipHeader(const char* source, int sourceLen) { + GZipHeader gzh; + const char* ptr = nullptr; + return gzh.ReadMore(source, sourceLen, &ptr) == GZipHeader::COMPLETE_HEADER; +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/zlib_wrapper/zlibwrapper.h b/sparse_matmul/zlib_wrapper/zlibwrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..22e3980e08b8dcfca7407c8c0d22f9ebb7d9bef6 --- /dev/null +++ b/sparse_matmul/zlib_wrapper/zlibwrapper.h @@ -0,0 +1,320 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_ZLIB_ZLIBWRAPPER_H +#define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_ZLIB_ZLIBWRAPPER_H + +#include "zlib.h" + +namespace csrblocksparse { + +class GZipHeader; + +class ZLib { + public: + ZLib(); + ~ZLib(); + + // Set this to true if you want to be flexible with the gzip footer. + static void set_should_be_flexible_with_gzip_footer(bool b) { + should_be_flexible_with_gzip_footer_ = b; + } + + static bool should_be_flexible_with_gzip_footer() { + return should_be_flexible_with_gzip_footer_; + } + + // Wipe a ZLib object to a virgin state. This differs from Reset() + // in that it also breaks any dictionary, gzip, etc, state. + void Reinit(); + + // Call this to make a zlib buffer as good as new. Here's the only + // case where they differ: + // CompressChunk(a); CompressChunk(b); CompressChunkDone(); vs + // CompressChunk(a); Reset(); CompressChunk(b); CompressChunkDone(); + // You'll want to use Reset(), then, when you interrupt a compress + // (or uncompress) in the middle of a chunk and want to start over. + void Reset(); + + // Sets no_header_mode appropriately. Note that using NoHeaderMode + // in conjunction with a preset dictionary is not supported (zlib + // starts behaving oddly if you try to do this). + void SetNoHeaderMode(bool no_header_mode); + + // Returns our current no_header_mode. + bool no_header_mode() const { return settings_.no_header_mode_; } + + // Uses a gzip header/footer; the output is a valid gzip file. + // This also causes us to generate a crc32 checksum used with gzip + void SetGzipHeaderMode(); + + // By default UncompressAtMostOrAll will return Z_OK upon hitting the end of + // the input stream. This function modifies that behavior by returning + // Z_STREAM_END instead. This is useful when getting multiple compressed + // documents in a single stream. Returning Z_STREAM_END will indicate the end + // of a document. + void SetDontHideStreamEnd(); + + // Sets the compression level to be used + void SetCompressionLevel(int level) { settings_.compression_level_ = level; } + + // Sets the size of the window (history buffer) used by the compressor. + // The size is expressed in bits (log base 2 of the desired size). + void SetCompressionWindowSizeInBits(int bits) { + settings_.window_bits_ = bits; + } + + // Controls the amount of memory used by the compresser. + // Legal value are 1 through 9. See zlib.h for more info. + void SetCompressionMemLevel(int level) { settings_.mem_level_ = level; } + + // Sets the initial dictionary to be used for decompression. + void SetDictionary(const char* initial_dict, unsigned int dict_len); + + // According to the zlib manual, when you Compress, the destination + // buffer must have size at least src + .1%*src + 12. This function + // helps you calculate that. Augment this to account for a potential + // gzip header and footer, plus a few bytes of slack. + static uLong MinCompressbufSize(uLong uncompress_size) { + return uncompress_size + uncompress_size / 1000 + 40; + } + + // The minimum size of footers written by CompressChunkDone(). + int MinFooterSize() const; + + // Compresses the source buffer into the destination buffer. + // sourceLen is the byte length of the source buffer. + // Upon entry, destLen is the total size of the destination buffer, + // which must be of size at least MinCompressbufSize(sourceLen). + // Upon exit, destLen is the actual size of the compressed buffer. + // + // This function can be used to compress a whole file at once if the + // input file is mmap'ed. + // + // Returns Z_OK if success, Z_MEM_ERROR if there was not + // enough memory, Z_BUF_ERROR if there was not enough room in the + // output buffer. Note that if the output buffer is exactly the same + // size as the compressed result, we still return Z_BUF_ERROR. + // (check CL#1936076) + // + // If the values of *destLen or sourceLen do not fit in an unsigned int, + // Z_BUF_ERROR is returned. + int Compress(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen); + + // Uncompresses the source buffer into the destination buffer. + // The destination buffer must be long enough to hold the entire + // decompressed contents. + // + // Returns Z_OK on success, otherwise, it returns a zlib error code. + // + // If the values of *destLen or sourceLen do not fit in an unsigned int, + // Z_BUF_ERROR is returned. + int Uncompress(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen); + + // Get the uncompressed size from the gzip header. Returns 0 if source is too + // short (len < 5). + uLongf GzipUncompressedLength(const Bytef* source, uLong len); + + // Special helper function to help uncompress gzipped documents: + // We'll allocate (with malloc) a destination buffer exactly big + // enough to hold the gzipped content. We set dest and destLen. + // If we don't return Z_OK, *dest will be NULL, otherwise you + // should free() it when you're done with it. + // Returns Z_OK on success, otherwise, it returns a zlib error code. + // Its the responsibility of the user to set *destLen to the + // expected maximum size of the uncompressed data. The size of the + // uncompressed data is read from the compressed buffer gzip footer. + // This value cannot be trusted, so we compare it to the expected + // maximum size supplied by the user, returning Z_MEM_ERROR if its + // greater than the expected maximum size. + int UncompressGzipAndAllocate(Bytef** dest, uLongf* destLen, + const Bytef* source, uLong sourceLen); + + // Streaming compression and decompression methods come in two + // variations. {Unc,C}ompressAtMost() and {Unc,C}ompressChunk(). + // The former decrements sourceLen by the amount of data that was + // consumed: if it returns Z_BUF_ERROR, set the source of the next + // {Unc,C}ompressAtMost() to the unconsumed data. + // {Unc,C}ompressChunk() is the legacy interface and does not do + // this, thus it cannot recover from a Z_BUF_ERROR (except for in + // the first chunk). + + // Compresses data one chunk at a time -- ie you can call this more + // than once. This is useful for a webserver, for instance, which + // might want to use chunked encoding with compression. To get this + // to work you need to call start and finish routines. + // + // Returns Z_OK if success, Z_MEM_ERROR if there was not + // enough memory, Z_BUF_ERROR if there was not enough room in the + // output buffer. + + int CompressAtMost(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen); + + int CompressChunk(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen); + + // Emits gzip footer information, as needed. + // destLen should be at least MinFooterSize() long. + // Returns Z_OK, Z_MEM_ERROR, and Z_BUF_ERROR as in CompressChunk(). + int CompressChunkDone(Bytef* dest, uLongf* destLen); + + // Uncompress data one chunk at a time -- ie you can call this + // more than once. To get this to work you need to call per-chunk + // and "done" routines. + // + // Returns Z_OK if success, Z_MEM_ERROR if there was not + // enough memory, Z_BUF_ERROR if there was not enough room in the + // output buffer. + + int UncompressAtMost(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen); + int UncompressChunk(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen); + + // Checks gzip footer information, as needed. Mostly this just + // makes sure the checksums match. Whenever you call this, it + // will assume the last 8 bytes from the previous UncompressChunk + // call are the footer. Returns true iff everything looks ok. + bool UncompressChunkDone(); + + // Only meaningful for chunked compressing/uncompressing. It's true + // after initialization or reset and before the first chunk of + // user data is received. + bool first_chunk() const { return first_chunk_; } + + // Returns a pointer to our current dictionary: + const Bytef* dictionary() const { return settings_.dictionary_; } + + // Convenience method to check if a bytestream has a header. This + // is intended as a quick test: "Is this likely a GZip file?" + static bool HasGzipHeader(const char* source, int sourceLen); + + // Have we parsed the complete gzip footer, and does it match the + // length and CRC checksum of the content that we have uncompressed + // so far? + bool IsGzipFooterValid() const; + + // Accessor for the uncompressed size (first added to address issue #509976) + uLong uncompressed_size() const { return uncompressed_size_; } + + private: + int InflateInit(); // sets up the zlib inflate structure + int DeflateInit(); // sets up the zlib deflate structure + + // These init the zlib data structures for compressing/uncompressing + int CompressInit(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen); + int UncompressInit(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen); + // Initialization method to be called if we hit an error while + // uncompressing. On hitting an error, call this method before + // returning the error. + void UncompressErrorInit(); + // Helper functions to write gzip-specific data + int WriteGzipHeader(); + int WriteGzipFooter(Bytef* dest, uLongf destLen); + + // Helper function for both Compress and CompressChunk + int CompressChunkOrAll(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen, int flush_mode); + int CompressAtMostOrAll(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen, int flush_mode); + + // Likewise for UncompressAndUncompressChunk + int UncompressChunkOrAll(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen, int flush_mode); + + int UncompressAtMostOrAll(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen, int flush_mode); + + // Initialization method to be called if we hit an error while + // compressing. On hitting an error, call this method before + // returning the error. + void CompressErrorInit(); + + // Makes sure the parameters are valid + void CheckValidParams(); + + struct Settings { + // null if we don't want an initial dictionary + Bytef* dictionary_; // NOLINT + + // initial dictionary length + unsigned int dict_len_; // NOLINT + + // compression level + int compression_level_; // NOLINT + + // log base 2 of the window size used in compression + int window_bits_; // NOLINT + + // specifies the amount of memory to be used by compressor (1-9) + int mem_level_; // NOLINT + + // true if we want/expect no zlib headers + bool no_header_mode_; // NOLINT + + // true if we want/expect gzip headers + bool gzip_header_mode_; // NOLINT + + // Controls behavior of UncompressAtMostOrAll with regards to returning + // Z_STREAM_END. See comments for SetDontHideStreamEnd. + bool dont_hide_zstream_end_; // NOLINT + }; + + // We allow all kinds of bad footers when this flag is true. + // Some web servers send bad pages corresponding to these cases + // and IE is tolerant with it. + // - Extra bytes after gzip footer (see bug 69126) + // - No gzip footer (see bug 72896) + // - Incomplete gzip footer (see bug 71871706) + static bool should_be_flexible_with_gzip_footer_; + + // "Current" settings. These will be used whenever we next configure zlib. + // For example changing compression level or header mode will be recorded + // in these, but don't usually get applied immediately but on next compress. + Settings settings_; + + // Settings last used to initialise and configure zlib. These are needed + // to know if the current desired configuration in settings_ is sufficiently + // compatible with the previous configuration and we can just reconfigure the + // underlying zlib objects, or have to recreate them from scratch. + Settings init_settings_; + + z_stream comp_stream_; // Zlib stream data structure + bool comp_init_; // True if we have initialized comp_stream_ + z_stream uncomp_stream_; // Zlib stream data structure + bool uncomp_init_; // True if we have initialized uncomp_stream_ + + // These are used only in gzip compression mode + uLong crc_; // stored in gzip footer, fitting 4 bytes + uLong uncompressed_size_; + + GZipHeader* gzip_header_; // our gzip header state + + Byte gzip_footer_[8]; // stored footer, used to uncompress + int gzip_footer_bytes_; // num of footer bytes read so far, or -1 + + // These are used only with chunked compression. + bool first_chunk_; // true if we need to emit headers with this chunk +}; + +} // namespace csrblocksparse + +#endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_ZLIB_ZLIBWRAPPER_H diff --git a/tacotron.py b/tacotron.py new file mode 100644 index 0000000000000000000000000000000000000000..2671bcf9a219ad2007200278347df43e0cddde17 --- /dev/null +++ b/tacotron.py @@ -0,0 +1,451 @@ +""" +Tacotron + stepwise monotonic attention +""" + +import jax +import jax.numpy as jnp +import pax + + +def conv_block(in_ft, out_ft, kernel_size, activation_fn, use_dropout): + """ + Conv >> LayerNorm >> activation >> Dropout + """ + f = pax.Sequential( + pax.Conv1D(in_ft, out_ft, kernel_size, with_bias=False), + pax.LayerNorm(out_ft, -1, True, True), + ) + if activation_fn is not None: + f >>= activation_fn + if use_dropout: + f >>= pax.Dropout(0.5) + return f + + +class HighwayBlock(pax.Module): + """ + Highway block + """ + + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + self.fc = pax.Linear(dim, 2 * dim) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + t, h = jnp.split(self.fc(x), 2, axis=-1) + t = jax.nn.sigmoid(t - 1.0) # bias toward keeping x + h = jax.nn.relu(h) + x = x * (1.0 - t) + h * t + return x + + +class BiGRU(pax.Module): + """ + Bidirectional GRU + """ + + def __init__(self, dim): + super().__init__() + + self.rnn_fwd = pax.GRU(dim, dim) + self.rnn_bwd = pax.GRU(dim, dim) + + def __call__(self, x, reset_masks): + N = x.shape[0] + x_fwd = x + x_bwd = jnp.flip(x, axis=1) + x_fwd_states = self.rnn_fwd.initial_state(N) + x_bwd_states = self.rnn_bwd.initial_state(N) + x_fwd_states, x_fwd = pax.scan( + self.rnn_fwd, x_fwd_states, x_fwd, time_major=False + ) + + reset_masks = jnp.flip(reset_masks, axis=1) + x_bwd_states0 = x_bwd_states + + def rnn_reset_core(prev, inputs): + x, reset_mask = inputs + + def reset_state(x0, xt): + return jnp.where(reset_mask, x0, xt) + + state, _ = self.rnn_bwd(prev, x) + state = jax.tree_map(reset_state, x_bwd_states0, state) + return state, state.hidden + + x_bwd_states, x_bwd = pax.scan( + rnn_reset_core, x_bwd_states, (x_bwd, reset_masks), time_major=False + ) + x_bwd = jnp.flip(x_bwd, axis=1) + x = jnp.concatenate((x_fwd, x_bwd), axis=-1) + return x + + +class CBHG(pax.Module): + """ + Conv Bank >> Highway net >> GRU + """ + + def __init__(self, dim): + super().__init__() + self.convs = [conv_block(dim, dim, i, jax.nn.relu, False) for i in range(1, 17)] + self.conv_projection_1 = conv_block(16 * dim, dim, 3, jax.nn.relu, False) + self.conv_projection_2 = conv_block(dim, dim, 3, None, False) + + self.highway = pax.Sequential( + HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim) + ) + self.rnn = BiGRU(dim) + + def __call__(self, x, x_mask): + conv_input = x * x_mask + fts = [f(conv_input) for f in self.convs] + residual = jnp.concatenate(fts, axis=-1) + residual = pax.max_pool(residual, 2, 1, "SAME", -1) + residual = self.conv_projection_1(residual * x_mask) + residual = self.conv_projection_2(residual * x_mask) + x = x + residual + x = self.highway(x) + x = self.rnn(x * x_mask, reset_masks=1 - x_mask) + return x * x_mask + + +class PreNet(pax.Module): + """ + Linear >> relu >> dropout >> Linear >> relu >> dropout + """ + + def __init__(self, input_dim, hidden_dim, output_dim, always_dropout=True): + super().__init__() + self.fc1 = pax.Linear(input_dim, hidden_dim) + self.fc2 = pax.Linear(hidden_dim, output_dim) + self.rng_seq = pax.RngSeq() + self.always_dropout = always_dropout + + def __call__(self, x, k1=None, k2=None): + x = self.fc1(x) + x = jax.nn.relu(x) + if self.always_dropout or self.training: + if k1 is None: + k1 = self.rng_seq.next_rng_key() + x = pax.dropout(k1, 0.5, x) + x = self.fc2(x) + x = jax.nn.relu(x) + if self.always_dropout or self.training: + if k2 is None: + k2 = self.rng_seq.next_rng_key() + x = pax.dropout(k2, 0.5, x) + return x + + +class Tacotron(pax.Module): + """ + Tacotron TTS model. + + It uses stepwise monotonic attention for robust attention. + """ + + def __init__( + self, + mel_dim: int, + attn_bias, + rr, + max_rr, + mel_min, + sigmoid_noise, + pad_token, + prenet_dim, + attn_hidden_dim, + attn_rnn_dim, + rnn_dim, + postnet_dim, + text_dim, + ): + """ + New Tacotron model + + Args: + mel_dim (int): dimension of log mel-spectrogram features. + attn_bias (float): control how "slow" the attention will + move forward at initialization. + rr (int): the reduction factor. + Number of predicted frame at each time step. Default is 2. + max_rr (int): max value of rr. + mel_min (float): the minimum value of mel features. + The frame is filled by `log(mel_min)` values. + sigmoid_noise (float): the variance of gaussian noise added + to attention scores in training. + pad_token (int): the pad value at the end of text sequences. + prenet_dim (int): dimension of prenet output. + attn_hidden_dim (int): dimension of attention hidden vectors. + attn_rnn_dim (int): number of cells in the attention RNN. + rnn_dim (int): number of cells in the decoder RNNs. + postnet_dim (int): number of features in the postnet convolutions. + text_dim (int): dimension of text embedding vectors. + """ + super().__init__() + self.text_dim = text_dim + assert rr <= max_rr + self.rr = rr + self.max_rr = max_rr + self.mel_dim = mel_dim + self.mel_min = mel_min + self.sigmoid_noise = sigmoid_noise + self.pad_token = pad_token + self.prenet_dim = prenet_dim + + # encoder submodules + self.encoder_embed = pax.Embed(256, text_dim) + self.encoder_pre_net = PreNet(text_dim, 256, prenet_dim, always_dropout=True) + self.encoder_cbhg = CBHG(prenet_dim) + + # random key generator + self.rng_seq = pax.RngSeq() + + # pre-net + self.decoder_pre_net = PreNet(mel_dim, 256, prenet_dim, always_dropout=True) + + # decoder submodules + self.attn_rnn = pax.LSTM(prenet_dim + prenet_dim * 2, attn_rnn_dim) + self.text_key_fc = pax.Linear(prenet_dim * 2, attn_hidden_dim, with_bias=True) + self.attn_query_fc = pax.Linear(attn_rnn_dim, attn_hidden_dim, with_bias=False) + + self.attn_V = pax.Linear(attn_hidden_dim, 1, with_bias=False) + self.attn_V_weight_norm = jnp.array(1.0 / jnp.sqrt(attn_hidden_dim)) + self.attn_V_bias = jnp.array(attn_bias) + self.attn_log = jnp.zeros((1,)) + self.decoder_input = pax.Linear(attn_rnn_dim + 2 * prenet_dim, rnn_dim) + self.decoder_rnn1 = pax.LSTM(rnn_dim, rnn_dim) + self.decoder_rnn2 = pax.LSTM(rnn_dim, rnn_dim) + # mel + end-of-sequence token + self.output_fc = pax.Linear(rnn_dim, (mel_dim + 1) * max_rr, with_bias=True) + + # post-net + self.post_net = pax.Sequential( + conv_block(mel_dim, postnet_dim, 5, jax.nn.tanh, True), + conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True), + conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True), + conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True), + conv_block(postnet_dim, mel_dim, 5, None, True), + ) + + parameters = pax.parameters_method("attn_V_weight_norm", "attn_V_bias") + + def encode_text(self, text: jnp.ndarray) -> jnp.ndarray: + """ + Encode text to a sequence of real vectors + """ + N, L = text.shape + text_mask = (text != self.pad_token)[..., None] + x = self.encoder_embed(text) + x = self.encoder_pre_net(x) + x = self.encoder_cbhg(x, text_mask) + return x + + def go_frame(self, batch_size: int) -> jnp.ndarray: + """ + return the go frame + """ + return jnp.ones((batch_size, self.mel_dim)) * jnp.log(self.mel_min) + + def decoder_initial_state(self, N: int, L: int): + """ + setup decoder initial state + """ + attn_context = jnp.zeros((N, self.prenet_dim * 2)) + attn_pr = jax.nn.one_hot( + jnp.zeros((N,), dtype=jnp.int32), num_classes=L, axis=-1 + ) + + attn_state = (self.attn_rnn.initial_state(N), attn_context, attn_pr) + decoder_rnn_states = ( + self.decoder_rnn1.initial_state(N), + self.decoder_rnn2.initial_state(N), + ) + return attn_state, decoder_rnn_states + + def monotonic_attention(self, prev_state, inputs, envs): + """ + Stepwise monotonic attention + """ + attn_rnn_state, attn_context, prev_attn_pr = prev_state + x, attn_rng_key = inputs + text, text_key = envs + attn_rnn_input = jnp.concatenate((x, attn_context), axis=-1) + attn_rnn_state, attn_rnn_output = self.attn_rnn(attn_rnn_state, attn_rnn_input) + attn_query_input = attn_rnn_output + attn_query = self.attn_query_fc(attn_query_input) + attn_hidden = jnp.tanh(attn_query[:, None, :] + text_key) + score = self.attn_V(attn_hidden) + score = jnp.squeeze(score, axis=-1) + weight_norm = jnp.linalg.norm(self.attn_V.weight) + score = score * (self.attn_V_weight_norm / weight_norm) + score = score + self.attn_V_bias + noise = jax.random.normal(attn_rng_key, score.shape) * self.sigmoid_noise + pr_stay = jax.nn.sigmoid(score + noise) + pr_move = 1.0 - pr_stay + pr_new_location = pr_move * prev_attn_pr + pr_new_location = jnp.pad( + pr_new_location[:, :-1], ((0, 0), (1, 0)), constant_values=0 + ) + attn_pr = pr_stay * prev_attn_pr + pr_new_location + attn_context = jnp.einsum("NL,NLD->ND", attn_pr, text) + new_state = (attn_rnn_state, attn_context, attn_pr) + return new_state, attn_rnn_output + + def zoneout_lstm(self, lstm_core, rng_key, zoneout_pr=0.1): + """ + Return a zoneout lstm core. + + It will zoneout the new hidden states and keep the new cell states unchanged. + """ + + def core(state, x): + new_state, _ = lstm_core(state, x) + h_old = state.hidden + h_new = new_state.hidden + mask = jax.random.bernoulli(rng_key, zoneout_pr, h_old.shape) + h_new = h_old * mask + h_new * (1.0 - mask) + return pax.LSTMState(h_new, new_state.cell), h_new + + return core + + def decoder_step( + self, + attn_state, + decoder_rnn_states, + rng_key, + mel, + text, + text_key, + call_pre_net=False, + ): + """ + One decoder step + """ + if call_pre_net: + k1, k2, zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 6) + mel = self.decoder_pre_net(mel, k1, k2) + else: + zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 4) + attn_inputs = (mel, rng_key) + attn_envs = (text, text_key) + attn_state, attn_rnn_output = self.monotonic_attention( + attn_state, attn_inputs, attn_envs + ) + (_, attn_context, attn_pr) = attn_state + (decoder_rnn_state1, decoder_rnn_state2) = decoder_rnn_states + decoder_rnn1_input = jnp.concatenate((attn_rnn_output, attn_context), axis=-1) + decoder_rnn1_input = self.decoder_input(decoder_rnn1_input) + decoder_rnn1 = self.zoneout_lstm(self.decoder_rnn1, zk1) + decoder_rnn_state1, decoder_rnn_output1 = decoder_rnn1( + decoder_rnn_state1, decoder_rnn1_input + ) + decoder_rnn2_input = decoder_rnn1_input + decoder_rnn_output1 + decoder_rnn2 = self.zoneout_lstm(self.decoder_rnn2, zk2) + decoder_rnn_state2, decoder_rnn_output2 = decoder_rnn2( + decoder_rnn_state2, decoder_rnn2_input + ) + x = decoder_rnn1_input + decoder_rnn_output1 + decoder_rnn_output2 + decoder_rnn_states = (decoder_rnn_state1, decoder_rnn_state2) + return attn_state, decoder_rnn_states, rng_key_next, x, attn_pr[0] + + @jax.jit + def inference_step( + self, attn_state, decoder_rnn_states, rng_key, mel, text, text_key + ): + """one inference step""" + attn_state, decoder_rnn_states, rng_key, x, _ = self.decoder_step( + attn_state, + decoder_rnn_states, + rng_key, + mel, + text, + text_key, + call_pre_net=True, + ) + x = self.output_fc(x) + N, D2 = x.shape + x = jnp.reshape(x, (N, self.max_rr, D2 // self.max_rr)) + x = x[:, : self.rr, :] + x = jnp.reshape(x, (N, self.rr, -1)) + mel = x[..., :-1] + eos_logit = x[..., -1] + eos_pr = jax.nn.sigmoid(eos_logit[0, -1]) + eos_pr = jnp.where(eos_pr < 0.1, 0.0, eos_pr) + rng_key, eos_rng_key = jax.random.split(rng_key) + eos = jax.random.bernoulli(eos_rng_key, p=eos_pr) + return attn_state, decoder_rnn_states, rng_key, (mel, eos) + + def inference(self, text, seed=42, max_len=1000): + """ + text to mel + """ + text = self.encode_text(text) + text_key = self.text_key_fc(text) + N, L, D = text.shape + assert N == 1 + mel = self.go_frame(N) + + attn_state, decoder_rnn_states = self.decoder_initial_state(N, L) + rng_key = jax.random.PRNGKey(seed) + mels = [] + count = 0 + while True: + count = count + 1 + attn_state, decoder_rnn_states, rng_key, (mel, eos) = self.inference_step( + attn_state, decoder_rnn_states, rng_key, mel, text, text_key + ) + mels.append(mel) + if eos.item() or count > max_len: + break + + mel = mel[:, -1, :] + + mels = jnp.concatenate(mels, axis=1) + mel = mel + self.post_net(mel) + return mels + + def decode(self, mel, text): + """ + Attention mechanism + Decoder + """ + text_key = self.text_key_fc(text) + + def scan_fn(prev_states, inputs): + attn_state, decoder_rnn_states = prev_states + x, rng_key = inputs + attn_state, decoder_rnn_states, _, output, attn_pr = self.decoder_step( + attn_state, decoder_rnn_states, rng_key, x, text, text_key + ) + states = (attn_state, decoder_rnn_states) + return states, (output, attn_pr) + + N, L, D = text.shape + decoder_states = self.decoder_initial_state(N, L) + rng_keys = self.rng_seq.next_rng_key(mel.shape[1]) + rng_keys = jnp.stack(rng_keys, axis=1) + decoder_states, (x, attn_log) = pax.scan( + scan_fn, + decoder_states, + (mel, rng_keys), + time_major=False, + ) + self.attn_log = attn_log + del decoder_states + x = self.output_fc(x) + + N, T2, D2 = x.shape + x = jnp.reshape(x, (N, T2, self.max_rr, D2 // self.max_rr)) + x = x[:, :, : self.rr, :] + x = jnp.reshape(x, (N, T2 * self.rr, -1)) + mel = x[..., :-1] + eos = x[..., -1] + return mel, eos + + def __call__(self, mel: jnp.ndarray, text: jnp.ndarray): + text = self.encode_text(text) + mel = self.decoder_pre_net(mel) + mel, eos = self.decode(mel, text) + return mel, mel + self.post_net(mel), eos diff --git a/tacotron.toml b/tacotron.toml new file mode 100644 index 0000000000000000000000000000000000000000..e0e5bca89f0685d12badc982b93fe8881c30e043 --- /dev/null +++ b/tacotron.toml @@ -0,0 +1,32 @@ +[tacotron] + +# training +BATCH_SIZE = 64 +LR=1024e-6 # learning rate +MODEL_PREFIX = "mono_tts_cbhg_small" +LOG_DIR = "./logs" +CKPT_DIR = "./ckpts" +USE_MP = false # use mixed-precision training + +# data +TF_DATA_DIR = "./tf_data" # tensorflow data directory +TF_GTA_DATA_DIR = "./tf_gta_data" # tf gta data directory +SAMPLE_RATE = 24000 # convert to this sample rate if needed +MEL_DIM = 80 # the dimension of melspectrogram features +MEL_MIN = 1e-5 +PAD = "_" # padding character +PAD_TOKEN = 0 +END_CHARACTER = "■" # to signal the end of the transcript +TEST_DATA_SIZE = 1024 + +# model +RR = 1 # reduction factor +MAX_RR=2 +ATTN_BIAS = 0.0 # control how slow the attention moves forward +SIGMOID_NOISE = 2.0 +PRENET_DIM = 128 +TEXT_DIM = 256 +RNN_DIM = 512 +ATTN_RNN_DIM = 256 +ATTN_HIDDEN_DIM = 128 +POSTNET_DIM = 512 \ No newline at end of file diff --git a/tacotrons_ljs_24k_v1_0300000.ckpt b/tacotrons_ljs_24k_v1_0300000.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..8058842ecdd9d48afec9bce12d2c3f5928fe2d03 --- /dev/null +++ b/tacotrons_ljs_24k_v1_0300000.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73368d52c6c519682b73fa9676fe2eaed1712aa559026034ab6a36b2bfd8f8c0 +size 53561547 diff --git a/text.py b/text.py new file mode 100644 index 0000000000000000000000000000000000000000..ba6a3a29549b13a68c6eb4de6997cf37db1aeb67 --- /dev/null +++ b/text.py @@ -0,0 +1,92 @@ +""" from https://github.com/keithito/tacotron """ + +""" +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You'll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +""" + +import re +from mynumbers import normalize_numbers +from unidecode import unidecode + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r"\s+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + """Pipeline for non-English text that transliterates to ASCII.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + """Pipeline for English text, including number and abbreviation expansion.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + return text diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1dcf6d5bf06bfb53678158d5ede70156dcbeb3a --- /dev/null +++ b/utils.py @@ -0,0 +1,74 @@ +""" +Utility functions +""" +import pickle +from pathlib import Path + +import pax +import toml +import yaml + +from tacotron import Tacotron + + +def load_tacotron_config(config_file=Path("tacotron.toml")): + """ + Load the project configurations + """ + return toml.load(config_file)["tacotron"] + + +def load_tacotron_ckpt(net: pax.Module, optim: pax.Module, path): + """ + load checkpoint from disk + """ + with open(path, "rb") as f: + dic = pickle.load(f) + if net is not None: + net = net.load_state_dict(dic["model_state_dict"]) + if optim is not None: + optim = optim.load_state_dict(dic["optim_state_dict"]) + return dic["step"], net, optim + + +def create_tacotron_model(config): + """ + return a random initialized Tacotron model + """ + return Tacotron( + mel_dim=config["MEL_DIM"], + attn_bias=config["ATTN_BIAS"], + rr=config["RR"], + max_rr=config["MAX_RR"], + mel_min=config["MEL_MIN"], + sigmoid_noise=config["SIGMOID_NOISE"], + pad_token=config["PAD_TOKEN"], + prenet_dim=config["PRENET_DIM"], + attn_hidden_dim=config["ATTN_HIDDEN_DIM"], + attn_rnn_dim=config["ATTN_RNN_DIM"], + rnn_dim=config["RNN_DIM"], + postnet_dim=config["POSTNET_DIM"], + text_dim=config["TEXT_DIM"], + ) + + +def load_wavegru_config(config_file): + """ + Load project configurations + """ + with open(config_file, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + + +def load_wavegru_ckpt(net, optim, ckpt_file): + """ + load training checkpoint from file + """ + with open(ckpt_file, "rb") as f: + dic = pickle.load(f) + + if net is not None: + net = net.load_state_dict(dic["net_state_dict"]) + if optim is not None: + optim = optim.load_state_dict(dic["optim_state_dict"]) + return dic["step"], net, optim diff --git a/wavegru.py b/wavegru.py new file mode 100644 index 0000000000000000000000000000000000000000..4ed0fd654df51995e0bfc61d6c26810f777517c0 --- /dev/null +++ b/wavegru.py @@ -0,0 +1,300 @@ +""" +WaveGRU model: melspectrogram => mu-law encoded waveform +""" + +from typing import Tuple + +import jax +import jax.numpy as jnp +import pax +from pax import GRUState +from tqdm.cli import tqdm + + +class ReLU(pax.Module): + def __call__(self, x): + return jax.nn.relu(x) + + +def dilated_residual_conv_block(dim, kernel, stride, dilation): + """ + Use dilated convs to enlarge the receptive field + """ + return pax.Sequential( + pax.Conv1D(dim, dim, kernel, stride, dilation, "VALID", with_bias=False), + pax.LayerNorm(dim, -1, True, True), + ReLU(), + pax.Conv1D(dim, dim, 1, 1, 1, "VALID", with_bias=False), + pax.LayerNorm(dim, -1, True, True), + ReLU(), + ) + + +def tile_1d(x, factor): + """ + Tile tensor of shape N, L, D into N, L*factor, D + """ + N, L, D = x.shape + x = x[:, :, None, :] + x = jnp.tile(x, (1, 1, factor, 1)) + x = jnp.reshape(x, (N, L * factor, D)) + return x + + +def up_block(in_dim, out_dim, factor, relu=True): + """ + Tile >> Conv >> BatchNorm >> ReLU + """ + f = pax.Sequential( + lambda x: tile_1d(x, factor), + pax.Conv1D( + in_dim, out_dim, 2 * factor, stride=1, padding="VALID", with_bias=False + ), + pax.LayerNorm(out_dim, -1, True, True), + ) + if relu: + f >>= ReLU() + return f + + +class Upsample(pax.Module): + """ + Upsample melspectrogram to match raw audio sample rate. + """ + + def __init__( + self, input_dim, hidden_dim, rnn_dim, upsample_factors, has_linear_output=False + ): + super().__init__() + self.input_conv = pax.Sequential( + pax.Conv1D(input_dim, hidden_dim, 1, with_bias=False), + pax.LayerNorm(hidden_dim, -1, True, True), + ) + self.upsample_factors = upsample_factors + self.dilated_convs = [ + dilated_residual_conv_block(hidden_dim, 3, 1, 2**i) for i in range(5) + ] + self.up_factors = upsample_factors[:-1] + self.up_blocks = [ + up_block(hidden_dim, hidden_dim, x) for x in self.up_factors[:-1] + ] + self.up_blocks.append( + up_block( + hidden_dim, + hidden_dim if has_linear_output else 3 * rnn_dim, + self.up_factors[-1], + relu=False, + ) + ) + if has_linear_output: + self.x2zrh_fc = pax.Linear(hidden_dim, rnn_dim * 3) + self.has_linear_output = has_linear_output + + self.final_tile = upsample_factors[-1] + + def __call__(self, x, no_repeat=False): + x = self.input_conv(x) + for residual in self.dilated_convs: + y = residual(x) + pad = (x.shape[1] - y.shape[1]) // 2 + x = x[:, pad:-pad, :] + y + + for f in self.up_blocks: + x = f(x) + + if self.has_linear_output: + x = self.x2zrh_fc(x) + + if no_repeat: + return x + x = tile_1d(x, self.final_tile) + return x + + +class GRU(pax.Module): + """ + A customized GRU module. + """ + + input_dim: int + hidden_dim: int + + def __init__(self, hidden_dim: int): + super().__init__() + self.hidden_dim = hidden_dim + self.h_zrh_fc = pax.Linear( + hidden_dim, + hidden_dim * 3, + w_init=jax.nn.initializers.variance_scaling( + 1, "fan_out", "truncated_normal" + ), + ) + + def initial_state(self, batch_size: int) -> GRUState: + """Create an all zeros initial state.""" + return GRUState(jnp.zeros((batch_size, self.hidden_dim), dtype=jnp.float32)) + + def __call__(self, state: GRUState, x) -> Tuple[GRUState, jnp.ndarray]: + hidden = state.hidden + x_zrh = x + h_zrh = self.h_zrh_fc(hidden) + x_zr, x_h = jnp.split(x_zrh, [2 * self.hidden_dim], axis=-1) + h_zr, h_h = jnp.split(h_zrh, [2 * self.hidden_dim], axis=-1) + + zr = x_zr + h_zr + zr = jax.nn.sigmoid(zr) + z, r = jnp.split(zr, 2, axis=-1) + + h_hat = x_h + r * h_h + h_hat = jnp.tanh(h_hat) + + h = (1 - z) * hidden + z * h_hat + return GRUState(h), h + + +class Pruner(pax.Module): + """ + Base class for pruners + """ + + def compute_sparsity(self, step): + t = jnp.power(1 - (step * 1.0 - 1_000) / 200_000, 3) + z = 0.95 * jnp.clip(1.0 - t, a_min=0, a_max=1) + return z + + def prune(self, step, weights): + """ + Return a mask + """ + z = self.compute_sparsity(step) + x = weights + H, W = x.shape + x = x.reshape(H // 4, 4, W // 4, 4) + x = jnp.abs(x) + x = jnp.sum(x, axis=(1, 3), keepdims=True) + q = jnp.quantile(jnp.reshape(x, (-1,)), z) + x = x >= q + x = jnp.tile(x, (1, 4, 1, 4)) + x = jnp.reshape(x, (H, W)) + return x + + +class GRUPruner(Pruner): + def __init__(self, gru): + super().__init__() + self.h_zrh_fc_mask = jnp.ones_like(gru.h_zrh_fc.weight) == 1 + + def __call__(self, gru: pax.GRU): + """ + Apply mask after an optimization step + """ + zrh_masked_weights = jnp.where(self.h_zrh_fc_mask, gru.h_zrh_fc.weight, 0) + gru = gru.replace_node(gru.h_zrh_fc.weight, zrh_masked_weights) + return gru + + def update_mask(self, step, gru: pax.GRU): + """ + Update internal masks + """ + z_weight, r_weight, h_weight = jnp.split(gru.h_zrh_fc.weight, 3, axis=1) + z_mask = self.prune(step, z_weight) + r_mask = self.prune(step, r_weight) + h_mask = self.prune(step, h_weight) + self.h_zrh_fc_mask *= jnp.concatenate((z_mask, r_mask, h_mask), axis=1) + + +class LinearPruner(Pruner): + def __init__(self, linear): + super().__init__() + self.mask = jnp.ones_like(linear.weight) == 1 + + def __call__(self, linear: pax.Linear): + """ + Apply mask after an optimization step + """ + return linear.replace(weight=jnp.where(self.mask, linear.weight, 0)) + + def update_mask(self, step, linear: pax.Linear): + """ + Update internal masks + """ + self.mask *= self.prune(step, linear.weight) + + +class WaveGRU(pax.Module): + """ + WaveGRU vocoder model. + """ + + def __init__( + self, + mel_dim=80, + rnn_dim=1024, + upsample_factors=(5, 3, 20), + has_linear_output=False, + ): + super().__init__() + self.embed = pax.Embed(256, 3 * rnn_dim) + self.upsample = Upsample( + input_dim=mel_dim, + hidden_dim=512, + rnn_dim=rnn_dim, + upsample_factors=upsample_factors, + has_linear_output=has_linear_output, + ) + self.rnn = GRU(rnn_dim) + self.o1 = pax.Linear(rnn_dim, rnn_dim) + self.o2 = pax.Linear(rnn_dim, 256) + self.gru_pruner = GRUPruner(self.rnn) + self.o1_pruner = LinearPruner(self.o1) + self.o2_pruner = LinearPruner(self.o2) + + def output(self, x): + x = self.o1(x) + x = jax.nn.relu(x) + x = self.o2(x) + return x + + def inference(self, mel, no_gru=False, seed=42): + """ + generate waveform form melspectrogram + """ + + @jax.jit + def step(rnn_state, mel, rng_key, x): + x = self.embed(x) + x = x + mel + rnn_state, x = self.rnn(rnn_state, x) + x = self.output(x) + rng_key, next_rng_key = jax.random.split(rng_key, 2) + x = jax.random.categorical(rng_key, x, axis=-1) + return rnn_state, next_rng_key, x + + y = self.upsample(mel, no_repeat=no_gru) + if no_gru: + return y + x = jnp.array([127], dtype=jnp.int32) + rnn_state = self.rnn.initial_state(1) + output = [] + rng_key = jax.random.PRNGKey(seed) + for i in tqdm(range(y.shape[1])): + rnn_state, rng_key, x = step(rnn_state, y[:, i], rng_key, x) + output.append(x) + x = jnp.concatenate(output, axis=0) + return x + + def __call__(self, mel, x): + x = self.embed(x) + y = self.upsample(mel) + pad_left = (x.shape[1] - y.shape[1]) // 2 + pad_right = x.shape[1] - y.shape[1] - pad_left + x = x[:, pad_left:-pad_right] + x = x + y + _, x = pax.scan( + self.rnn, + self.rnn.initial_state(x.shape[0]), + x, + time_major=False, + ) + x = self.output(x) + return x diff --git a/wavegru.yaml b/wavegru.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b3d53765239a3b538b1f75bc667a9f9ee0a1a95f --- /dev/null +++ b/wavegru.yaml @@ -0,0 +1,14 @@ +## dsp +sample_rate : 24000 +window_length: 50.0 # ms +hop_length: 12.5 # ms +mel_min: 1.0e-5 ## need .0 to make it a float +mel_dim: 80 +n_fft: 2048 + +## wavegru +embed_dim: 32 +rnn_dim: 1024 +frames_per_sequence: 67 +num_pad_frames: 62 +upsample_factors: [5, 3, 20] diff --git a/wavegru_cpp.py b/wavegru_cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..5b6801c9bbb794fb890c4ba493f5e78cd1db10fb --- /dev/null +++ b/wavegru_cpp.py @@ -0,0 +1,42 @@ +import numpy as np +from wavegru_mod import WaveGRU + + +def extract_weight_mask(net): + data = {} + data["embed_weight"] = net.embed.weight + data["gru_h_zrh_weight"] = net.rnn.h_zrh_fc.weight + data["gru_h_zrh_mask"] = net.gru_pruner.h_zrh_fc_mask + data["gru_h_zrh_bias"] = net.rnn.h_zrh_fc.bias + + data["o1_weight"] = net.o1.weight + data["o1_mask"] = net.o1_pruner.mask + data["o1_bias"] = net.o1.bias + data["o2_weight"] = net.o2.weight + data["o2_mask"] = net.o2_pruner.mask + data["o2_bias"] = net.o2.bias + return data + + +def load_wavegru_cpp(data, repeat_factor): + """load wavegru weight to cpp object""" + embed = data["embed_weight"] + rnn_dim = data["gru_h_zrh_bias"].shape[0] // 3 + net = WaveGRU(rnn_dim, repeat_factor) + net.load_embed(embed) + + m = np.ascontiguousarray(data["gru_h_zrh_weight"].T) + mask = np.ascontiguousarray(data["gru_h_zrh_mask"].T) + b = data["gru_h_zrh_bias"] + + o1 = np.ascontiguousarray(data["o1_weight"].T) + masko1 = np.ascontiguousarray(data["o1_mask"].T) + o1b = data["o1_bias"] + + o2 = np.ascontiguousarray(data["o2_weight"].T) + masko2 = np.ascontiguousarray(data["o2_mask"].T) + o2b = data["o2_bias"] + + net.load_weights(m, mask, b, o1, masko1, o1b, o2, masko2, o2b) + + return net diff --git a/wavegru_mod.cc b/wavegru_mod.cc new file mode 100644 index 0000000000000000000000000000000000000000..888f0455cfa0d8c7964771172251fd24b24bd571 --- /dev/null +++ b/wavegru_mod.cc @@ -0,0 +1,150 @@ +/* +WaveGRU: +> Embed > GRU > O1 > O2 > Sampling > ... +*/ + +#include +#include +#include + +#include +#include +#include + +#include "sparse_matmul/sparse_matmul.h" +namespace py = pybind11; +using namespace std; + +using fvec = std::vector; +using ivec = std::vector; +using fndarray = py::array_t; +using indarray = py::array_t; +using mat = csrblocksparse::CsrBlockSparseMatrix; +using vec = csrblocksparse::CacheAlignedVector; +using masked_mat = csrblocksparse::MaskedSparseMatrix; + +mat create_mat(int h, int w) { + auto m = masked_mat(w, h, 0.90, 4, 4, 0.0, true); + auto a = mat(m); + return a; +} + +struct WaveGRU { + int hidden_dim; + int repeat_factor; + mat m; + vec b; + vec z, r, hh, zrh; + vec fco1, fco2; + vec o1b, o2b; + vec t; + vec h; + vec logits; + mat o1, o2; + std::vector embed; + + WaveGRU(int hidden_dim, int repeat_factor) + : hidden_dim(hidden_dim), + repeat_factor(repeat_factor), + b(3*hidden_dim), + t(3*hidden_dim), + zrh(3*hidden_dim), + z(hidden_dim), + r(hidden_dim), + hh(hidden_dim), + fco1(hidden_dim), + fco2(256), + h(hidden_dim), + o1b(hidden_dim), + o2b(256), + logits(256) { + m = create_mat(hidden_dim, 3*hidden_dim); + o1 = create_mat(hidden_dim, hidden_dim); + o2 = create_mat(hidden_dim, 256); + embed = std::vector(); + for (int i = 0; i < 256; i++) { + embed.emplace_back(hidden_dim * 3); + embed[i].FillRandom(); + } + } + + void load_embed(fndarray embed_weights) { + auto a_embed = embed_weights.unchecked<2>(); + for (int i = 0; i < 256; i++) { + for (int j = 0; j < hidden_dim * 3; j++) embed[i][j] = a_embed(i, j); + } + } + + mat load_linear(vec& bias, fndarray w, indarray mask, fndarray b) { + auto w_ptr = static_cast(w.request().ptr); + auto mask_ptr = static_cast(mask.request().ptr); + auto rb = b.unchecked<1>(); + // load bias, scale by 1/4 + for (int i = 0; i < rb.shape(0); i++) bias[i] = rb(i) / 4; + // load weights + masked_mat mm(w.shape(0), w.shape(1), mask_ptr, w_ptr); + mat mmm(mm); + return mmm; + } + + void load_weights(fndarray m, indarray m_mask, fndarray b, + fndarray o1, indarray o1_mask, + fndarray o1b, fndarray o2, + indarray o2_mask, fndarray o2b) { + this->m = load_linear(this->b, m, m_mask, b); + this->o1 = load_linear(this->o1b, o1, o1_mask, o1b); + this->o2 = load_linear(this->o2b, o2, o2_mask, o2b); + } + + std::vector inference(fndarray ft, float temperature) { + auto rft = ft.unchecked<2>(); + int value = 127; + std::vector signal(rft.shape(0) * repeat_factor); + h.FillZero(); + for (int index = 0; index < signal.size(); index++) { + m.SpMM_bias(h, b, &zrh, false); + + for (int i = 0; i < 3 * hidden_dim; i++) t[i] = embed[value][i] + rft(index / repeat_factor, i); + for (int i = 0; i < hidden_dim; i++) { + z[i] = zrh[i] + t[i]; + r[i] = zrh[hidden_dim + i] + t[hidden_dim + i]; + } + + z.Sigmoid(); + r.Sigmoid(); + + for (int i = 0; i < hidden_dim; i++) { + hh[i] = zrh[hidden_dim * 2 + i] * r[i] + t[hidden_dim * 2 + i]; + } + hh.Tanh(); + for (int i = 0; i < hidden_dim; i++) { + h[i] = (1. - z[i]) * h[i] + z[i] * hh[i]; + } + o1.SpMM_bias(h, o1b, &fco1, true); + o2.SpMM_bias(fco1, o2b, &fco2, false); + // auto max_logit = fco2[0]; + // for (int i = 1; i <= 255; ++i) { + // max_logit = max(max_logit, fco2[i]); + // } + // float total = 0.0; + // for (int i = 0; i <= 255; ++i) { + // logits[i] = csrblocksparse::fast_exp(fco2[i] - max_logit); + // total += logits[i]; + // } + // for (int i = 0; i <= 255; ++i) { + // if (logits[i] < total / 1024.0) fco2[i] = -1e9; + // } + value = fco2.Sample(temperature); + signal[index] = value; + } + return signal; + } +}; + +PYBIND11_MODULE(wavegru_mod, m) { + py::class_(m, "WaveGRU") + .def(py::init()) + .def("load_embed", &WaveGRU::load_embed) + .def("load_weights", &WaveGRU::load_weights) + .def("inference", &WaveGRU::inference); +} diff --git a/wavegru_mod.so b/wavegru_mod.so new file mode 100755 index 0000000000000000000000000000000000000000..3a640ccb6a7826b43f4c95b13797284deafbc122 --- /dev/null +++ b/wavegru_mod.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:700f2cade76db615b1e38bddfc9c604ff1c8ea1af3e507f879d0ceebae5d232d +size 525536 diff --git a/wavegru_vocoder_1024_v4_1320000.ckpt b/wavegru_vocoder_1024_v4_1320000.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..6420089ae7786486a82823b8a064694704f4ce0c --- /dev/null +++ b/wavegru_vocoder_1024_v4_1320000.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:052a90bd510607f5cd2a6e6ce9d1ae4138db25cb69cc8504c98f2d33eac13375 +size 69717674